From 4becfe471347f5d9b54372b8cf5d0de256079f7e Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Tue, 22 Oct 2024 03:10:06 +0000 Subject: [PATCH 01/33] Add Windows pipeline --- .github/workflows/win_ci.yml | 109 +++++++++++++++++++++++++++++++++++ 1 file changed, 109 insertions(+) create mode 100644 .github/workflows/win_ci.yml diff --git a/.github/workflows/win_ci.yml b/.github/workflows/win_ci.yml new file mode 100644 index 0000000..86815af --- /dev/null +++ b/.github/workflows/win_ci.yml @@ -0,0 +1,109 @@ +name: Windows_CI +on: + push: + branches: + - main + - rel-* + pull_request: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + Win32_debug_no_ort: + runs-on: windows-2022 + permissions: + actions: read + contents: read + security-events: write + steps: + - uses: actions/checkout@v4 + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + config-file: ./.github/codeql/codeql-config.yml + languages: 'cpp' + - run: | + cmake --workflow --preset windows_win32_debug_no_ort_workflow + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:cpp" + output: sarif-results + upload: failure-only + + - name: filter-sarif + uses: advanced-security/filter-sarif@v1 + with: + patterns: | + +**/*.cc + +**/*.h + -tests/**/*.* + -build/**/*.* + input: sarif-results/cpp.sarif + output: sarif-results/cpp.sarif + + - name: Upload SARIF + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: sarif-results/cpp.sarif + + Win32_release_no_ort: + runs-on: windows-2022 + steps: + - uses: actions/checkout@v4 + - run: | + cmake --workflow --preset windows_win32_release_no_ort_workflow + + WinX64_debug_no_ort: + runs-on: windows-2022 + permissions: + actions: read + contents: read + security-events: write + steps: + - uses: actions/checkout@v4 + - name: Initialize CodeQL + uses: github/codeql-action/init@v3 + with: + config-file: ./.github/codeql/codeql-config.yml + languages: 'cpp' + - run: | + cmake --workflow --preset windows_x64_debug_no_ort_workflow + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v3 + with: + category: "/language:cpp" + output: sarif-results + upload: failure-only + + - name: filter-sarif + uses: advanced-security/filter-sarif@v1 + with: + patterns: | + +**/*.cc + +**/*.h + -tests/**/*.* + -build/**/*.* + input: sarif-results/cpp.sarif + output: sarif-results/cpp.sarif + + - name: Upload SARIF + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: sarif-results/cpp.sarif + + WinX64_release_no_ort: + runs-on: windows-2022 + steps: + - uses: actions/checkout@v4 + - run: | + cmake --workflow --preset windows_x64_release_no_ort_workflow + + WinX64_release: + runs-on: windows-2022 + steps: + - uses: actions/checkout@v4 + - run: | + cmake --workflow --preset windows_x64_release_workflow \ No newline at end of file From 43096b6713fb170bf66980cf0aa5b998d04cb7a2 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 6 Dec 2024 20:53:35 +0000 Subject: [PATCH 02/33] update --- include/mlas_gemm_postprocessor.h | 1 - include/mlas_qnbit.h | 82 +- src/common/logging/logging.cc | 14 +- src/common/profiler.cc | 4 +- src/common/profiler.h | 4 +- src/common/threadpool.cc | 4 +- src/core/platform/windows/stacktrace.cc | 1 - src/lib/CMakeLists.txt | 1599 +++++++++-------- src/lib/fp16_common.h | 17 + src/lib/hqnbitgemm_kernel_neon_fp16.cpp | 898 +++++++++ src/lib/mlasi.h | 34 +- src/lib/platform.cpp | 25 +- src/lib/qgemm.h | 9 +- src/lib/{sqnbitgemm.cpp => qnbitgemm.cpp} | 404 ++++- src/lib/{sqnbitgemm.h => qnbitgemm.h} | 120 +- ...nel_neon.cpp => qnbitgemm_kernel_neon.cpp} | 49 +- ..._kernel_neon.h => qnbitgemm_kernel_neon.h} | 53 +- src/lib/scalar/SgemmKernelScalar.cpp | 36 +- src/lib/sgemm.cpp | 6 +- src/lib/sqnbitgemm_kernel_avx2.cpp | 34 +- .../sqnbitgemm_kernel_avx2_int8_blklen16.h | 2 +- .../sqnbitgemm_kernel_avx2_int8_blklen32.h | 2 +- .../sqnbitgemm_kernel_avx2_int8_blklen64.h | 4 +- src/lib/sqnbitgemm_kernel_avx512.cpp | 26 +- src/lib/sqnbitgemm_kernel_avx512_int8.h | 12 +- .../sqnbitgemm_kernel_avx512_int8_blklen128.h | 2 +- .../sqnbitgemm_kernel_avx512_int8_blklen16.h | 2 +- .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 2 +- .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 2 +- src/lib/sqnbitgemm_kernel_avx512vnni.cpp | 20 +- src/lib/sqnbitgemm_kernel_avx_common.h | 24 +- src/lib/sqnbitgemm_kernel_avx_common_fp32.h | 2 +- src/lib/sqnbitgemm_kernel_avx_common_int8.h | 2 +- src/lib/sqnbitgemm_kernel_neon_fp32.cpp | 10 +- src/lib/sqnbitgemm_kernel_neon_int8.cpp | 8 +- ...bitgemm_m1_sym_kernel_avx2_int8_blklen32.h | 2 +- ...bitgemm_m1_sym_kernel_avx2_int8_blklen64.h | 2 +- src/ort_include/core/common/logging/logging.h | 3 +- .../platform/EigenNonBlockingThreadPool.h | 19 +- src/ort_include/core/platform/ort_mutex.h | 9 - tests/bench/CMakeLists.txt | 2 +- ...nch_sqnbitgemm.cpp => bench_qnbitgemm.cpp} | 111 +- tests/bench/bench_util.h | 25 +- tests/unittest/test_hqnbitgemm_neon.cpp | 501 ++++++ tests/unittest/test_sqnbitgemm.cpp | 40 +- tests/unittest/test_util.h | 12 +- 46 files changed, 3021 insertions(+), 1219 deletions(-) create mode 100644 src/lib/hqnbitgemm_kernel_neon_fp16.cpp rename src/lib/{sqnbitgemm.cpp => qnbitgemm.cpp} (62%) rename src/lib/{sqnbitgemm.h => qnbitgemm.h} (71%) rename src/lib/{sqnbitgemm_kernel_neon.cpp => qnbitgemm_kernel_neon.cpp} (74%) rename src/lib/{sqnbitgemm_kernel_neon.h => qnbitgemm_kernel_neon.h} (69%) delete mode 100644 src/ort_include/core/platform/ort_mutex.h rename tests/bench/{bench_sqnbitgemm.cpp => bench_qnbitgemm.cpp} (53%) create mode 100644 tests/unittest/test_hqnbitgemm_neon.cpp diff --git a/include/mlas_gemm_postprocessor.h b/include/mlas_gemm_postprocessor.h index 8c24705..7ea29eb 100644 --- a/include/mlas_gemm_postprocessor.h +++ b/include/mlas_gemm_postprocessor.h @@ -16,7 +16,6 @@ Module Name: --*/ #pragma once -#include template class MLAS_GEMM_POSTPROCESSOR diff --git a/include/mlas_qnbit.h b/include/mlas_qnbit.h index 232bf22..9608644 100644 --- a/include/mlas_qnbit.h +++ b/include/mlas_qnbit.h @@ -27,51 +27,50 @@ Module Name: * @brief Define compute types of block quantization, in order of decreasing accuracy. */ typedef enum { - CompUndef = 0, /*!< undef */ - CompFp32, /*!< input fp32, accumulator fp32 */ - CompFp16, /*!< input fp16, accumulator fp16 */ - CompBf16, /*!< input bf16, accumulator fp32 */ - CompInt8, /*!< input int8, accumulator int32 */ - - // special values that should be the first and last actual values - - CompMostAccurate = CompUndef, - CompLeastAccurate = CompInt8, -} MLAS_SQNBIT_GEMM_COMPUTE_TYPE; + SQNBIT_CompFp32, /*!< input fp32, accumulator fp32 */ + HQNBIT_CompFp16, /*!< input fp16, accumulator fp16 */ + BHQNBIT_CompBf16, /*!< input bf16, accumulator fp32 */ + SQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp32 */ + HQNBIT_CompInt8, /*!< input int8, accumulator int32, input fp16 */ +} MLAS_QNBIT_GEMM_COMPUTE_TYPE; /** * @brief Data parameters for float/n-bit quantized int GEMM routine. + * + * @tparam T data type of input A */ -struct MLAS_SQNBIT_GEMM_DATA_PARAMS { - const float* A = nullptr; ///< address of A (float32 matrix) +template +struct MLAS_QNBIT_GEMM_DATA_PARAMS { + const T* A = nullptr; ///< address of A (float32/16 matrix) size_t lda = 0; ///< leading dimension of A const void* QuantBDataWorkspace; ///< address of quantized B (quantized n-bit int values) const std::byte* PackedQuantBData = nullptr; /// address of packed quantized B data - const float* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block + const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block - const float* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block - const float* Bias = nullptr; ///< optional address of Bias, vector size N - float* C = nullptr; ///< address of result matrix + const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + const T* Bias = nullptr; ///< optional address of Bias, vector size N + T* C = nullptr; ///< address of result matrix size_t ldc = 0; ///< leading dimension of C ///< optional post processing to apply to result matrix - MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; + MLAS_GEMM_POSTPROCESSOR* PostProcessor = nullptr; }; /** * @brief Batched GEMM: C = A * B + Bias - * A must be a float32 matrix + * A must be a float32/16 matrix * B must be a quantized and packed n-bit int matrix * - * Call MlasIsSQNBitGemmAvailable() with the same parameters to determine whether this function may be called. + * Call MlasIsQNBitGemmAvailable() with the same parameters to determine whether this function may be called. * - * Call MlasSQNBitGemmPackQuantBDataSize() with the same parameters to determine whether - * MLAS_SQNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with - * MlasSQNBitGemmPackQuantBData(). + * Call MlasQNBitGemmPackQuantBDataSize() with the same parameters to determine whether + * MLAS_QNBIT_GEMM_DATA_PARAMS::QuantBData in `DataParams` should point to a buffer packed with + * MlasQNBitGemmPackQuantBData(). * - * Call MlasSQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should + * Call MlasQNBitGemmBatchWorkspaceSize() with the same parameters to determine whether `Workspace` should * point to an intermediate workspace buffer. * + * @tparam T data type of input A * @param[in] M row size of matrix A and C * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -81,36 +80,37 @@ struct MLAS_SQNBIT_GEMM_DATA_PARAMS { * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) * @param[inout] DataParams An array (size BatchN) of parameter blocks * @param[in] Workspace Address of intermediate workspace buffer. - If MlasSQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a + If MlasQNBitGemmBatchWorkspaceSize() returns a non-zero value, this must be a buffer with at least that many bytes. Otherwise, it may be nullptr. * @param[in] ThreadPool optional thread pool to use */ +template void MLASCALL -MlasSQNBitGemmBatch( +MlasQNBitGemmBatch( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, MLAS_THREADPOOL* ThreadPool = nullptr ); /** - * @brief Determines whether a float32/quantized n-bit int GEMM implementation is available on the current platform. + * @brief Determines whether a float32/16 quantized n-bit int GEMM implementation is available on the current platform. * * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ bool MLASCALL -MlasIsSQNBitGemmAvailable( +MlasIsQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -126,22 +126,22 @@ MlasIsSQNBitGemmAvailable( * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( +MlasQNBitGemmBatchWorkspaceSize( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** * @brief Gets the size in bytes of the packed quantized B data. - * If non-zero, the quantized B data must first be packed by calling MlasSQNBitGemmPackQuantBData() with a buffer of - * this size, and then that packed quantized B data buffer must be passed to MlasSQNBitGemmBatch(). - * If zero, MlasSQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to - * MlasSQNBitGemmBatch(). + * If non-zero, the quantized B data must first be packed by calling MlasQNBitGemmPackQuantBData() with a buffer of + * this size, and then that packed quantized B data buffer must be passed to MlasQNBitGemmBatch(). + * If zero, MlasQNBitGemmPackQuantBData() must not be called and the quantized B data must be directly passed to + * MlasQNBitGemmBatch(). * * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B @@ -150,12 +150,12 @@ MlasSQNBitGemmBatchWorkspaceSize( * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( +MlasQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); /** @@ -186,12 +186,12 @@ MlasSQNBitGemmPackQuantBDataSize( * @param[in] ThreadPool thread pool to use (no parallel if nullptr) */ void MLASCALL -MlasSQNBitGemmPackQuantBData( +MlasQNBitGemmPackQuantBData( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBDataAndOrBlkSum, const void* QuantBScale, diff --git a/src/common/logging/logging.cc b/src/common/logging/logging.cc index a326095..103a932 100644 --- a/src/common/logging/logging.cc +++ b/src/common/logging/logging.cc @@ -63,13 +63,13 @@ LoggingManager* LoggingManager::GetDefaultInstance() { #pragma warning(disable : 26426) #endif -static OrtMutex& DefaultLoggerMutex() noexcept { - static OrtMutex mutex; +static std::mutex& DefaultLoggerMutex() noexcept { + static std::mutex mutex; return mutex; } Logger* LoggingManager::s_default_logger_ = nullptr; -OrtMutex sink_mutex_; +std::mutex sink_mutex_; #ifdef _MSC_VER #pragma warning(pop) @@ -106,7 +106,7 @@ LoggingManager::LoggingManager(std::unique_ptr sink, Severity default_min // lock mutex to create instance, and enable logging // this matches the mutex usage in Shutdown - std::lock_guard guard(DefaultLoggerMutex()); + std::lock_guard guard(DefaultLoggerMutex()); if (DefaultLoggerManagerInstance().load() != nullptr) { ORT_THROW("Only one instance of LoggingManager created with InstanceType::Default can exist at any point in time."); @@ -126,7 +126,7 @@ LoggingManager::LoggingManager(std::unique_ptr sink, Severity default_min LoggingManager::~LoggingManager() { if (owns_default_logger_) { // lock mutex to reset DefaultLoggerManagerInstance() and free default logger from this instance. - std::lock_guard guard(DefaultLoggerMutex()); + std::lock_guard guard(DefaultLoggerMutex()); #if ((__cplusplus >= 201703L) || (defined(_MSVC_LANG) && (_MSVC_LANG >= 201703L))) DefaultLoggerManagerInstance().store(nullptr, std::memory_order_release); #else @@ -252,7 +252,7 @@ unsigned int GetProcessId() { bool LoggingManager::AddSinkOfType(SinkType sink_type, std::function()> sinkFactory, logging::Severity severity) { - std::lock_guard guard(sink_mutex_); + std::lock_guard guard(sink_mutex_); if (sink_->GetType() != SinkType::CompositeSink) { // Current sink is not a composite, create a new composite sink and add the current sink to it auto new_composite = std::make_unique(); @@ -274,7 +274,7 @@ bool LoggingManager::AddSinkOfType(SinkType sink_type, std::function guard(sink_mutex_); + std::lock_guard guard(sink_mutex_); if (sink_->GetType() == SinkType::CompositeSink) { auto composite_sink = static_cast(sink_.get()); diff --git a/src/common/profiler.cc b/src/common/profiler.cc index 71bca6e..8562e55 100644 --- a/src/common/profiler.cc +++ b/src/common/profiler.cc @@ -85,7 +85,7 @@ void Profiler::EndTimeAndRecordEvent(EventCategory category, custom_logger_->SendProfileEvent(event); } else { // TODO: sync_gpu if needed. - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); if (events_.size() < max_num_events_) { events_.emplace_back(std::move(event)); } else { @@ -115,7 +115,7 @@ std::string Profiler::EndProfiling() { LOGS(*session_logger_, INFO) << "Writing profiler data to file " << profile_stream_file_; } - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); profile_stream_ << "[\n"; for (const auto& ep_profiler : ep_profilers_) { diff --git a/src/common/profiler.h b/src/common/profiler.h index a0bca00..0103d8a 100644 --- a/src/common/profiler.h +++ b/src/common/profiler.h @@ -11,7 +11,7 @@ #include "core/common/profiler_common.h" #include "core/common/logging/logging.h" -#include "core/platform/ort_mutex.h" +#include namespace onnxruntime { @@ -130,7 +130,7 @@ class Profiler { static std::atomic global_max_num_events_; // Mutex controlling access to profiler data - OrtMutex mutex_; + std::mutex mutex_; bool enabled_{false}; #if defined(__wasm__) /* diff --git a/src/common/threadpool.cc b/src/common/threadpool.cc index 52d1c1e..7e8caa7 100644 --- a/src/common/threadpool.cc +++ b/src/common/threadpool.cc @@ -21,11 +21,11 @@ limitations under the License. #include "core/common/cpuid_info.h" #include "core/common/eigen_common_wrapper.h" #include "core/platform/EigenNonBlockingThreadPool.h" -#include "core/platform/ort_mutex.h" +#include #if !defined(ORT_MINIMAL_BUILD) #ifdef _WIN32 #include -#include +#include "processthreadsapi.h" #include #include #elif defined(__APPLE__) diff --git a/src/core/platform/windows/stacktrace.cc b/src/core/platform/windows/stacktrace.cc index 3401507..cc23d70 100644 --- a/src/core/platform/windows/stacktrace.cc +++ b/src/core/platform/windows/stacktrace.cc @@ -30,7 +30,6 @@ class CaptureStackTrace { // Get the stack trace. Currently only enabled for a DEBUG build as we require the DbgHelp library. std::vector GetStackTrace() { #ifndef NDEBUG -// TVM need to run with shared CRT, so won't work with debug helper now #if (defined __cpp_lib_stacktrace) && !(defined _OPSCHEMA_LIB_) && !(defined _GAMING_XBOX) && !(defined ONNXRUNTIME_ENABLE_MEMLEAK_CHECK) return detail::CaptureStackTrace().Trace(); #else diff --git a/src/lib/CMakeLists.txt b/src/lib/CMakeLists.txt index fbc7037..95829fe 100644 --- a/src/lib/CMakeLists.txt +++ b/src/lib/CMakeLists.txt @@ -1,795 +1,804 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -set(MLAS_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) -set(MLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) -set(MLAS_INC_DIR ${MLAS_ROOT}/../include) - -include_directories(${ONNXRUNTIME_INCLUDE_DIR}) - -#Set global compile flags for all the source code(including third_party code like protobuf) -#This section must be before any add_subdirectory, otherwise build may fail because /MD,/MT mismatch -if (MSVC) - if (CMAKE_VS_PLATFORM_NAME) - # Multi-platform generator - set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) - else() - set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) - endif() - if (onnxruntime_target_platform STREQUAL "ARM64") - set(onnxruntime_target_platform "ARM64") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "ARM64EC") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") - set(onnxruntime_target_platform "ARM") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64") - set(onnxruntime_target_platform "x64") - enable_language(ASM_MASM) - elseif (onnxruntime_target_platform STREQUAL "Win32" OR onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686") - set(onnxruntime_target_platform "x86") - enable_language(ASM_MASM) - message("Enabling SAFESEH for x86 build") - set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") - else() - message(FATAL_ERROR "Unknown CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") - endif() -endif() - -# -# All hardware agnostic source files here -# hardware specific files would cause trouble in -# multi-target build -# -add_library(onnxruntime_mlas STATIC - ${MLAS_SRC_DIR}/mlasi.h - ${MLAS_SRC_DIR}/platform.cpp - ${MLAS_SRC_DIR}/threading.cpp - ${MLAS_SRC_DIR}/sgemm.cpp - ${MLAS_SRC_DIR}/halfgemm.cpp - ${MLAS_SRC_DIR}/qgemm.cpp - ${MLAS_SRC_DIR}/qdwconv.cpp - ${MLAS_SRC_DIR}/convolve.cpp - ${MLAS_SRC_DIR}/convsym.cpp - ${MLAS_SRC_DIR}/pooling.cpp - ${MLAS_SRC_DIR}/transpose.cpp - ${MLAS_SRC_DIR}/reorder.cpp - ${MLAS_SRC_DIR}/snchwc.cpp - ${MLAS_SRC_DIR}/activate.cpp - ${MLAS_SRC_DIR}/logistic.cpp - ${MLAS_SRC_DIR}/tanh.cpp - ${MLAS_SRC_DIR}/erf.cpp - ${MLAS_SRC_DIR}/compute.cpp - ${MLAS_SRC_DIR}/quantize.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp - ${MLAS_SRC_DIR}/qladd.cpp - ${MLAS_SRC_DIR}/qlmul.cpp - ${MLAS_SRC_DIR}/qpostprocessor.cpp - ${MLAS_SRC_DIR}/qlgavgpool.cpp - ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp - ${MLAS_SRC_DIR}/sqnbitgemm.h - ${MLAS_SRC_DIR}/sqnbitgemm.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h - ${MLAS_SRC_DIR}/flashattn.cpp - ${MLAS_SRC_DIR}/cast.cpp -) - -target_sources(onnxruntime_mlas PRIVATE - ${MLAS_INC_DIR}/mlas_float16.h - ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h - ${MLAS_INC_DIR}/mlas_q4.h - ${MLAS_INC_DIR}/mlas_qnbit.h - ${MLAS_INC_DIR}/mlas.h -) - -if (NOT onnxruntime_ORT_MINIMAL_BUILD) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/q4_dq.cpp - ${MLAS_SRC_DIR}/q4gemm.cpp - ) -endif() - - -#TODO: set MASM flags properly -function(setup_mlas_source_for_windows) - - # - # Sources common for all platforms. - # - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ) - - #The onnxruntime_target_platform variable was added by Windows AI team in onnxruntime_common.cmake - #Don't use it for other platforms. - if((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")) - set(PREPROCESS_ARMASM_FLAGS "") - set(ARMASM_FLAGS "") - - if(onnxruntime_target_platform STREQUAL "ARM64") - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp - ) - - set(mlas_platform_preprocess_srcs - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDotLd64.asm - ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelDot.asm - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm - ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm - ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelSdot.asm - ${MLAS_SRC_DIR}/arm64/SgemmKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SgemvKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDot.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm - ) - else() - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ) - - set(mlas_platform_preprocess_srcs - ${MLAS_SRC_DIR}/arm64ec/QgemmU8X8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64ec/SgemmKernelNeon.asm - ) - - string(APPEND PREPROCESS_ARMASM_FLAGS " /arm64EC") - string(APPEND ARMASM_FLAGS " -machine ARM64EC") - endif() - - if(CMAKE_BUILD_TYPE STREQUAL "Debug") - string(APPEND ARMASM_FLAGS " -g") - endif() - - # Remove double quotes from flag strings. - separate_arguments(PREPROCESS_ARMASM_FLAGS NATIVE_COMMAND "${PREPROCESS_ARMASM_FLAGS}") - separate_arguments(ARMASM_FLAGS NATIVE_COMMAND "${ARMASM_FLAGS}") - - # Run the C precompiler on each input before the assembler. - foreach(asm_filename ${mlas_platform_preprocess_srcs}) - get_filename_component(asm_filename_base ${asm_filename} NAME_WLE) - set(preprocess_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.i) - set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.obj) - add_custom_command( - OUTPUT ${obj_filename} - COMMAND - cl.exe ${PREPROCESS_ARMASM_FLAGS} /P ${asm_filename} /Fi${preprocess_filename} - COMMAND - armasm64.exe ${ARMASM_FLAGS} ${preprocess_filename} ${obj_filename} - DEPENDS ${asm_filename} - BYPRODUCTS ${preprocess_filename} - ) - target_sources(onnxruntime_mlas PRIVATE ${obj_filename}) - endforeach() - elseif(onnxruntime_target_platform STREQUAL "ARM") - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/arm/sgemmc.cpp - ) - elseif(onnxruntime_target_platform STREQUAL "x64") - - file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS - "${MLAS_SRC_DIR}/intrinsics/avx/*.cpp" - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "/arch:AVX") - - file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS - "${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp" - ) - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") - - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/dgemm.cpp - ${mlas_platform_srcs_avx} - ${mlas_platform_srcs_avx2} - ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp - ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Vnni.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvxVnni.asm - ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelM1Avx.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/sgemma.asm - ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm - ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm - ) - if(MSVC_VERSION GREATER_EQUAL 1933) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm - ) - endif() - - if (NOT onnxruntime_ORT_MINIMAL_BUILD) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/q4gemm_avx512.cpp - ) - endif() - else() - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/i386/SgemmKernelSse2.asm - ${MLAS_SRC_DIR}/i386/SgemmKernelAvx.asm - ) - endif() -endfunction() - -if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/wasm_simd/*.cpp" - ) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/qgemm_kernel_wasmsimd.cpp - ) - else() - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/scalar/*.cpp" - ) - endif() - target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) -elseif(MSVC) - setup_mlas_source_for_windows() -else() - - if(APPLE) - get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) - - if(NOT ONNXRUNTIME_MLAS_OSX_ARCH) - set(ONNXRUNTIME_MLAS_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) - endif() - foreach(OSX_ARCH ${ONNXRUNTIME_MLAS_OSX_ARCH}) - if (OSX_ARCH STREQUAL "arm64") - set(ARM64 TRUE) - elseif (OSX_ARCH STREQUAL "arm64e") - set(ARM64 TRUE) - elseif (OSX_ARCH STREQUAL "arm") - set(ARM TRUE) - elseif (OSX_ARCH STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (OSX_ARCH STREQUAL "i386") - set(X86 TRUE) - endif() - endforeach() - elseif(ANDROID) - if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") - set(ARM TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") - set(ARM64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") - set(X86 TRUE) - endif() - else() - #Linux/FreeBSD/PowerPC/... - #The value of CMAKE_SYSTEM_PROCESSOR should be from `uname -m` - #Example values: - #arm64v8/ubuntu -> aarch64 - #arm32v6/alpine -> armv7l - #arm32v7/centos -> armv7l - #ppc64le/debian -> ppc64le - #s390x/ubuntu -> s390x - #ppc64le/busybox -> ppc64le - #arm64v8/ubuntu -> aarch64 - #Android: armv7-a aarch64 i686 x86_64 - #chasun: I don't think anyone uses 'arm64' - if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm.*") - set(ARM TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(powerpc.*|ppc.*)") - set(POWER TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") - set(X86 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") - set(X86_64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") - set(LOONGARCH64 TRUE) - endif() - endif() - - if(APPLE) - get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) - endif() - list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) - if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) - set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) - endif() - #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below - #and split MLAS to multiple static libraries. - #Otherwise, it works like if(...) elseif(...) elseif(...) endif() - set(MLAS_SOURCE_IS_NOT_SET 1) - if(ARM) - enable_language(ASM) - - set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} -mfpu=neon") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") - - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch32/QgemmU8X8KernelNeon.S - ${MLAS_SRC_DIR}/arm/sgemmc.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ) - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(ARM64 AND MLAS_SOURCE_IS_NOT_SET ) - enable_language(ASM) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDot.S - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDotLd64.S - ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelDot.S - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S - ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") - if (NOT APPLE) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S - ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp - ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - endif() - - if(ONNXRUNTIME_MLAS_MULTI_ARCH) - add_library(onnxruntime_mlas_arm64 STATIC ${mlas_platform_srcs}) - list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) - set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") - set(mlas_platform_srcs ) - else() - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(POWER AND MLAS_SOURCE_IS_NOT_SET) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp - ${MLAS_SRC_DIR}/dgemm.cpp - ${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp - ${MLAS_SRC_DIR}/power/QuantizePower.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") - - check_cxx_compiler_flag("-mcpu=power9" HAS_POWER9) - if (HAS_POWER9) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp PROPERTIES COMPILE_FLAGS "-mcpu=power9") - endif() - - check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10) - if(HAS_POWER10) - set(CMAKE_REQUIRED_FLAGS "-mcpu=power10") - check_cxx_source_compiles(" - #include - int main() { - __vector_quad acc0; - __builtin_mma_xxsetaccz (&acc0); - return 0; - }" - COMPILES_P10 - ) - if(COMPILES_P10) - check_cxx_source_compiles(" - #ifdef _AIX - #define POWER_10 0x40000 - #define POWER_10_ANDUP (POWER_10) - #include - #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) - int main() { - bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31); - return 0; - } - #else - #include - int main() { - unsigned long hwcap2 = getauxval(AT_HWCAP2); - bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); - return 0; - } - } - #endif" - HAS_P10_RUNTIME - ) - if (HAS_P10_RUNTIME) - set_source_files_properties(${MLAS_SRC_DIR}/platform.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") - set_source_files_properties(${MLAS_SRC_DIR}/qgemm.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") - endif() - set(mlas_platform_srcs_power10 - ${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp - ${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp - ${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE") - set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10") - set_source_files_properties(${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp PROPERTIES COMPILE_FLAGS "-O3 -mcpu=power10") - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${mlas_platform_srcs_power10} - ) - endif() - endif() - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(X86 AND MLAS_SOURCE_IS_NOT_SET) - enable_language(ASM) - - set(mlas_platform_srcs_sse2 - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/x86/SgemmKernelSse2.S - ) - set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") - - set(mlas_platform_srcs_avx - ${MLAS_SRC_DIR}/x86/SgemmKernelAvx.S - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") - - set(mlas_platform_srcs - ${mlas_platform_srcs_sse2} - ${mlas_platform_srcs_avx} - ) - - # In r23, NDK remove __x86.get_pc_thunk.* from libatomic. Add our own - # implementation to avoid external dependency. - if(ANDROID) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/x86/x86.get_pc_thunk.S - ) - endif() - - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(X86_64 AND MLAS_SOURCE_IS_NOT_SET) - enable_language(ASM) - - # Forward the flags for the minimum target platform version from the C - # compiler to the assembler. This works around CMakeASMCompiler.cmake.in - # not including the logic to set this flag for the assembler. - set(CMAKE_ASM${ASM_DIALECT}_OSX_DEPLOYMENT_TARGET_FLAG "${CMAKE_C_OSX_DEPLOYMENT_TARGET_FLAG}") - - # The LLVM assembler does not support the .arch directive to enable instruction - # set extensions and also doesn't support AVX-512F instructions without - # turning on support via command-line option. Group the sources by the - # instruction set extension and explicitly set the compiler flag as appropriate. - - set(mlas_platform_srcs_sse2 - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/x86_64/DgemmKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Sse2.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S - ) - if(NOT APPLE) - set(mlas_platform_srcs_sse2 - ${mlas_platform_srcs_sse2} - ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S - ) - endif() - set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") - - set(mlas_platform_srcs_avx - ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1Avx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1TransposeBAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Avx.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx.S - ${MLAS_SRC_DIR}/intrinsics/avx/min_max_elements.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") - - set(mlas_platform_srcs_avx2 - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8U8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvxVnni.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/DgemmKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/TransKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S - ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp - ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ) - if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) - set(mlas_platform_srcs_avx2 - ${mlas_platform_srcs_avx2} - ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S - ) - endif() -message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") -message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") - -if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") - message(STATUS "Using -mavx2 -mfma -mavxvnni flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") -else() - message(STATUS "Using -mavx2 -mfma flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") -endif() - set(mlas_platform_srcs_avx512f - ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") - - set(mlas_platform_srcs_avx512core - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Core.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S - ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") - - set(mlas_platform_srcs_avx512vnni - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") - - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/dgemm.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp - ${mlas_platform_srcs_sse2} - ${mlas_platform_srcs_avx} - ${mlas_platform_srcs_avx2} - ${mlas_platform_srcs_avx512f} - ${mlas_platform_srcs_avx512core} - ${mlas_platform_srcs_avx512vnni} - ) - - if (NOT onnxruntime_ORT_MINIMAL_BUILD) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/q4gemm_avx512.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() - if(NOT APPLE) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S - ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S - ) - set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() - - if(ONNXRUNTIME_MLAS_MULTI_ARCH) - add_library(onnxruntime_mlas_x86_64 STATIC ${mlas_platform_srcs}) - set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") - list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) - set(mlas_platform_srcs ) - else() - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp - ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S - ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S - ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S - ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S - ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/scalar/*.cpp") - endif() - target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) -endif() - -foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_INCLUDE_DIR} ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) - target_link_libraries(${mlas_target} Microsoft.GSL::GSL) - - set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") -endforeach() - -if (WIN32) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") - if (onnxruntime_ENABLE_STATIC_ANALYSIS) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") - endif() -endif() - -if (PLATFORM_NAME STREQUAL "macabi") - # Needed for maccatalyst C compilation - # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" - target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) -endif() - -if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_mlas - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) -endif() - -# set up source group for MLAS source files -block() - set(source_group_srcs) - foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - get_target_property(mlas_target_srcs ${mlas_target} SOURCES) - foreach(mlas_target_src ${mlas_target_srcs}) - cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) - if(in_mlas_root) - list(APPEND source_group_srcs ${mlas_target_src}) - endif() - endforeach() - endforeach() -endblock() - - - - # - # Command line tool for quantization and de-quantization of 2-D fp32 tensors - # based on block-wise quantization of int4 - # - - add_executable(onnxruntime_mlas_q4dq - ${MLAS_SRC_DIR}/q4_dq_cli.cpp - ) - target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") - - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS}) - if(NOT MLAS_NO_ONNXRUNTIME) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE onnxruntime_common) - endif() - if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo) - endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs}) - endif() - - if(WIN32) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE debug Dbghelp Advapi32) - endif() - if (onnxruntime_LINK_LIBATOMIC) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE atomic) - endif() - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE Threads::Threads) - - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") - else() - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") - endif() - endif() - +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +set(MLAS_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) +set(MLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(MLAS_INC_DIR ${MLAS_ROOT}/../include) + +include_directories(${ONNXRUNTIME_INCLUDE_DIR}) + +#Set global compile flags for all the source code(including third_party code like protobuf) +#This section must be before any add_subdirectory, otherwise build may fail because /MD,/MT mismatch +if (MSVC) + if (CMAKE_VS_PLATFORM_NAME) + # Multi-platform generator + set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) + else() + set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) + endif() + if (onnxruntime_target_platform STREQUAL "ARM64") + set(onnxruntime_target_platform "ARM64") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM64EC") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") + set(onnxruntime_target_platform "ARM") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64") + set(onnxruntime_target_platform "x64") + enable_language(ASM_MASM) + elseif (onnxruntime_target_platform STREQUAL "Win32" OR onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686") + set(onnxruntime_target_platform "x86") + enable_language(ASM_MASM) + message("Enabling SAFESEH for x86 build") + set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") + else() + message(FATAL_ERROR "Unknown CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") + endif() +endif() + +# +# All hardware agnostic source files here +# hardware specific files would cause trouble in +# multi-target build +# +add_library(onnxruntime_mlas STATIC + ${MLAS_SRC_DIR}/mlasi.h + ${MLAS_SRC_DIR}/platform.cpp + ${MLAS_SRC_DIR}/threading.cpp + ${MLAS_SRC_DIR}/sgemm.cpp + ${MLAS_SRC_DIR}/halfgemm.cpp + ${MLAS_SRC_DIR}/qgemm.cpp + ${MLAS_SRC_DIR}/qdwconv.cpp + ${MLAS_SRC_DIR}/convolve.cpp + ${MLAS_SRC_DIR}/convsym.cpp + ${MLAS_SRC_DIR}/pooling.cpp + ${MLAS_SRC_DIR}/transpose.cpp + ${MLAS_SRC_DIR}/reorder.cpp + ${MLAS_SRC_DIR}/snchwc.cpp + ${MLAS_SRC_DIR}/activate.cpp + ${MLAS_SRC_DIR}/logistic.cpp + ${MLAS_SRC_DIR}/tanh.cpp + ${MLAS_SRC_DIR}/erf.cpp + ${MLAS_SRC_DIR}/compute.cpp + ${MLAS_SRC_DIR}/quantize.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp + ${MLAS_SRC_DIR}/qladd.cpp + ${MLAS_SRC_DIR}/qlmul.cpp + ${MLAS_SRC_DIR}/qpostprocessor.cpp + ${MLAS_SRC_DIR}/qlgavgpool.cpp + ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/qnbitgemm.h + ${MLAS_SRC_DIR}/qnbitgemm.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h + ${MLAS_SRC_DIR}/flashattn.cpp + ${MLAS_SRC_DIR}/cast.cpp +) + +target_sources(onnxruntime_mlas PRIVATE + ${MLAS_INC_DIR}/mlas_float16.h + ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h + ${MLAS_INC_DIR}/mlas_q4.h + ${MLAS_INC_DIR}/mlas_qnbit.h + ${MLAS_INC_DIR}/mlas.h +) + +if (NOT onnxruntime_ORT_MINIMAL_BUILD) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/q4_dq.cpp + ${MLAS_SRC_DIR}/q4gemm.cpp + ) +endif() + + +#TODO: set MASM flags properly +function(setup_mlas_source_for_windows) + + # + # Sources common for all platforms. + # + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ) + + #The onnxruntime_target_platform variable was added by Windows AI team in onnxruntime_common.cmake + #Don't use it for other platforms. + if((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")) + set(PREPROCESS_ARMASM_FLAGS "") + set(ARMASM_FLAGS "") + + if(onnxruntime_target_platform STREQUAL "ARM64") + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ) + + set(mlas_platform_preprocess_srcs + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDotLd64.asm + ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelDot.asm + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm + ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm + ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelSdot.asm + ${MLAS_SRC_DIR}/arm64/SgemmKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SgemvKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDot.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm + ) + else() + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ) + + set(mlas_platform_preprocess_srcs + ${MLAS_SRC_DIR}/arm64ec/QgemmU8X8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64ec/SgemmKernelNeon.asm + ) + + string(APPEND PREPROCESS_ARMASM_FLAGS " /arm64EC") + string(APPEND ARMASM_FLAGS " -machine ARM64EC") + endif() + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + string(APPEND ARMASM_FLAGS " -g") + endif() + + # Remove double quotes from flag strings. + separate_arguments(PREPROCESS_ARMASM_FLAGS NATIVE_COMMAND "${PREPROCESS_ARMASM_FLAGS}") + separate_arguments(ARMASM_FLAGS NATIVE_COMMAND "${ARMASM_FLAGS}") + + # Run the C precompiler on each input before the assembler. + foreach(asm_filename ${mlas_platform_preprocess_srcs}) + get_filename_component(asm_filename_base ${asm_filename} NAME_WLE) + set(preprocess_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.i) + set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.obj) + add_custom_command( + OUTPUT ${obj_filename} + COMMAND + cl.exe ${PREPROCESS_ARMASM_FLAGS} /P ${asm_filename} /Fi${preprocess_filename} + COMMAND + armasm64.exe ${ARMASM_FLAGS} ${preprocess_filename} ${obj_filename} + DEPENDS ${asm_filename} + BYPRODUCTS ${preprocess_filename} + ) + target_sources(onnxruntime_mlas PRIVATE ${obj_filename}) + endforeach() + elseif(onnxruntime_target_platform STREQUAL "ARM") + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/arm/sgemmc.cpp + ) + elseif(onnxruntime_target_platform STREQUAL "x64") + + file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS + "${MLAS_SRC_DIR}/intrinsics/avx/*.cpp" + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "/arch:AVX") + + file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS + "${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp" + ) + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") + + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/dgemm.cpp + ${mlas_platform_srcs_avx} + ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp + ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Vnni.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvxVnni.asm + ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelM1Avx.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/sgemma.asm + ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm + ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm + ) + if(MSVC_VERSION GREATER_EQUAL 1933) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm + ) + endif() + + if (NOT onnxruntime_ORT_MINIMAL_BUILD) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/q4gemm_avx512.cpp + ) + endif() + else() + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp + ${MLAS_SRC_DIR}/i386/SgemmKernelSse2.asm + ${MLAS_SRC_DIR}/i386/SgemmKernelAvx.asm + ) + endif() +endfunction() + +if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/wasm_simd/*.cpp" + ) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/qgemm_kernel_wasmsimd.cpp + ) + else() + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/scalar/*.cpp" + ) + endif() + target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) +elseif(MSVC) + setup_mlas_source_for_windows() +else() + + if(APPLE) + get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) + + if(NOT ONNXRUNTIME_MLAS_OSX_ARCH) + set(ONNXRUNTIME_MLAS_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) + endif() + foreach(OSX_ARCH ${ONNXRUNTIME_MLAS_OSX_ARCH}) + if (OSX_ARCH STREQUAL "arm64") + set(ARM64 TRUE) + elseif (OSX_ARCH STREQUAL "arm64e") + set(ARM64 TRUE) + elseif (OSX_ARCH STREQUAL "arm") + set(ARM TRUE) + elseif (OSX_ARCH STREQUAL "x86_64") + set(X86_64 TRUE) + elseif (OSX_ARCH STREQUAL "i386") + set(X86 TRUE) + endif() + endforeach() + elseif(ANDROID) + if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") + set(ARM TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") + set(ARM64 TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") + set(X86_64 TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") + set(X86 TRUE) + endif() + else() + #Linux/FreeBSD/PowerPC/... + #The value of CMAKE_SYSTEM_PROCESSOR should be from `uname -m` + #Example values: + #arm64v8/ubuntu -> aarch64 + #arm32v6/alpine -> armv7l + #arm32v7/centos -> armv7l + #ppc64le/debian -> ppc64le + #s390x/ubuntu -> s390x + #ppc64le/busybox -> ppc64le + #arm64v8/ubuntu -> aarch64 + #Android: armv7-a aarch64 i686 x86_64 + #chasun: I don't think anyone uses 'arm64' + if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*") + set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm.*") + set(ARM TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") + set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(powerpc.*|ppc.*)") + set(POWER TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") + set(X86 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") + set(X86_64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") + set(LOONGARCH64 TRUE) + endif() + endif() + + if(APPLE) + get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) + endif() + list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) + if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) + set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) + endif() + #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below + #and split MLAS to multiple static libraries. + #Otherwise, it works like if(...) elseif(...) elseif(...) endif() + set(MLAS_SOURCE_IS_NOT_SET 1) + if(ARM) + enable_language(ASM) + + set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} -mfpu=neon") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") + + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch32/QgemmU8X8KernelNeon.S + ${MLAS_SRC_DIR}/arm/sgemmc.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ) + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(ARM64 AND MLAS_SOURCE_IS_NOT_SET ) + enable_language(ASM) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDot.S + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDotLd64.S + ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelDot.S + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S + ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") + if (NOT APPLE) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + endif() + + if(ONNXRUNTIME_MLAS_MULTI_ARCH) + add_library(onnxruntime_mlas_arm64 STATIC ${mlas_platform_srcs}) + list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) + set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") + set(mlas_platform_srcs ) + else() + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(POWER AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp + ${MLAS_SRC_DIR}/dgemm.cpp + ${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp + ${MLAS_SRC_DIR}/power/QuantizePower.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") + + check_cxx_compiler_flag("-mcpu=power9" HAS_POWER9) + if (HAS_POWER9) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp PROPERTIES COMPILE_FLAGS "-mcpu=power9") + endif() + + check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10) + if(HAS_POWER10) + set(CMAKE_REQUIRED_FLAGS "-mcpu=power10") + check_cxx_source_compiles(" + #include + int main() { + __vector_quad acc0; + __builtin_mma_xxsetaccz (&acc0); + return 0; + }" + COMPILES_P10 + ) + if(COMPILES_P10) + check_cxx_source_compiles(" + #ifdef _AIX + #define POWER_10 0x40000 + #define POWER_10_ANDUP (POWER_10) + #include + #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) + int main() { + bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31); + return 0; + } + #else + #include + int main() { + unsigned long hwcap2 = getauxval(AT_HWCAP2); + bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); + return 0; + } + #endif" + HAS_P10_RUNTIME + ) + if (HAS_P10_RUNTIME) + set_source_files_properties(${MLAS_SRC_DIR}/platform.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") + set_source_files_properties(${MLAS_SRC_DIR}/qgemm.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") + endif() + set(mlas_platform_srcs_power10 + ${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp + ${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp + ${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE") + set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10") + set_source_files_properties(${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp PROPERTIES COMPILE_FLAGS "-O3 -mcpu=power10") + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${mlas_platform_srcs_power10} + ) + endif() + endif() + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(X86 AND MLAS_SOURCE_IS_NOT_SET) + enable_language(ASM) + + set(mlas_platform_srcs_sse2 + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/x86/SgemmKernelSse2.S + ) + set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") + + set(mlas_platform_srcs_avx + ${MLAS_SRC_DIR}/x86/SgemmKernelAvx.S + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") + + set(mlas_platform_srcs + ${mlas_platform_srcs_sse2} + ${mlas_platform_srcs_avx} + ) + + # In r23, NDK remove __x86.get_pc_thunk.* from libatomic. Add our own + # implementation to avoid external dependency. + if(ANDROID) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/x86/x86.get_pc_thunk.S + ) + endif() + + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(X86_64 AND MLAS_SOURCE_IS_NOT_SET) + enable_language(ASM) + + # Forward the flags for the minimum target platform version from the C + # compiler to the assembler. This works around CMakeASMCompiler.cmake.in + # not including the logic to set this flag for the assembler. + set(CMAKE_ASM${ASM_DIALECT}_OSX_DEPLOYMENT_TARGET_FLAG "${CMAKE_C_OSX_DEPLOYMENT_TARGET_FLAG}") + + # The LLVM assembler does not support the .arch directive to enable instruction + # set extensions and also doesn't support AVX-512F instructions without + # turning on support via command-line option. Group the sources by the + # instruction set extension and explicitly set the compiler flag as appropriate. + + set(mlas_platform_srcs_sse2 + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/x86_64/DgemmKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Sse2.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S + ) + if(NOT APPLE) + set(mlas_platform_srcs_sse2 + ${mlas_platform_srcs_sse2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S + ) + endif() + set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") + + set(mlas_platform_srcs_avx + ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1Avx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1TransposeBAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Avx.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx.S + ${MLAS_SRC_DIR}/intrinsics/avx/min_max_elements.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") + + set(mlas_platform_srcs_avx2 + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8U8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvxVnni.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/DgemmKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/TransKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S + ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp + ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ) + if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) + set(mlas_platform_srcs_avx2 + ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S + ) + endif() +message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") + +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") + message(STATUS "Using -mavx2 -mfma -mavxvnni flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") +else() + message(STATUS "Using -mavx2 -mfma flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") +endif() + set(mlas_platform_srcs_avx512f + ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") + + set(mlas_platform_srcs_avx512core + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Core.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S + ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") + + set(mlas_platform_srcs_avx512vnni + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") + + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/dgemm.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp + ${mlas_platform_srcs_sse2} + ${mlas_platform_srcs_avx} + ${mlas_platform_srcs_avx2} + ${mlas_platform_srcs_avx512f} + ${mlas_platform_srcs_avx512core} + ${mlas_platform_srcs_avx512vnni} + ) + + if (NOT onnxruntime_ORT_MINIMAL_BUILD) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/q4gemm_avx512.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") + endif() + if(NOT APPLE) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S + ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S + ) + set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + endif() + + if(ONNXRUNTIME_MLAS_MULTI_ARCH) + add_library(onnxruntime_mlas_x86_64 STATIC ${mlas_platform_srcs}) + set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") + list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) + set(mlas_platform_srcs ) + else() + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S + ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S + ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/scalar/*.cpp") + elseif (onnxruntime_FORCE_GENERIC_ALGORITHMS) + file(GLOB_RECURSE mlas_platform_srcs_generic + "${MLAS_SRC_DIR}/scalar/*.cpp") + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${mlas_platform_srcs_generic} + ) + endif() + target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) +endif() + +foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) + target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_INCLUDE_DIR} ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) + target_link_libraries(${mlas_target} Microsoft.GSL::GSL) + + set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") +endforeach() + +if (WIN32) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") + if (onnxruntime_ENABLE_STATIC_ANALYSIS) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") + endif() +endif() + +if (PLATFORM_NAME STREQUAL "macabi") + # Needed for maccatalyst C compilation + # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" + target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) +endif() + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_mlas + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() + +# set up source group for MLAS source files +block() + set(source_group_srcs) + foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) + get_target_property(mlas_target_srcs ${mlas_target} SOURCES) + foreach(mlas_target_src ${mlas_target_srcs}) + cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) + if(in_mlas_root) + list(APPEND source_group_srcs ${mlas_target_src}) + endif() + endforeach() + endforeach() +endblock() + + + + # + # Command line tool for quantization and de-quantization of 2-D fp32 tensors + # based on block-wise quantization of int4 + # + + add_executable(onnxruntime_mlas_q4dq + ${MLAS_SRC_DIR}/q4_dq_cli.cpp + ) + target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") + + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS}) + if(NOT MLAS_NO_ONNXRUNTIME) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE onnxruntime_common) + endif() + if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo) + endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs}) + endif() + + if(WIN32) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE debug Dbghelp Advapi32) + endif() + if (onnxruntime_LINK_LIBATOMIC) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE atomic) + endif() + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE Threads::Threads) + + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") + else() + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") + endif() + endif() + diff --git a/src/lib/fp16_common.h b/src/lib/fp16_common.h index 30b66cd..f4c4990 100644 --- a/src/lib/fp16_common.h +++ b/src/lib/fp16_common.h @@ -64,6 +64,15 @@ MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadFloat16x4(const _mlas_fp16_* Buffer) { return vreinterpret_f16_u16(vld1_u16(Buffer)); } +template +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasLoadLaneFloat16x4(const _mlas_fp16_* Buffer, MLAS_FLOAT16X4 vec) { + return vreinterpret_f16_u16( + vld1_lane_u16(Buffer, vreinterpret_u16_f16(vec), lane) + ); +} + MLAS_FORCEINLINE MLAS_FLOAT16X4 MlasLoadPartialFloat16x4(const _mlas_fp16_* Buffer, size_t len) @@ -95,6 +104,14 @@ MlasStoreFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) vst1_u16(Buffer, vreinterpret_u16_f16(Vector)); } +template +MLAS_FORCEINLINE +void +MlasStoreLaneFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector) +{ + vst1_lane_u16(Buffer, vreinterpret_u16_f16(Vector), lane); +} + MLAS_FORCEINLINE void MlasStorePartialFloat16x4(_mlas_fp16_* Buffer, MLAS_FLOAT16X4 Vector, size_t len) diff --git a/src/lib/hqnbitgemm_kernel_neon_fp16.cpp b/src/lib/hqnbitgemm_kernel_neon_fp16.cpp new file mode 100644 index 0000000..69e37d2 --- /dev/null +++ b/src/lib/hqnbitgemm_kernel_neon_fp16.cpp @@ -0,0 +1,898 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + hqnbitgemm_kernel_neon_fp16.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + MLAS_QNBIT_GEMM_COMPUTE_TYPE HQNBIT_CompFp16. + +--*/ + +#include + +#include +#include +#include + +#include "fp16_common.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ +MLAS_FORCEINLINE void +Transpose8x8(uint8x8_t& v0, uint8x8_t& v1, uint8x8_t& v2, uint8x8_t& v3, + uint8x8_t& v4, uint8x8_t& v5, uint8x8_t& v6, uint8x8_t& v7) +{ + // v0: | B00 B10 | B20 B30 | B40 B50 | B60 B70 | B80 B90 | Ba0 Bb0 | Bc0 Bd0 | Be0 Bf0 | + // v1: | B01 B11 | B21 B31 | B41 B51 | B61 B71 | B81 B91 | Ba1 Bb1 | Bc1 Bd1 | Be1 Bf1 | + // v2: | B02 B12 | B22 B32 | B42 B52 | B62 B72 | B82 B92 | Ba2 Bb2 | Bc2 Bd2 | Be2 Bf2 | + // v3: | B03 B13 | B23 B33 | B43 B53 | B63 B73 | B83 B93 | Ba3 Bb3 | Bc3 Bd3 | Be3 Bf3 | + // v4: | B04 B14 | B24 B34 | B44 B54 | B64 B74 | B84 B94 | Ba4 Bb4 | Bc4 Bd4 | Be4 Bf4 | + // v5: | B05 B15 | B25 B35 | B45 B55 | B65 B75 | B85 B95 | Ba5 Bb5 | Bc5 Bd5 | Be5 Bf5 | + // v6: | B06 B16 | B26 B36 | B46 B56 | B66 B76 | B86 B96 | Ba6 Bb6 | Bc6 Bd6 | Be6 Bf6 | + // v7: | B07 B17 | B27 B37 | B47 B57 | B67 B77 | B87 B97 | Ba7 Bb7 | Bc7 Bd7 | Be7 Bf7 | + + uint8x8x2_t a0 = vtrn_u8(v0, v1); + uint8x8x2_t a1 = vtrn_u8(v2, v3); + uint8x8x2_t a2 = vtrn_u8(v4, v5); + uint8x8x2_t a3 = vtrn_u8(v6, v7); + + // a0[0]: | B00 B10 | B01 B11 | B40 B50 | B41 B51 | B80 B90 | B81 B91 | Bc0 Bd0 | Bc1 Bd1 | + // a0[1]: | B20 B30 | B21 B31 | B60 B70 | B61 B71 | Ba0 Bb0 | Ba1 Bb1 | Be0 Bf0 | Be1 Bf1 | + // a1[0]: | B02 B12 | B03 B13 | B42 B52 | B43 B53 | B82 B92 | B83 B93 | Bc2 Bd2 | Bc3 Bd3 | + // a1[1]: | B22 B32 | B23 B33 | B62 B72 | B63 B73 | Ba2 Bb2 | Ba3 Bb3 | Be2 Bf2 | Be3 Bf3 | + // a2[0]: | B04 B14 | B05 B15 | B44 B54 | B45 B55 | B84 B94 | B85 B95 | Bc4 Bd4 | Bc5 Bd5 | + // a2[1]: | B24 B34 | B25 B35 | B64 B74 | B65 B75 | Ba4 Bb4 | Ba5 Bb5 | Be4 Bf4 | Be5 Bf5 | + // a3[0]: | B06 B16 | B07 B17 | B46 B56 | B47 B57 | B86 B96 | B87 B97 | Bc6 Bd6 | Bc7 Bd7 | + // a3[1]: | B26 B36 | B27 B37 | B66 B76 | B67 B77 | Ba6 Bb6 | Ba7 Bb7 | Be6 Bf6 | Be7 Bf7 | + + uint16x4x2_t b0 = vtrn_u16(vreinterpret_u16_u8(a0.val[0]), vreinterpret_u16_u8(a1.val[0])); + uint16x4x2_t b1 = vtrn_u16(vreinterpret_u16_u8(a0.val[1]), vreinterpret_u16_u8(a1.val[1])); + uint16x4x2_t b2 = vtrn_u16(vreinterpret_u16_u8(a2.val[0]), vreinterpret_u16_u8(a3.val[0])); + uint16x4x2_t b3 = vtrn_u16(vreinterpret_u16_u8(a2.val[1]), vreinterpret_u16_u8(a3.val[1])); + + // b0[0]: | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B80 B90 | B81 B91 | B82 B92 | B83 B93 | + // b0[1]: | B40 B50 | B41 B51 | B42 B52 | B43 B53 | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | + // b1[0]: | B20 B30 | B21 B31 | B22 B32 | B23 B33 | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | + // b1[1]: | B60 B70 | B61 B71 | B62 B72 | B63 B73 | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | + // b2[0]: | B04 B14 | B05 B15 | B06 B16 | B07 B17 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // b2[1]: | B44 B54 | B45 B55 | B46 B56 | B47 B57 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // b3[0]: | B24 B34 | B25 B35 | B26 B36 | B27 B37 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // b3[1]: | B64 B74 | B65 B75 | B66 B76 | B67 B77 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + + uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]), vreinterpret_u32_u16(b2.val[0])); + uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]), vreinterpret_u32_u16(b2.val[1])); + uint32x2x2_t c2 = vtrn_u32(vreinterpret_u32_u16(b1.val[0]), vreinterpret_u32_u16(b3.val[0])); + uint32x2x2_t c3 = vtrn_u32(vreinterpret_u32_u16(b1.val[1]), vreinterpret_u32_u16(b3.val[1])); + + // c0[0]: | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B04 B14 | B05 B15 | B06 B16 | B07 B17 | + // c0[1]: | B80 B90 | B81 B91 | B92 B92 | B83 B93 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // c1[0]: | B40 B50 | B41 B51 | B42 B52 | B43 B53 | B44 B54 | B45 B55 | B46 B56 | B47 B57 | + // c1[1]: | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // c2[0]: | B20 B30 | B21 B31 | B22 B32 | B23 B33 | B24 B34 | B25 B35 | B26 B36 | B27 B37 | + // c2[1]: | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // c3[0]: | B60 B70 | B61 B71 | B62 B72 | B63 B73 | B64 B74 | B65 B75 | B66 B76 | B67 B77 | + // c3[1]: | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + + v0 = vreinterpret_u8_u32(c0.val[0]); + v1 = vreinterpret_u8_u32(c2.val[0]); + v2 = vreinterpret_u8_u32(c1.val[0]); + v3 = vreinterpret_u8_u32(c3.val[0]); + v4 = vreinterpret_u8_u32(c0.val[1]); + v5 = vreinterpret_u8_u32(c2.val[1]); + v6 = vreinterpret_u8_u32(c1.val[1]); + v7 = vreinterpret_u8_u32(c3.val[1]); +} + +MLAS_FORCEINLINE void +Transpose4x8(float16x8_t& v0, float16x8_t& v1, float16x8_t& v2, float16x8_t& v3) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // => + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + + v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); + v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); +} + +MLAS_FORCEINLINE void +Transpose4x4(float16x4_t& v0, float16x4_t& v1, float16x4_t& v2, float16x4_t& v3) +{ + float16x4x2_t t01 = vtrn_f16(v0, v1); + float16x4x2_t t23 = vtrn_f16(v2, v3); + + v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); + v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); +} + +void +HQ4BitGemmPackQuantBData_CompFp16( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + MLAS_UNREFERENCED_PARAMETER(ComputeType); + constexpr size_t nbits = 4; + constexpr size_t k_blk_dim = 16; + constexpr size_t n_blk_dim = 8; + assert(BlkLen > 0 && BlkLen % k_blk_dim == 0); + + const size_t k_blk_num = MlasDivRoundup(K, k_blk_dim); + const size_t n_blk_num = MlasDivRoundup(N, n_blk_dim); + constexpr size_t k_blk_bytes = MlasQNBitBlkDataSizeInBytes(nbits, k_blk_dim); + const size_t iterations = k_blk_num * n_blk_num; // one iteration per block + const size_t ld = MlasDivRoundup(K, BlkLen) * MlasQNBitBlkDataSizeInBytes(nbits, BlkLen); + + // + // For blocks 16_K * 8_N, transpose bytes in 8x8 blocks like this: + // src B_k_n: + // | B00 B10 | B20 B30 | B40 B50 | B60 B70 | B80 B90 | Ba0 Bb0 | Bc0 Bd0 | Be0 Bf0 | + // | B01 B11 | B21 B31 | B41 B51 | B61 B71 | B81 B91 | Ba1 Bb1 | Bc1 Bd1 | Be1 Bf1 | + // | B02 B12 | B22 B32 | B42 B52 | B62 B72 | B82 B92 | Ba2 Bb2 | Bc2 Bd2 | Be2 Bf2 | + // | B03 B13 | B23 B33 | B43 B53 | B63 B73 | B83 B93 | Ba3 Bb3 | Bc3 Bd3 | Be3 Bf3 | + // | B04 B14 | B24 B34 | B44 B54 | B64 B74 | B84 B94 | Ba4 Bb4 | Bc4 Bd4 | Be4 Bf4 | + // | B05 B15 | B25 B35 | B45 B55 | B65 B75 | B85 B95 | Ba5 Bb5 | Bc5 Bd5 | Be5 Bf5 | + // | B06 B16 | B26 B36 | B46 B56 | B66 B76 | B86 B96 | Ba6 Bb6 | Bc6 Bd6 | Be6 Bf6 | + // | B07 B17 | B27 B37 | B47 B57 | B67 B77 | B87 B97 | Ba7 Bb7 | Bc7 Bd7 | Be7 Bf7 | + // => dst: + // | B00 B10 | B01 B11 | B02 B12 | B03 B13 | B04 B14 | B05 B15 | B06 B16 | B07 B17 | + // | B20 B30 | B21 B31 | B22 B32 | B23 B33 | B24 B34 | B25 B35 | B26 B36 | B27 B37 | + // | B40 B50 | B41 B51 | B42 B52 | B43 B53 | B44 B54 | B45 B55 | B46 B56 | B47 B57 | + // | B60 B70 | B61 B71 | B62 B72 | B63 B73 | B64 B74 | B65 B75 | B66 B76 | B67 B77 | + // | B80 B90 | B81 B91 | B92 B92 | B83 B93 | B84 B94 | B85 B95 | B86 B96 | B87 B97 | + // | Ba0 Bb0 | Ba1 Bb1 | Ba2 Bb2 | Ba3 Bb3 | Ba4 Bb4 | Ba5 Bb5 | Ba6 Bb6 | Ba7 Bb7 | + // | Bc0 Bd0 | Bc1 Bd1 | Bc2 Bd2 | Bc3 Bd3 | Bc4 Bd4 | Bc5 Bd5 | Bc6 Bd6 | Bc7 Bd7 | + // | Be0 Bf0 | Be1 Bf1 | Be2 Bf2 | Be3 Bf3 | Be4 Bf4 | Be5 Bf5 | Be6 Bf6 | Be7 Bf7 | + // + + // + // For blocks < 8_N: + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + MlasTrySimpleParallel( + ThreadPool, iterations, + [&](ptrdiff_t tid) { + const size_t n_blk = tid / k_blk_num; + const size_t k_blk = tid % k_blk_num; + size_t n = n_blk * n_blk_dim; + const size_t src_offset = n * ld + k_blk * k_blk_bytes; + + if (n + n_blk_dim <= N) { + const size_t dst_offset = n * ld + k_blk * k_blk_bytes * n_blk_dim; + const uint8_t* src = reinterpret_cast(QuantBDataBegin) + src_offset; + uint8_t* dst = reinterpret_cast(PackedQuantBDataBegin) + dst_offset; + + uint8x8_t v0 = vld1_u8(src); + uint8x8_t v1 = vld1_u8(src + ld); + uint8x8_t v2 = vld1_u8(src + 2*ld); + uint8x8_t v3 = vld1_u8(src + 3*ld); + uint8x8_t v4 = vld1_u8(src + 4*ld); + uint8x8_t v5 = vld1_u8(src + 5*ld); + uint8x8_t v6 = vld1_u8(src + 6*ld); + uint8x8_t v7 = vld1_u8(src + 7*ld); + + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + + vst1_u8(dst, v0); + vst1_u8(dst + 8, v1); + vst1_u8(dst + 16, v2); + vst1_u8(dst + 24, v3); + vst1_u8(dst + 32, v4); + vst1_u8(dst + 40, v5); + vst1_u8(dst + 48, v6); + vst1_u8(dst + 56, v7); + } else { + const uint8_t* src = reinterpret_cast(QuantBDataBegin) + src_offset; + uint8_t* dst = reinterpret_cast(PackedQuantBDataBegin) + src_offset; + + for (; n < N; ++n, src += ld, dst += ld) { + uint8x8_t v0 = vld1_u8(src); + uint8x8_t v_even = vand_u8(v0, vdup_n_u8(0x0F)); + uint8x8_t v_odd = vshr_n_u8(v0, 4); + uint8x8x2_t v1 = vzip_u8(v_even, v_odd); + uint8x8_t v2 = vorr_u8(v1.val[0], vshl_n_u8(v1.val[1], 4)); + vst1_u8(dst, v2); + } + } + } + ); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && K == 16), void> +HQ4BitBlkDequantBKernel( + const std::uint8_t* src_ptr, + const float16x8_t& scale, + const float16x8_t& neg_scaled_zp, + _mlas_fp16_* dst_ptr +) { + const uint8x8_t low_mask = vdup_n_u8(0x0F); + + uint8x8_t b01 = vld1_u8(src_ptr); + uint8x8_t b23 = vld1_u8(src_ptr + 8); + uint8x8_t b45 = vld1_u8(src_ptr + 16); + uint8x8_t b67 = vld1_u8(src_ptr + 24); + uint8x8_t b89 = vld1_u8(src_ptr + 32); + uint8x8_t bab = vld1_u8(src_ptr + 40); + uint8x8_t bcd = vld1_u8(src_ptr + 48); + uint8x8_t bef = vld1_u8(src_ptr + 56); + + float16x8_t b0 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b01, low_mask), 0)); + float16x8_t b1 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b01, 4), 0)); + float16x8_t b2 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b23, low_mask), 0)); + float16x8_t b3 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b23, 4), 0)); + float16x8_t b4 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b45, low_mask), 0)); + float16x8_t b5 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b45, 4), 0)); + float16x8_t b6 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b67, low_mask), 0)); + float16x8_t b7 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b67, 4), 0)); + float16x8_t b8 = vcvtq_f16_u16(vshll_n_u8(vand_u8(b89, low_mask), 0)); + float16x8_t b9 = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(b89, 4), 0)); + float16x8_t ba = vcvtq_f16_u16(vshll_n_u8(vand_u8(bab, low_mask), 0)); + float16x8_t bb = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bab, 4), 0)); + float16x8_t bc = vcvtq_f16_u16(vshll_n_u8(vand_u8(bcd, low_mask), 0)); + float16x8_t bd = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bcd, 4), 0)); + float16x8_t be = vcvtq_f16_u16(vshll_n_u8(vand_u8(bef, low_mask), 0)); + float16x8_t bf = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(bef, 4), 0)); + + float16x8_t c0 = vfmaq_f16(neg_scaled_zp, b0, scale); + float16x8_t c1 = vfmaq_f16(neg_scaled_zp, b1, scale); + float16x8_t c2 = vfmaq_f16(neg_scaled_zp, b2, scale); + float16x8_t c3 = vfmaq_f16(neg_scaled_zp, b3, scale); + float16x8_t c4 = vfmaq_f16(neg_scaled_zp, b4, scale); + float16x8_t c5 = vfmaq_f16(neg_scaled_zp, b5, scale); + float16x8_t c6 = vfmaq_f16(neg_scaled_zp, b6, scale); + float16x8_t c7 = vfmaq_f16(neg_scaled_zp, b7, scale); + float16x8_t c8 = vfmaq_f16(neg_scaled_zp, b8, scale); + float16x8_t c9 = vfmaq_f16(neg_scaled_zp, b9, scale); + float16x8_t ca = vfmaq_f16(neg_scaled_zp, ba, scale); + float16x8_t cb = vfmaq_f16(neg_scaled_zp, bb, scale); + float16x8_t cc = vfmaq_f16(neg_scaled_zp, bc, scale); + float16x8_t cd = vfmaq_f16(neg_scaled_zp, bd, scale); + float16x8_t ce = vfmaq_f16(neg_scaled_zp, be, scale); + float16x8_t cf = vfmaq_f16(neg_scaled_zp, bf, scale); + + MlasStoreFloat16x8(dst_ptr, c0); + MlasStoreFloat16x8(dst_ptr + 8, c1); + MlasStoreFloat16x8(dst_ptr + 16, c2); + MlasStoreFloat16x8(dst_ptr + 24, c3); + MlasStoreFloat16x8(dst_ptr + 32, c4); + MlasStoreFloat16x8(dst_ptr + 40, c5); + MlasStoreFloat16x8(dst_ptr + 48, c6); + MlasStoreFloat16x8(dst_ptr + 56, c7); + MlasStoreFloat16x8(dst_ptr + 64, c8); + MlasStoreFloat16x8(dst_ptr + 72, c9); + MlasStoreFloat16x8(dst_ptr + 80, ca); + MlasStoreFloat16x8(dst_ptr + 88, cb); + MlasStoreFloat16x8(dst_ptr + 96, cc); + MlasStoreFloat16x8(dst_ptr + 104, cd); + MlasStoreFloat16x8(dst_ptr + 112, ce); + MlasStoreFloat16x8(dst_ptr + 120, cf); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 1 && K == 16), void> +HQ4BitBlkDequantBKernel( + const std::uint8_t* src_ptr, + const float16x8_t& scale, + const float16x8_t& neg_scaled_zp, + _mlas_fp16_* dst_ptr +) { + const uint8x8_t low_mask = vdup_n_u8(0x0F); + + uint8x8_t v0 = vld1_u8(src_ptr); + + float16x8_t f_low = vcvtq_f16_u16(vshll_n_u8(vand_u8(v0, low_mask), 0)); + float16x8_t f_high = vcvtq_f16_u16(vshll_n_u8(vshr_n_u8(v0, 4), 0)); + + float16x8_t c0 = vfmaq_f16(neg_scaled_zp, f_low, scale); + float16x8_t c1 = vfmaq_f16(neg_scaled_zp, f_high, scale); + + MlasStoreFloat16x8(dst_ptr, c0); + MlasStoreFloat16x8(dst_ptr + 8, c1); +} + +void +HQ4BitBlkDequantBForHgemm_CompFp16( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t K, + size_t BlockCountK +) { + MLAS_UNREFERENCED_PARAMETER(K); + constexpr size_t nbits = 4; + constexpr size_t kk_blk_dim = 16; + constexpr size_t n_blk_dim = 8; + assert(BlkLen > 0 && BlkLen % kk_blk_dim == 0); + + const size_t kk_blk_num = BlockCountK * BlkLen / kk_blk_dim; + constexpr size_t kk_blk_bytes = MlasQNBitBlkDataSizeInBytes(nbits, kk_blk_dim); + const size_t kk_n_src_bytes = kk_blk_bytes * n_blk_dim; + const size_t kk_n_dst_size = kk_blk_dim * n_blk_dim; + const size_t ld_blk_src = kk_blk_num * kk_n_src_bytes; + const size_t ld_blk_dst = BlkLen * BlockCountK * n_blk_dim; + const size_t ld_blk_scale = BlockCountK * n_blk_dim; + const size_t ld_zp = (BlockCountK + 1) / 2; + const size_t ld_blk_zp = ld_zp * n_blk_dim; + const float16x8_t zp_mid_point_vec = MlasBroadcastFloat16x8(MLAS_FP16(8.0f).val); + const bool has_zp = QuantBZeroPoint != nullptr; + + size_t n = 0; + for (; n + n_blk_dim <= CountN; n += n_blk_dim) { + const auto* scales_ptr = reinterpret_cast(QuantBScale); + const std::uint8_t* zero_points_ptr = reinterpret_cast(QuantBZeroPoint); + const std::uint8_t* src_ptr = reinterpret_cast(QuantBData); + auto* dst_ptr = reinterpret_cast<_mlas_fp16_*>(FpData); + + for (size_t k_blk_i = 0; k_blk_i < BlockCountK; ++k_blk_i) { + // prepare scales and zero_points for the block + _mlas_fp16_ scales[n_blk_dim]; + uint16_t zero_points[n_blk_dim]; + float16x8_t scale_vec; + float16x8_t neg_scaled_zp_vec; + + UnrolledLoop([&](int nn){ + scales[nn] = scales_ptr[nn * BlockCountK]; + }); + scale_vec = MlasLoadFloat16x8(scales); + + if (has_zp) { + UnrolledLoop([&](int nn){ + uint8_t zp = zero_points_ptr[nn * ld_zp]; + zp = (k_blk_i & 1) ? (zp >> 4) : (zp & 0x0F); + zero_points[nn] = static_cast(zp); + }); + uint16x8_t zp_u16_vec = vld1q_u16(zero_points); + neg_scaled_zp_vec = vcvtq_f16_u16(zp_u16_vec); + } else { + neg_scaled_zp_vec = zp_mid_point_vec; + } + neg_scaled_zp_vec = vnegq_f16(vmulq_f16(scale_vec, neg_scaled_zp_vec)); + + for (size_t kk = 0; kk < BlkLen; kk += kk_blk_dim) { + HQ4BitBlkDequantBKernel<8, 16>(src_ptr, scale_vec, neg_scaled_zp_vec, dst_ptr); + + src_ptr += kk_n_src_bytes; + dst_ptr += kk_n_dst_size; + } + + ++scales_ptr; + if (has_zp) { + zero_points_ptr += k_blk_i & 1; + } + } + + QuantBData += ld_blk_src; + FpData += ld_blk_dst; + QuantBScale += ld_blk_scale; + QuantBZeroPoint = has_zp ? QuantBZeroPoint + ld_blk_zp : nullptr; + } + + // remaining N + for (; n < CountN; ++n) { + const auto* scales_ptr = reinterpret_cast(QuantBScale); + const std::uint8_t* zero_points_ptr = reinterpret_cast(QuantBZeroPoint); + for (size_t k_blk_i = 0; k_blk_i < BlockCountK; ++k_blk_i) { + const auto scale = scales_ptr[0]; + float16x8_t scale_vec = MlasBroadcastFloat16x8(scale); + float16x8_t neg_scaled_zp_vec; + + if (has_zp) { + uint8_t zero_point = static_cast(zero_points_ptr[0]); + zero_point = (k_blk_i & 1) ? (zero_point >> 4) : (zero_point & 0x0F); + uint16x8_t zp_u16_vec = vdupq_n_u16(static_cast(zero_point)); + neg_scaled_zp_vec = vcvtq_f16_u16(zp_u16_vec); + } else { + neg_scaled_zp_vec = zp_mid_point_vec; + } + neg_scaled_zp_vec = vnegq_f16(vmulq_f16(scale_vec, neg_scaled_zp_vec)); + + for (size_t kk = 0; kk < BlkLen; kk += kk_blk_dim) { + HQ4BitBlkDequantBKernel<1, 16>( + reinterpret_cast(QuantBData), scale_vec, neg_scaled_zp_vec, + reinterpret_cast<_mlas_fp16_*>(FpData) + ); + + QuantBData += kk_blk_bytes; + FpData += kk_blk_dim; + } + + ++scales_ptr; + if (has_zp) { + zero_points_ptr += k_blk_i & 1; + } + } + + QuantBScale += BlockCountK; + if (has_zp) { + QuantBZeroPoint += ld_zp; + } + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8), float16x8_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + if (Bias) { + return MlasLoadFloat16x8(Bias); + } else { + return MlasZeroFloat16x8(); + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 4), float16x4_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + if (Bias) { + return MlasLoadFloat16x4(Bias); + } else { + return MlasZeroFloat16x4(); + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N == 2 || N == 1)), float16x4_t> +PrepareAccumulator(const _mlas_fp16_* Bias) +{ + float16x4_t v = MlasZeroFloat16x4(); + + if (Bias) { + v = MlasLoadLaneFloat16x4<0>(Bias, v); + if constexpr (N == 2) { + v = MlasLoadLaneFloat16x4<1>(Bias + 1, v); + } + return v; + } else { + return v; + } +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && K == 8), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x8_t a0 = MlasLoadFloat16x8(A); + float16x8_t b0 = MlasLoadFloat16x8(B); + float16x8_t b1 = MlasLoadFloat16x8(B + 8); + float16x8_t b2 = MlasLoadFloat16x8(B + 16); + float16x8_t b3 = MlasLoadFloat16x8(B + 24); + float16x8_t b4 = MlasLoadFloat16x8(B + 32); + float16x8_t b5 = MlasLoadFloat16x8(B + 40); + float16x8_t b6 = MlasLoadFloat16x8(B + 48); + float16x8_t b7 = MlasLoadFloat16x8(B + 56); + + // This version uses less instructions, but introduces dependency path between instructions. + // Must pair it with loop unrolling to alleviate dependency path penalty. + float16x8_t c0 = vfmaq_laneq_f16(accumulator, b0, a0, 0); + c0 = vfmaq_laneq_f16(c0, b1, a0, 1); + c0 = vfmaq_laneq_f16(c0, b2, a0, 2); + c0 = vfmaq_laneq_f16(c0, b3, a0, 3); + c0 = vfmaq_laneq_f16(c0, b4, a0, 4); + c0 = vfmaq_laneq_f16(c0, b5, a0, 5); + c0 = vfmaq_laneq_f16(c0, b6, a0, 6); + c0 = vfmaq_laneq_f16(c0, b7, a0, 7); + + return c0; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && K == 4), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x4_t a0 = MlasLoadFloat16x4(A); + float16x8_t b0 = MlasLoadFloat16x8(B); + float16x8_t b1 = MlasLoadFloat16x8(B + 8); + float16x8_t b2 = MlasLoadFloat16x8(B + 16); + float16x8_t b3 = MlasLoadFloat16x8(B + 24); + + float16x8_t c0 = vfmaq_lane_f16(accumulator, b0, a0, 0); + c0 = vfmaq_lane_f16(c0, b1, a0, 1); + c0 = vfmaq_lane_f16(c0, b2, a0, 2); + c0 = vfmaq_lane_f16(c0, b3, a0, 3); + + return c0; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<(N == 8 && M == 1 && (K == 2 || K == 1)), float16x8_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x8_t accumulator +) { + MLAS_UNREFERENCED_PARAMETER(ldb); + float16x4_t a0 = MlasZeroFloat16x4(); + a0 = MlasLoadLaneFloat16x4<0>(A, a0); + if constexpr (K == 2) a0 = MlasLoadLaneFloat16x4<1>(A + 1, a0); + float16x8_t b0 = MlasLoadFloat16x8(B), b1; + if constexpr (K == 2) b1 = MlasLoadFloat16x8(B + 8); + + float16x8_t c0 = vfmaq_lane_f16(accumulator, b0, a0, 0), c01; + if constexpr (K == 2) c01 = vfmaq_lane_f16(c0, b1, a0, 1); + + if constexpr (K == 1) + return c0; + else + return c01; +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && K == 8), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x8_t a0 = MlasLoadFloat16x8(A); + + float16x8_t b0, b1, b2, b3; + b0 = MlasLoadFloat16x8(B); + if constexpr (N > 1) b1 = MlasLoadFloat16x8(B + ldb); + if constexpr (N > 2) b2 = MlasLoadFloat16x8(B + ldb * 2); + if constexpr (N > 3) b3 = MlasLoadFloat16x8(B + ldb * 3); + + float16x8_t c00, c01, c02, c03; + c00 = vmulq_f16(b0, a0); + if constexpr (N > 1) + c01 = vmulq_f16(b1, a0); + else + c01 = MlasZeroFloat16x8(); + if constexpr (N > 2) + c02 = vmulq_f16(b2, a0); + else + c02 = MlasZeroFloat16x8(); + if constexpr (N > 3) + c03 = vmulq_f16(b3, a0); + else + c03 = MlasZeroFloat16x8(); + + Transpose4x8(c00, c01, c02, c03); + + float16x8_t c_low_high = vaddq_f16(vaddq_f16(c00, c01), vaddq_f16(c02, c03)); + float16x4_t c_low = vget_low_f16(c_low_high); + float16x4_t c_high = vget_high_f16(c_low_high); + float16x4_t c = vadd_f16(c_low, c_high); + + return vadd_f16(c, accumulator); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && (K == 4)), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x4_t a0 = MlasLoadFloat16x4(A); + float16x4_t b0, b1, b2, b3; + b0 = MlasLoadFloat16x4(B); + if constexpr (N > 1) b1 = MlasLoadFloat16x4(B + ldb); + if constexpr (N > 2) b2 = MlasLoadFloat16x4(B + ldb * 2); + if constexpr (N > 3) b3 = MlasLoadFloat16x4(B + ldb * 3); + + float16x4_t c00, c01, c02, c03; + c00 = vmul_f16(b0, a0); + if constexpr (N > 1) + c01 = vmul_f16(b1, a0); + else + c01 = MlasZeroFloat16x4(); + if constexpr (N > 2) + c02 = vmul_f16(b2, a0); + else + c02 = MlasZeroFloat16x4(); + if constexpr (N > 3) + c03 = vmul_f16(b3, a0); + else + c03 = MlasZeroFloat16x4(); + + Transpose4x4(c00, c01, c02, c03); + + float16x4_t c = vadd_f16(vadd_f16(c00, c01), vadd_f16(c02, c03)); + return vadd_f16(c, accumulator); +} + +template +MLAS_FORCEINLINE +typename std::enable_if_t<((N > 0 && N <= 4) && M == 1 && (K > 0 && K < 4)), float16x4_t> +HQ4BitGemmMicroKernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const size_t ldb, + float16x4_t accumulator +) { + float16x4_t a0 = MlasZeroFloat16x4(); + float16x4_t b0 = MlasZeroFloat16x4(), b1, b2, b3; + if constexpr (N > 1) b1 = MlasZeroFloat16x4(); + if constexpr (N > 2) b2 = MlasZeroFloat16x4(); + if constexpr (N > 3) b3 = MlasZeroFloat16x4(); + + a0 = MlasLoadLaneFloat16x4<0>(A, a0); + b0 = MlasLoadLaneFloat16x4<0>(B, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<0>(B + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<0>(B + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<0>(B + ldb * 3, b3); + + if constexpr (K >= 2) { + a0 = MlasLoadLaneFloat16x4<1>(A + 1, a0); + b0 = MlasLoadLaneFloat16x4<1>(B + 1, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<1>(B + 1 + ldb * 3, b3); + } + + if constexpr (K >= 3) { + a0 = MlasLoadLaneFloat16x4<2>(A + 2, a0); + b0 = MlasLoadLaneFloat16x4<2>(B + 2, b0); + if constexpr (N > 1) b1 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb, b1); + if constexpr (N > 2) b2 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb * 2, b2); + if constexpr (N > 3) b3 = MlasLoadLaneFloat16x4<2>(B + 2 + ldb * 3, b3); + } + + float16x4_t c00, c01, c02, c03; + c00 = vmul_f16(b0, a0); + if constexpr (N > 1) + c01 = vmul_f16(b1, a0); + else + c01 = MlasZeroFloat16x4(); + if constexpr (N > 2) + c02 = vmul_f16(b2, a0); + else + c02 = MlasZeroFloat16x4(); + if constexpr (N > 3) + c03 = vmul_f16(b3, a0); + else + c03 = MlasZeroFloat16x4(); + + Transpose4x4(c00, c01, c02, c03); + + float16x4_t c = vadd_f16(vadd_f16(c00, c01), vadd_f16(c02, c03)); + return vadd_f16(c, accumulator); +} + +template +typename std::enable_if_t<((CountN >= 1 && CountN <= 16 && ((CountN - 1) & CountN) == 0) && (CountM == 1 || CountM == 2)), void> +HQ4BitGemmKernel_CompFp16_Kernel( + const _mlas_fp16_* A, + const _mlas_fp16_* B, + const _mlas_fp16_* Bias, + _mlas_fp16_* C, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +) { + using RegisterType = typename std::conditional_t<(CountN < 8), float16x4_t, float16x8_t>; + + RegisterType accu00, accu01, accu10, accu11; + constexpr size_t b_step = CountN >= 8 ? 8 : 1; + constexpr size_t N = CountN == 16 ? 8 : CountN; + + if constexpr (CountM == 2) { + accu00 = accu10 = PrepareAccumulator(Bias); + } else { + accu00 = PrepareAccumulator(Bias); + } + if constexpr (CountN == 16) { + if constexpr (CountM == 2) { + accu01 = accu11 = PrepareAccumulator(Bias ? Bias + 8 : nullptr); + } else { + accu01 = PrepareAccumulator(Bias ? Bias + 8 : nullptr); + } + } + + size_t k = 0; + for (; k + 8 <= K; k += 8, A += 8, B += b_step * 8) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + } + + if (K & 4) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + k += 4, A += 4, B += b_step * 4; + } + + if (K & 2) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + k += 2, A += 2, B += b_step * 2; + } + + if (k < K) { + accu00 = HQ4BitGemmMicroKernel(A, B, ldb, accu00); + if constexpr (CountN == 16) { + accu01 = HQ4BitGemmMicroKernel(A, B + b_step * ldb, ldb, accu01); + } + if constexpr (CountM == 2) { + accu10 = HQ4BitGemmMicroKernel(A + lda, B, ldb, accu10); + if constexpr (CountN == 16) { + accu11 = HQ4BitGemmMicroKernel(A + lda, B + b_step * ldb, ldb, accu11); + } + } + } + + if constexpr (CountN >= 8) { + MlasStoreFloat16x8(C, accu00); + if constexpr (CountN == 16) { + MlasStoreFloat16x8(C + 8, accu01); + } + } else if constexpr (CountN == 4) { + MlasStoreFloat16x4(C, accu00); + } else { + MlasStoreLaneFloat16x4<0>(C, accu00); + if constexpr (CountN == 2) { + MlasStoreLaneFloat16x4<1>(C + 1, accu00); + } + } + + if constexpr (CountM == 2) { + if constexpr (CountN >= 8) { + MlasStoreFloat16x8(C + ldc, accu10); + if constexpr (CountN == 16) { + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } else if constexpr (CountN == 4) { + MlasStoreFloat16x4(C + ldc, accu10); + } else { + MlasStoreLaneFloat16x4<0>(C + ldc, accu10); + if constexpr (CountN == 2) { + MlasStoreLaneFloat16x4<1>(C + ldc + 1, accu10); + } + } + } +} + +void +HQ4BitGemmKernel_CompFp16( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +) { + assert(CountM <= 2); + + // 2M_16N is the balance between loop unrolling and register spill. + // More unroll will trigger register spill. + // Less unroll will increase micro kernel dependency path penalty. + // TODO: dequant 16N as continuous segments. Current version dequants 8N. + const auto* a = reinterpret_cast(A); + const auto* b = reinterpret_cast(B); + const auto* bias = reinterpret_cast(Bias); + auto* c = reinterpret_cast<_mlas_fp16_*>(C); + + for (; CountN >= 16; CountN -= 16) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<16, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<16, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 16 * ldb, c += 16; + if (bias) bias += 16; + } + + if (CountN & 8) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<8, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<8, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 8 * ldb, c += 8; + if (bias) bias += 8; + } + + if (CountN & 4) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<4, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<4, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 4 * ldb, c += 4; + if (bias) bias += 4; + } + + if (CountN & 2) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<2, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<2, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + b += 2 * ldb, c += 2; + if (bias) bias += 2; + } + + if (CountN & 1) { + if (CountM == 2) { + HQ4BitGemmKernel_CompFp16_Kernel<1, 2>(a, b, bias, c, K, lda, ldb, ldc); + } else { + HQ4BitGemmKernel_CompFp16_Kernel<1, 1>(a, b, bias, c, K, lda, ldb, ldc); + } + } +} +} // namespace sqnbitgemm_neon diff --git a/src/lib/mlasi.h b/src/lib/mlasi.h index 13ea8d9..0533a5e 100644 --- a/src/lib/mlasi.h +++ b/src/lib/mlasi.h @@ -358,6 +358,22 @@ size_t bool ZeroMode ); +#ifdef FORCE_GENERIC_ALGORITHMS +typedef +size_t +(MLASCALL MLAS_GEMM_FLOAT_KERNEL_GENERIC)( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha + ); +#endif + #else #if defined(__aarch64__) && defined(__linux__) @@ -733,6 +749,10 @@ extern "C" { #if defined(MLAS_TARGET_AMD64_IX86) MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelSse; MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx; +#ifdef FORCE_GENERIC_ALGORITHMS + MLAS_GEMM_FLOAT_KERNEL_GENERIC MlasSgemmKernelZero; + MLAS_GEMM_FLOAT_KERNEL_GENERIC MlasSgemmKernelAdd; +#endif #if defined(MLAS_TARGET_AMD64) MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelFma3; MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelAvx512F; @@ -1017,17 +1037,17 @@ extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; // Float/quantized n-bit integer matrix/matrix multiply dispatch structure. // -struct MLAS_SQNBIT_GEMM_DISPATCH; +struct MLAS_QNBIT_GEMM_DISPATCH; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; -extern const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; // // Quantized depthwise convolution kernels. @@ -1184,7 +1204,7 @@ struct MLAS_PLATFORM { const MLAS_FPQ4GEMM_DISPATCH* FpQ4GemmDispatch{nullptr}; const MLAS_Q8Q4GEMM_DISPATCH* Q8Q4GemmDispatch{nullptr}; - const MLAS_SQNBIT_GEMM_DISPATCH* SQNBitGemmDispatch{nullptr}; + const MLAS_QNBIT_GEMM_DISPATCH* QNBitGemmDispatch{nullptr}; MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; diff --git a/src/lib/platform.cpp b/src/lib/platform.cpp index ed56d82..2aea7a9 100644 --- a/src/lib/platform.cpp +++ b/src/lib/platform.cpp @@ -287,7 +287,11 @@ Return Value: this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; #ifndef __APPLE__ +#ifndef FORCE_GENERIC_ALGORITHMS this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; +#else // FORCE_GENERIC_ALGORITHMS + this->CastF16ToF32Kernel = nullptr; +#endif // FORCE_GENERIC_ALGORITHMS #endif // __APPLE__ this->NchwcBlockSize = 8; @@ -309,8 +313,11 @@ Return Value: // // Check if the processor supports SSE 4.1 instructions. // - +#ifndef FORCE_GENERIC_ALGORITHMS if ((Cpuid1[2] & 0x80000) != 0) { +#else // FORCE_GENERIC_ALGORITHMS + if (false) { +#endif // FORCE_GENERIC_ALGORITHMS this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchSse41; } @@ -320,7 +327,11 @@ Return Value: // Check if the processor supports the AVX and OSXSAVE features. // +#ifndef FORCE_GENERIC_ALGORITHMS if ((Cpuid1[2] & 0x18000000) == 0x18000000) { +#else // FORCE_GENERIC_ALGORITHMS + if (false) { +#endif // FORCE_GENERIC_ALGORITHMS // // Check if the operating system supports saving SSE and AVX states. @@ -388,7 +399,7 @@ Return Value: this->ConvDepthwiseS8S8Kernel = MlasConvDepthwiseKernelAvx2; this->ConvDepthwiseS8U8Kernel = MlasConvDepthwiseKernelAvx2; this->ComputeSumExpF32Kernel = MlasComputeSumExpF32KernelFma3; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; @@ -418,7 +429,7 @@ Return Value: this->GemmU8S8Kernel = MlasGemmU8S8KernelAvxVnni; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvxVnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvxVnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2vnni; } #if !defined(ORT_MINIMAL_BUILD) @@ -459,7 +470,7 @@ Return Value: this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core; this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512; // // Check if the processor supports AVX512VNNI. @@ -472,7 +483,7 @@ Return Value: this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Vnni; this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Vnni; this->Q8Q4GemmDispatch = &MlasQ8Q4GemmDispatchAvx512vnni; - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512vnni; } } } @@ -532,6 +543,7 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. @@ -561,9 +573,6 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchSdot; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchDot; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; - - // MlasSQNBitGemmDispatchNeon has a dependency on dot product instructions - this->SQNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; } #if defined(__linux__) diff --git a/src/lib/qgemm.h b/src/lib/qgemm.h index 1ef5b5f..bcd878e 100644 --- a/src/lib/qgemm.h +++ b/src/lib/qgemm.h @@ -867,7 +867,8 @@ MlasGemmQuantGetDispatch( { const MLAS_GEMM_QUANT_DISPATCH* GemmQuantDispatch = &MlasGemmQuantDispatchDefault; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_LARCH64) +#if !defined(FORCE_GENERIC_ALGORITHMS) +#if defined(MLAS_TARGET_AMD64_IX86) if (AIsSigned) { GemmQuantDispatch = BIsSigned ? GetMlasPlatform().GemmS8S8Dispatch : GetMlasPlatform().GemmS8U8Dispatch; @@ -895,7 +896,13 @@ MlasGemmQuantGetDispatch( if (GetMlasPlatform().GemmU8X8Dispatch == &MlasGemm8X8DispatchPOWER10) { GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch; } +#elif defined(MLAS_TARGET_LARCH64) + if (!AIsSigned) { + GemmQuantDispatch = + BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch; + } #endif +#endif // !defined(FORCE_GENERIC_ALGORITHMS) if (nullptr == GemmQuantDispatch) { std::stringstream ss; diff --git a/src/lib/sqnbitgemm.cpp b/src/lib/qnbitgemm.cpp similarity index 62% rename from src/lib/sqnbitgemm.cpp rename to src/lib/qnbitgemm.cpp index b45f3a1..f064a8e 100644 --- a/src/lib/sqnbitgemm.cpp +++ b/src/lib/qnbitgemm.cpp @@ -6,53 +6,57 @@ Licensed under the MIT License. Module Name: - sqnbitgemm.cpp + qnbitgemm.cpp Abstract: This module implements the float/quantized n-bit integer matrix - multiplication hardware agnostic entrypoint, MlasSQNBitGemmBatch, + multiplication hardware agnostic entrypoint, MlasQNBitGemmBatch, as well as some SQNBitGemm-related query functions. --*/ -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" #include -#include namespace { -enum SQNBitGemmVariant { +enum QNBitGemmVariant { SQNBitGemmVariantInvalid = -1, // Valid variants SQNBitGemmVariant_BitWidth4_CompFp32 = 0, SQNBitGemmVariant_BitWidth4_CompInt8, + HQNBitGemmVariant_BitWidth4_CompFp16, + HQNBitGemmVariant_BitWidth4_CompInt8, // End of valid variants - // Keep this element last and ensure that its value is the number of valid SQNBitGemmVariant values. + // Keep this element last and ensure that its value is the number of valid QNBitGemmVariant values. // Its value is used as an array size. SQNBitGemmVariantCount, }; -SQNBitGemmVariant -GetSQNBitGemmVariant( +QNBitGemmVariant +GetQNBitGemmVariant( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { if (BlkBitWidth == 4 && (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { - if (ComputeType == CompFp32 || - ComputeType == CompUndef) { // treat CompUndef (undefined) as CompFp32 + if (ComputeType == SQNBIT_CompFp32) { return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == CompInt8) { + } else if (ComputeType == HQNBIT_CompFp16) { + return HQNBitGemmVariant_BitWidth4_CompFp16; + } else if (ComputeType == SQNBIT_CompInt8) { return SQNBitGemmVariant_BitWidth4_CompInt8; + } else if (ComputeType == HQNBIT_CompInt8) { + return HQNBitGemmVariant_BitWidth4_CompInt8; } } @@ -62,23 +66,28 @@ GetSQNBitGemmVariant( } // namespace bool MLASCALL -MlasIsSQNBitGemmAvailable( +MlasIsQNBitGemmAvailable( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return false; } - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); switch (Variant) { case SQNBitGemmVariant_BitWidth4_CompFp32: { return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && - Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr; + Dispatch->SQ4BitBlkDequantBForSgemm_CompFp32 != nullptr; + } + case HQNBitGemmVariant_BitWidth4_CompFp16: { + return Dispatch->HQ4BitGemmPackQuantBData != nullptr && + Dispatch->HQ4BitGemmKernel_CompFp16 != nullptr && + Dispatch->HQ4BitBlkDequantBForHgemm_CompFp16 != nullptr; } case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 return @@ -95,80 +104,80 @@ namespace { size_t -SQNBitGemmPerGemmWorkspaceSize( +QNBitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 0; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceSize != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceSize != nullptr) { + return Dispatch->Q4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); } return 0; } size_t -SQNBitGemmPerGemmWorkspaceAlignment( +QNBitGemmPerGemmWorkspaceAlignment( size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 1; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment != nullptr) { - return Dispatch->SQ4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceAlignment != nullptr) { + return Dispatch->Q4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); } return 1; } size_t -SQNBitGemmPerGemmWorkspaceStride( +QNBitGemmPerGemmWorkspaceStride( size_t M, size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto Size = SQNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); - const auto Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const auto Size = QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); return MlasDivRoundup(Size, Alignment) * Alignment; } } // namespace size_t MLASCALL -MlasSQNBitGemmBatchWorkspaceSize( +MlasQNBitGemmBatchWorkspaceSize( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const size_t PerGemmWorkspaceStride = QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); if (PerGemmWorkspaceStride == 0) { return 0; } - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const size_t Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); const size_t WorkspaceSize = BatchN * PerGemmWorkspaceStride; @@ -176,21 +185,21 @@ MlasSQNBitGemmBatchWorkspaceSize( } size_t MLASCALL -MlasSQNBitGemmPackQuantBDataSize( +MlasQNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return 0; } - if (BlkBitWidth == 4 && Dispatch->SQ4BitGemmPackQuantBDataSize != nullptr) { - return Dispatch->SQ4BitGemmPackQuantBDataSize( + if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->Q4BitGemmPackQuantBDataSize( N, K, BlkLen, ComputeType ); } @@ -214,12 +223,12 @@ struct PerGemmQuantAWorkspace { }; void MLASCALL -MlasSQNBitGemmPackQuantBData( +MlasQNBitGemmPackQuantBData( size_t N, size_t K, size_t BlkBitWidth, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const void* QuantBData, void* PackedQuantBDataAndOrBlkSumWorkspace, const void* QuantBScale, @@ -228,15 +237,15 @@ MlasSQNBitGemmPackQuantBData( MLAS_THREADPOOL* ThreadPool ) { - const auto* Dispatch = GetMlasPlatform().SQNBitGemmDispatch; + const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; if (Dispatch == nullptr) { return; } if (BlkBitWidth == 4) { - if (ComputeType == CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, @@ -249,6 +258,16 @@ MlasSQNBitGemmPackQuantBData( packed_quant_b, ThreadPool ); + } else if (ComputeType == HQNBIT_CompFp16 && Dispatch->HQ4BitGemmPackQuantBData != nullptr) { + Dispatch->HQ4BitGemmPackQuantBData( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(PackedQuantBDataAndOrBlkSumWorkspace), + ThreadPool + ); } else if (Dispatch->SQ4BitGemmPackQuantBData != nullptr) { // TODO: these assertions are true if called from matmul_nbits kernel but not from mlas tests. //assert(QuantBScale == nullptr); @@ -296,22 +315,11 @@ AddBiasForGemm(const float* Bias, float* C, size_t CountM, size_t CountN, size_t } } -typedef void(SQNBitGemmFn)( - size_t BlkLen, - size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, - void* PerGemmWorkspace, - size_t RangeStartM, - size_t RangeCountM, - size_t RangeStartN, - size_t RangeCountN -); - void SQ4BitGemm_CompFp32( const size_t BlkLen, const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, @@ -356,7 +364,7 @@ SQ4BitGemm_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmM1Kernel_CompFp32( BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, CountN, K, k_blks, bias ); @@ -394,7 +402,7 @@ SQ4BitGemm_CompFp32( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - GetMlasPlatform().SQNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitBlkDequantBForSgemm_CompFp32( BlkLen, dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks ); @@ -426,11 +434,84 @@ SQ4BitGemm_CompFp32( } } +void +HQ4BitGemm_CompFp16( + const size_t BlkLen, + const size_t K, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + constexpr size_t BlkBitWidth = 4; + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspace); + + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const size_t k_blk_num = MlasDivRoundup(K, BlkLen); + const size_t qldb = k_blk_num * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t ldb = k_blk_num * BlkLen; + const size_t k_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blk_num); + + const MLAS_FP16* A = DataParams->A + RangeStartM * lda; + MLAS_FP16* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * qldb; + const MLAS_FP16* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blk_num; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_zp_bytes; + const MLAS_FP16* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias; + + // 32N is the sweet spot of cache utilization. It is machine dependent though. + constexpr size_t StrideM = 2; + constexpr size_t StrideN = 32; + + // TODO(fajin): move allocation up to the op. + size_t bufsize = ldb * StrideN * sizeof(MLAS_FP16); + MlasThreadedBufAlloc(bufsize); + auto* dequant_b = reinterpret_cast(ThreadedBufHolder.get()); + + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16( + BlkLen, dequant_b, QuantBData, QuantBScale, QuantBZeroPoint, countN, K, k_blk_num + ); + + const MLAS_FP16* a = A; + MLAS_FP16* c = C; + for (size_t m = 0, countM; m < RangeCountM; m += countM) { + countM = std::min(StrideM, RangeCountM - m); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16( + a, dequant_b, Bias, c, countM, countN, K, lda, ldb, ldc + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM + m, RangeStartN + n, countM, countN, ldc + ); + } + + a += countM * lda; + c += countM * ldc; + } + + QuantBData += countN * qldb; + QuantBScale += countN * k_blk_num; + QuantBZeroPoint = QuantBZeroPoint ? QuantBZeroPoint + countN * k_zp_bytes : nullptr; + Bias = Bias ? Bias + countN : nullptr; + C += countN; + } +} + void SQ4BitGemm_CompInt8( const size_t BlkLen, const size_t K, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* const DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, void* const PerGemmWorkspace, const size_t RangeStartM, const size_t RangeCountM, @@ -501,10 +582,10 @@ SQ4BitGemm_CompInt8( float* c_blk = C + n; const float* bias = (Bias == nullptr) ? nullptr : Bias + n; - if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { + if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8 != nullptr) { size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { - const auto RowsHandled = GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( + const auto RowsHandled = GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_CompInt8( BlkLen, a_row, b_col, b_col_scale, b_col_zp, c_blk, RowsRemaining, CountN, K, k_blks, ldc, bias ); @@ -523,10 +604,10 @@ SQ4BitGemm_CompInt8( } } #ifdef MLAS_TARGET_AMD64_IX86 - else if (GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) + else if (GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { const float* b_blk_sum = QuantBBlkSum + n * k_blks; - GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( + GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8( BlkLen, QuantA, QuantAScale, @@ -555,26 +636,29 @@ SQ4BitGemm_CompInt8( } } -typedef void(InitializeWorkspaceFn)( +template +void +InitializeWorkspace_CompInt8( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool ); +template <> void -InitializeWorkspace_CompInt8( +InitializeWorkspace_CompInt8( size_t M, size_t N, size_t K, size_t BatchN, size_t BlkLen, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, MLAS_THREADPOOL* ThreadPool @@ -582,8 +666,8 @@ InitializeWorkspace_CompInt8( { MLAS_UNREFERENCED_PARAMETER(N); - const auto QuantizeARow = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARow_CompInt8; - const auto QuantizeARow2 = GetMlasPlatform().SQNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; + const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8; + const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); @@ -623,61 +707,153 @@ InitializeWorkspace_CompInt8( } } -struct Operations { - InitializeWorkspaceFn* InitializeWorkspace = nullptr; - SQNBitGemmFn* SQNBitGemm = nullptr; -}; +template <> +void +InitializeWorkspace_CompInt8( + size_t M, + size_t N, + size_t K, + size_t BatchN, + size_t BlkLen, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +) { + MLAS_UNREFERENCED_PARAMETER(M); + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(BatchN); + MLAS_UNREFERENCED_PARAMETER(BlkLen); + MLAS_UNREFERENCED_PARAMETER(DataParams); + MLAS_UNREFERENCED_PARAMETER(Workspace); + MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspaceStride); + MLAS_UNREFERENCED_PARAMETER(ThreadPool); +} + +template +using InitializeWorkspaceFn = std::function* DataParams, + void* Workspace, + size_t PerGemmWorkspaceStride, + MLAS_THREADPOOL* ThreadPool +)>; -constexpr auto OperationMap = []() { - std::array ops; +template +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant); - ops[SQNBitGemmVariant_BitWidth4_CompFp32].SQNBitGemm = SQ4BitGemm_CompFp32; +template <> +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant) +{ + switch (variant) { + case SQNBitGemmVariant_BitWidth4_CompInt8: + return InitializeWorkspace_CompInt8; + default: + return nullptr; + } +} + +template <> +InitializeWorkspaceFn +GetInitializeWorkspace(QNBitGemmVariant variant) +{ + switch (variant) { + case HQNBitGemmVariant_BitWidth4_CompInt8: + return InitializeWorkspace_CompInt8; + default: + return nullptr; + } +} + +template +using QNBitGemmFn = std::function* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +)>; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].InitializeWorkspace = InitializeWorkspace_CompInt8; - ops[SQNBitGemmVariant_BitWidth4_CompInt8].SQNBitGemm = SQ4BitGemm_CompInt8; +template +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant); - return ops; -}(); +template <> +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant) +{ + switch (variant) { + case SQNBitGemmVariant_BitWidth4_CompFp32: + return SQ4BitGemm_CompFp32; + case SQNBitGemmVariant_BitWidth4_CompInt8: + return SQ4BitGemm_CompInt8; + default: + return nullptr; + } +} + +template <> +QNBitGemmFn +GetQNBitGemm(QNBitGemmVariant variant) +{ + switch (variant) { + case HQNBitGemmVariant_BitWidth4_CompFp16: + return HQ4BitGemm_CompFp16; + default: + return nullptr; + } +} } // namespace +template void MLASCALL -MlasSQNBitGemmBatch( +MlasQNBitGemmBatch( const size_t M, const size_t N, const size_t K, const size_t BatchN, const size_t BlkBitWidth, const size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - const MLAS_SQNBIT_GEMM_DATA_PARAMS* DataParams, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, MLAS_THREADPOOL* ThreadPool ) { - const auto Variant = GetSQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); + const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); assert(Variant != SQNBitGemmVariantInvalid); // // Ensure `Workspace` has correct alignment. // if (Workspace != nullptr) { - const size_t Alignment = SQNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); + const size_t Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); const uintptr_t WorkspaceAddress = reinterpret_cast(Workspace); Workspace = reinterpret_cast( (WorkspaceAddress + Alignment - 1) & (~(Alignment - 1)) ); } - const size_t PerGemmWorkspaceStride = SQNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const size_t PerGemmWorkspaceStride = QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); - if (const auto InitializeWorkspaceOperation = OperationMap[Variant].InitializeWorkspace; + if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace(Variant); InitializeWorkspaceOperation != nullptr) { InitializeWorkspaceOperation( M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool ); } - const auto ComputeOperation = OperationMap[Variant].SQNBitGemm; + const auto ComputeOperation = GetQNBitGemm(Variant); const size_t BlockCountK = MlasDivRoundup(K, BlkLen); @@ -686,11 +862,11 @@ MlasSQNBitGemmBatch( const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else { @@ -757,11 +933,11 @@ MlasSQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == CompInt8 && GetMlasPlatform().SQNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); - const_cast(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; - const_cast(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; - const_cast(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); @@ -770,3 +946,33 @@ MlasSQNBitGemmBatch( } }); } + +template +void MLASCALL +MlasQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); + +template +void MLASCALL +MlasQNBitGemmBatch( + const size_t M, + const size_t N, + const size_t K, + const size_t BatchN, + const size_t BlkBitWidth, + const size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, + void* Workspace, + MLAS_THREADPOOL* ThreadPool +); diff --git a/src/lib/sqnbitgemm.h b/src/lib/qnbitgemm.h similarity index 71% rename from src/lib/sqnbitgemm.h rename to src/lib/qnbitgemm.h index 2da336c..eb3d0b4 100644 --- a/src/lib/sqnbitgemm.h +++ b/src/lib/qnbitgemm.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm.h + qnbitgemm.h Abstract: @@ -46,24 +46,25 @@ MlasAlignAddress(void* addr, const size_t alignment) return addr; } +template struct PackedQuantBDataStruct { PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) { - // TODO: duplicate code from SQ4BitGemmPackQuantBDataSize + // TODO: duplicate code from Q4BitGemmPackQuantBDataSize constexpr size_t BlkBitWidth = 4; const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); // _mm256_load_si256 requires alignment on a 32-byte boundary PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); - QuantBBlkSum = (float*)(PackedQuantBData + PackedQuantBDataSize); - QuantBBlkSum = (float*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); - PackedQuantBScale = (float*)((std::byte*)QuantBBlkSum + BlkSumSize); + QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize); + QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); } std::byte* PackedQuantBData; - float* PackedQuantBScale; - float* QuantBBlkSum; + T* PackedQuantBScale; + T* QuantBBlkSum; void* QuantBWorkspace_; size_t N_, BlockCountK_, BlkLen_; @@ -84,44 +85,45 @@ MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) // Kernel dispatch structure. // -struct MLAS_SQNBIT_GEMM_DISPATCH { +struct MLAS_QNBIT_GEMM_DISPATCH { // // Quantized B data packing function prototypes. // - /** Gets size of packed quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBDataSize(). */ - typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)( + /** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ + typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPackQuantBDataSize_Fn* SQ4BitGemmPackQuantBDataSize = nullptr; + Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; - /** Packs quantized B data containing 4-bit integers. See MlasSQNBitGemmPackQuantBData(). */ - typedef void(SQ4BitGemmPackQuantBData_Fn)( + /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ + typedef void(Q4BitGemmPackQuantBData_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool ); - SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr; typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ); @@ -141,15 +143,15 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)( + typedef size_t(Q4BitGemmPerGemmWorkspaceSize_Fn)( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceSize_Fn* SQ4BitGemmPerGemmWorkspaceSize = nullptr; + Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; /** * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. @@ -157,15 +159,15 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)( + typedef size_t(Q4BitGemmPerGemmWorkspaceAlignment_Fn)( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - SQ4BitGemmPerGemmWorkspaceAlignment_Fn* SQ4BitGemmPerGemmWorkspaceAlignment = nullptr; + Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr; // - // CompFp32 kernel function prototypes. + // SQNBIT_CompFp32 kernel function prototypes. // /** @@ -228,10 +230,41 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { size_t BlockStrideQuantB ); - Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr; + Q4BitBlkDequantBForSgemm_CompFp32_Fn* SQ4BitBlkDequantBForSgemm_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * It should have enough space for + * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen + * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are + * useful, but the kernel implementation can be simplified with the extra space. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp16_Fn)( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp16_Fn* HQ4BitBlkDequantBForHgemm_CompFp16 = nullptr; // - // CompInt8 kernel function prototypes. + // SQNBIT_CompInt8 kernel function prototypes. // /** @@ -337,4 +370,35 @@ struct MLAS_SQNBIT_GEMM_DISPATCH { float* AScaledGroupSum // scale_k * Sum_blklen(a_i) ); QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; + + /** + * @brief Multiply fp16 matrix A rows with fp16 matrix B columns. + * Results are written to fp16 matrix C. + * If bias is provided, the bias are added to the result. + * + * @param A first row of the A matrix segment. Row major. + * @param B first column of the B matrix segment. Column major. + * @param Bias the bias at the target column. Optional. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param K the number of columns of A matrix and rows of B matrix. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + */ + typedef void(HQ4BitGemmKernel_CompFp16_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc + ); + + HQ4BitGemmKernel_CompFp16_Fn* HQ4BitGemmKernel_CompFp16 = nullptr; }; diff --git a/src/lib/sqnbitgemm_kernel_neon.cpp b/src/lib/qnbitgemm_kernel_neon.cpp similarity index 74% rename from src/lib/sqnbitgemm_kernel_neon.cpp rename to src/lib/qnbitgemm_kernel_neon.cpp index 3f32cc6..d05de64 100644 --- a/src/lib/sqnbitgemm_kernel_neon.cpp +++ b/src/lib/qnbitgemm_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.cpp + qnbitgemm_kernel_neon.cpp Abstract: @@ -19,8 +19,8 @@ Module Name: #include -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" namespace sqnbitgemm_neon @@ -34,11 +34,11 @@ namespace // size_t -SQ4BitGemmPackQuantBDataSize( +Q4BitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType @@ -55,7 +55,7 @@ SQ4BitGemmPackQuantBData( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool @@ -69,7 +69,7 @@ SQ4BitGemmPackQuantBData( const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t Iterations = N * BlockCountK; // one iteration per block - const size_t SubBlkLen = (ComputeType == CompInt8) + const size_t SubBlkLen = (ComputeType == SQNBIT_CompInt8) ? ((BlkLen == 16) ? 16 : 32) : 16; @@ -126,18 +126,18 @@ SQ4BitGemmPackQuantBData( // size_t -SQ4BitGemmPerGemmWorkspaceSize( +Q4BitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(N); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); @@ -150,15 +150,15 @@ SQ4BitGemmPerGemmWorkspaceSize( } size_t -SQ4BitGemmPerGemmWorkspaceAlignment( +Q4BitGemmPerGemmWorkspaceAlignment( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(BlkLen); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { return Q8BlkAlignment(); } default: { @@ -175,20 +175,27 @@ SQ4BitGemmPerGemmWorkspaceAlignment( // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; - d.SQ4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32; - - d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32; + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot()) { + d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + } d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16; + d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16; + d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; +#endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 + return d; }(); diff --git a/src/lib/sqnbitgemm_kernel_neon.h b/src/lib/qnbitgemm_kernel_neon.h similarity index 69% rename from src/lib/sqnbitgemm_kernel_neon.h rename to src/lib/qnbitgemm_kernel_neon.h index ef9345d..ccadd24 100644 --- a/src/lib/sqnbitgemm_kernel_neon.h +++ b/src/lib/qnbitgemm_kernel_neon.h @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - sqnbitgemm_kernel_neon.h + qnbitgemm_kernel_neon.h Abstract: @@ -30,13 +30,13 @@ namespace sqnbitgemm_neon // // Function declarations for SQNBitGemm ARM NEON kernel entry points. -// Refer to the prototypes in sqnbitgemm.h for documentation. +// Refer to the prototypes in qnbitgemm.h for documentation. // These are declared here so they can be used to initialize the -// MLAS_SQNBIT_GEMM_DISPATCH structure and also be implemented in separate +// MLAS_QNBIT_GEMM_DISPATCH structure and also be implemented in separate // files. // -// CompFp32 declarations +// SQNBIT_CompFp32 declarations void SQ4BitGemmM1Kernel_CompFp32( @@ -53,7 +53,7 @@ SQ4BitGemmM1Kernel_CompFp32( ); void -Q4BitBlkDequantBForSgemm_CompFp32( +SQ4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, const std::byte* QuantBData, @@ -64,7 +64,48 @@ Q4BitBlkDequantBForSgemm_CompFp32( size_t BlockCountK ); -// CompInt8 declarations +// HQNBIT_CompFp16 declarations +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) +void +HQ4BitGemmPackQuantBData_CompFp16( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +); + +void +HQ4BitBlkDequantBForHgemm_CompFp16( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t K, + size_t BlockCountK +); + +void +HQ4BitGemmKernel_CompFp16( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc +); + +#endif // !(defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64)) + +// SQNBIT_CompInt8 declarations void QuantizeARow_CompInt8( diff --git a/src/lib/scalar/SgemmKernelScalar.cpp b/src/lib/scalar/SgemmKernelScalar.cpp index 6272925..cbec5d8 100644 --- a/src/lib/scalar/SgemmKernelScalar.cpp +++ b/src/lib/scalar/SgemmKernelScalar.cpp @@ -83,6 +83,8 @@ Return Value: #endif + int countb = 0; + do { float BElements00; @@ -116,6 +118,7 @@ Return Value: // const float* a = A; + const float* b = B; size_t k = CountK; while (k >= 2) { @@ -128,10 +131,10 @@ Return Value: Row1AElements1 = a[lda + 1]; } - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; + BElements00 = b[0]; + BElements01 = b[1]; + BElements02 = b[2]; + BElements03 = b[3]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; @@ -144,10 +147,10 @@ Return Value: Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; } - BElements00 = B[4]; - BElements01 = B[5]; - BElements02 = B[6]; - BElements03 = B[7]; + BElements00 = b[16]; + BElements01 = b[17]; + BElements02 = b[18]; + BElements03 = b[19]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements1; Row0Block01 = Row0Block01 + BElements01 * Row0AElements1; Row0Block02 = Row0Block02 + BElements02 * Row0AElements1; @@ -161,7 +164,7 @@ Return Value: } a += 2; - B += 8; + b += 32; k -= 2; } @@ -173,10 +176,10 @@ Return Value: Row1AElements0 = a[lda]; } - BElements00 = B[0]; - BElements01 = B[1]; - BElements02 = B[2]; - BElements03 = B[3]; + BElements00 = b[0]; + BElements01 = b[1]; + BElements02 = b[2]; + BElements03 = b[3]; Row0Block00 = Row0Block00 + BElements00 * Row0AElements0; Row0Block01 = Row0Block01 + BElements01 * Row0AElements0; Row0Block02 = Row0Block02 + BElements02 * Row0AElements0; @@ -188,8 +191,6 @@ Return Value: Row1Block02 = Row1Block02 + BElements02 * Row1AElements0; Row1Block03 = Row1Block03 + BElements03 * Row1AElements0; } - - B += 4; } // @@ -295,9 +296,14 @@ Return Value: break; } + B += 4; C += 4; CountN -= 4; + countb = (countb + 1) % 4; + if (countb == 0) { + B += CountK * 16 - 16; + } } while (CountN > 0); return ProcessTwoRows ? 2 : 1; diff --git a/src/lib/sgemm.cpp b/src/lib/sgemm.cpp index 4d7a1ce..f8b25fb 100644 --- a/src/lib/sgemm.cpp +++ b/src/lib/sgemm.cpp @@ -1061,7 +1061,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) +#if (defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)) && !defined(FORCE_GENERIC_ALGORITHMS) RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { @@ -1158,6 +1158,7 @@ Return Value: if (M == 1 && TransA == CblasNoTrans && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { +#if !defined(FORCE_GENERIC_ALGORITHMS) #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine; @@ -1181,6 +1182,7 @@ Return Value: } #endif +#endif // !defined(FORCE_GENERIC_ALGORITHMS) } @@ -1193,7 +1195,7 @@ Return Value: if (N == 1 && ldb == 1 && ldc == 1 && alpha == 1.0f && (beta == 0.0f || beta == 1.0f)) { -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) && !defined(FORCE_GENERIC_ALGORITHMS) MLAS_SGEMM_KERNEL_M1_ROUTINE* SgemmKernelM1Routine; diff --git a/src/lib/sqnbitgemm_kernel_avx2.cpp b/src/lib/sqnbitgemm_kernel_avx2.cpp index abf8060..5d42f22 100644 --- a/src/lib/sqnbitgemm_kernel_avx2.cpp +++ b/src/lib/sqnbitgemm_kernel_avx2.cpp @@ -20,7 +20,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" #include "sqnbitgemm_kernel_avx2_int8_blklen16.h" @@ -1307,12 +1307,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -1320,9 +1320,9 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - // TODO: always use SubBlkLen = 64 in CompInt8 + // TODO: always use SubBlkLen = 64 in SQNBIT_CompInt8 size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (BlkLen == 32 && ComputeType == CompInt8) { + if (BlkLen == 32 && ComputeType == SQNBIT_CompInt8) { SubBlkLen = 64; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); @@ -1331,18 +1331,18 @@ SQ4BitGemmPackQuantBDataAndBlkSum( // // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; @@ -1350,18 +1350,18 @@ const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { return d; }(); -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; diff --git a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h index 80d6780..445ead3 100644 --- a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h +++ b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index af6f520..5dab809 100644 --- a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index 174ebc5..d4b89bd 100644 --- a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" template @@ -117,7 +117,7 @@ accumulate_blklen64_r1c1blk1_avx2( __m256 scale_b_8_ps = _mm256_broadcast_ss(scale_b); acc0 = _mm256_fmadd_ps(sum_ps, _mm256_mul_ps(scale_a_8_ps, scale_b_8_ps), acc0); -#if !defined(__GNUC__) || (__GNUC__ > 9) +#if !defined(__GNUC__) || (__GNUC__ > 10) } #endif } diff --git a/src/lib/sqnbitgemm_kernel_avx512.cpp b/src/lib/sqnbitgemm_kernel_avx512.cpp index 127279a..b4e25d4 100644 --- a/src/lib/sqnbitgemm_kernel_avx512.cpp +++ b/src/lib/sqnbitgemm_kernel_avx512.cpp @@ -18,8 +18,8 @@ Module Name: #include #include #include -#include -#include "sqnbitgemm.h" + +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" #include "sqnbitgemm_kernel_avx512_int8_blklen16.h" @@ -28,7 +28,7 @@ Module Name: #include "sqnbitgemm_kernel_avx512_int8_blklen128.h" // -// CompFp32 kernel implementation. +// SQNBIT_CompFp32 kernel implementation. // #include "sqnbitgemm_kernel_avx_common_fp32.h" @@ -151,7 +151,7 @@ SQ4BitGemmM1Kernel_CompFp32_avx512( } // -// CompInt8 kernel implementation. +// SQNBIT_CompInt8 kernel implementation. // MLAS_FORCEINLINE @@ -332,12 +332,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -346,24 +346,24 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); } -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8.h b/src/lib/sqnbitgemm_kernel_avx512_int8.h index 7d9dc36..8f1ea66 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" @@ -81,7 +81,7 @@ accumulate_blklen32_r2c1blk2_avx2( _mm256_sign_epi8(bv1_32_epi8, bv1_32_epi8), _mm256_sign_epi8(av01_32_epi8, bv1_32_epi8) ); const __m256i sum_16_epi16 = _mm256_hadd_epi16(dot0_16_epi16, dot1_16_epi16); - + __m256i one_16_epi16 = _mm256_srli_epi16(_mm256_cmpeq_epi16(bv0_32_epi8, bv0_32_epi8), 15); const __m256i sum_8_epi32 = _mm256_madd_epi16(one_16_epi16, sum_16_epi16); const __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); @@ -143,7 +143,7 @@ accumulate_blklen32_r2c1blk2_avx2( // const __m256i bv1 = _mm256_and_si256(_mm256_srli_epi16(bv_packed, 4), low_mask); // 16, 17,...30, 31, 48, 49,...,62, 63 __m256i bv1_32_epi8 = _mm256_srli_epi16(_mm256_sub_epi8(bv_packed, bv0_32_epi8), 4); // 16, 17,...30, 31, 48, 49,...,62, 63 - //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); + //__m256i bv0_32_epi8 = _mm256_set_m128i(_mm256_castsi256_si128(bv1), _mm256_castsi256_si128(bv0)); //// This (the second line below) saves one _mm256_extracti128_si256 against using _mm256_set_m128i. ////__m256i bv1_32_epi8 = _mm256_set_m128i(_mm256_extracti128_si256(bv1, 1), _mm256_extracti128_si256(bv0, 1)); @@ -184,7 +184,7 @@ accumulate_blklen32_r2c1blk1_avx2( const __m128i bv_packed0 = _mm_loadu_si128(reinterpret_cast(QuantBDataPtr)); __m256i bv_32_epi8 = _mm256_set_m128i(_mm_srli_epi16(bv_packed0, 4), bv_packed0); bv_32_epi8 = _mm256_and_si256(_mm256_set1_epi8(0x0F), bv_32_epi8); - + const int8_t zp = get_zp(true, QuantBZeroPointPtr); const __m256i bzp = _mm256_set1_epi8(zp); bv_32_epi8 = _mm256_sub_epi8(bv_32_epi8, bzp); @@ -435,7 +435,7 @@ Q4Int8Gemm2x4BlkLen32Avx2( } } -template +template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, const std::byte* QuantBData, @@ -877,7 +877,7 @@ MLAS_FORCEINLINE QuantBZeroPoint + multipleCols * StrideQuantBZeroPoint, C + multipleRows * ldc + multipleCols, remainingRows, - remainingCols, + remainingCols, BlockCountK, Bias ? Bias + multipleCols : nullptr, lda, diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h index 60a8873..d79554c 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h index bb14bab..0306488 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx2_int8_blklen16.h" #include "sqnbitgemm_kernel_avx512_int8_blklen32.h" diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index e9df6b9..3b1096a 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 2a65ac4..72ce28d 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" static MLAS_FORCEINLINE __m256 diff --git a/src/lib/sqnbitgemm_kernel_avx512vnni.cpp b/src/lib/sqnbitgemm_kernel_avx512vnni.cpp index 6a5c011..a4468bb 100644 --- a/src/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/src/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -19,7 +19,7 @@ Module Name: #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_fp32.h" #include "sqnbitgemm_kernel_avx_common_int8.h" @@ -314,12 +314,12 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -328,7 +328,7 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( const size_t BlockCountK = MlasDivRoundup(K, BlkLen); size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); @@ -337,18 +337,18 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( // // Kernel dispatch structure definition. // -const MLAS_SQNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { - MLAS_SQNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.SQ4BitGemmPackQuantBDataSize = SQ4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; - d.SQ4BitGemmPerGemmWorkspaceSize = SQ4BitGemmPerGemmWorkspaceSize; - d.SQ4BitGemmPerGemmWorkspaceAlignment = SQ4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; + d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; - d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; diff --git a/src/lib/sqnbitgemm_kernel_avx_common.h b/src/lib/sqnbitgemm_kernel_avx_common.h index 177f551..b0367b7 100644 --- a/src/lib/sqnbitgemm_kernel_avx_common.h +++ b/src/lib/sqnbitgemm_kernel_avx_common.h @@ -1,5 +1,5 @@ #pragma once -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" // @@ -7,16 +7,16 @@ // static size_t -SQ4BitGemmPackQuantBDataSize( +Q4BitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { constexpr size_t BlkBitWidth = 4; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t ScaleSize = N * BlockCountK * sizeof(float); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); @@ -39,7 +39,7 @@ SQ4BitGemmPackQuantBData( size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, const std::byte* QuantBDataBegin, std::byte* PackedQuantBDataBegin, MLAS_THREADPOOL* ThreadPool @@ -304,7 +304,7 @@ PackQuantBDataAndBlkSum( const float* QuantBScaleBegin, bool has_zp_input, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& packed_quant_b, MLAS_THREADPOOL* ThreadPool ) { @@ -326,18 +326,18 @@ PackQuantBDataAndBlkSum( // static size_t -SQ4BitGemmPerGemmWorkspaceSize( +Q4BitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(N); switch(ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); // QuantData + Scale + BlkSum @@ -351,15 +351,15 @@ SQ4BitGemmPerGemmWorkspaceSize( } static size_t -SQ4BitGemmPerGemmWorkspaceAlignment( +Q4BitGemmPerGemmWorkspaceAlignment( size_t BlkLen, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(BlkLen); switch (ComputeType) { - case CompInt8: { + case SQNBIT_CompInt8: { return Q8BlkAlignment(); } default: { diff --git a/src/lib/sqnbitgemm_kernel_avx_common_fp32.h b/src/lib/sqnbitgemm_kernel_avx_common_fp32.h index 5cd380e..d15cfc7 100644 --- a/src/lib/sqnbitgemm_kernel_avx_common_fp32.h +++ b/src/lib/sqnbitgemm_kernel_avx_common_fp32.h @@ -1,5 +1,5 @@ #pragma once -#include "sqnbitgemm.h" +#include "qnbitgemm.h" template MLAS_FORCEINLINE diff --git a/src/lib/sqnbitgemm_kernel_avx_common_int8.h b/src/lib/sqnbitgemm_kernel_avx_common_int8.h index 895ce6c..2e96082 100644 --- a/src/lib/sqnbitgemm_kernel_avx_common_int8.h +++ b/src/lib/sqnbitgemm_kernel_avx_common_int8.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_q8_block.h" diff --git a/src/lib/sqnbitgemm_kernel_neon_fp32.cpp b/src/lib/sqnbitgemm_kernel_neon_fp32.cpp index 12ddc42..31a499b 100644 --- a/src/lib/sqnbitgemm_kernel_neon_fp32.cpp +++ b/src/lib/sqnbitgemm_kernel_neon_fp32.cpp @@ -13,7 +13,7 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompFp32. + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompFp32. --*/ @@ -21,8 +21,8 @@ Module Name: #include -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" namespace sqnbitgemm_neon { @@ -31,7 +31,7 @@ namespace { // -// CompFp32 kernel implementation. +// SQNBIT_CompFp32 kernel implementation. // MLAS_FORCEINLINE void @@ -608,7 +608,7 @@ Q4BitBlkDequantBForSgemm_CompFp32_Impl( } // namespace void -Q4BitBlkDequantBForSgemm_CompFp32( +SQ4BitBlkDequantBForSgemm_CompFp32( size_t BlkLen, float* FpData, const std::byte* QuantBData, diff --git a/src/lib/sqnbitgemm_kernel_neon_int8.cpp b/src/lib/sqnbitgemm_kernel_neon_int8.cpp index 0d62ea3..73beb06 100644 --- a/src/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/src/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -13,7 +13,7 @@ Module Name: This module implements the float/quantized n-bit integer matrix multiplication kernels for ARM NEON specific to input type T1 as float32 and - MLAS_SQNBIT_GEMM_COMPUTE_TYPE CompInt8. + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompInt8. --*/ @@ -21,15 +21,15 @@ Module Name: #include -#include "sqnbitgemm.h" -#include "sqnbitgemm_kernel_neon.h" +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" namespace sqnbitgemm_neon { // -// CompInt8 kernel implementation. +// SQNBIT_CompInt8 kernel implementation. // namespace diff --git a/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h b/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h index 45c3963..941b884 100644 --- a/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h +++ b/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" template diff --git a/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h b/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h index e9c3812..ed78dfa 100644 --- a/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h +++ b/src/lib/sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h @@ -3,7 +3,7 @@ #include #include -#include "sqnbitgemm.h" +#include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" diff --git a/src/ort_include/core/common/logging/logging.h b/src/ort_include/core/common/logging/logging.h index 571262a..3ad27d3 100644 --- a/src/ort_include/core/common/logging/logging.h +++ b/src/ort_include/core/common/logging/logging.h @@ -17,7 +17,6 @@ #include "core/common/logging/macros.h" #include "core/common/logging/severity.h" #include "core/common/logging/sink_types.h" -#include "core/platform/ort_mutex.h" /* @@ -258,7 +257,7 @@ class LoggingManager final { std::unique_ptr sink_; #ifdef _WIN32 - mutable OrtMutex sink_mutex_; + mutable std::mutex sink_mutex_; #endif Severity default_min_severity_; const bool default_filter_user_data_; diff --git a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h index 26237b3..38a4c59 100644 --- a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h +++ b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h @@ -50,7 +50,6 @@ #include "core/common/denormal.h" #include "core/common/inlined_containers_fwd.h" #include "core/common/spin_pause.h" -#include "core/platform/ort_mutex.h" #include "core/platform/ort_spin_lock.h" // ORT thread pool overview @@ -459,7 +458,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif unsigned back = back_.load(std::memory_order_relaxed); Elem& e = array_[(back - 1) & kMask]; @@ -483,7 +482,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif unsigned back = back_.load(std::memory_order_relaxed); w_idx = (back - 1) & kMask; @@ -508,7 +507,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif unsigned back; Elem* e; @@ -554,7 +553,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE std::lock_guard mtx(spin_lock_); #else - std::lock_guard lock(mutex_); + std::lock_guard lock(mutex_); #endif Elem& e = array_[w_idx]; ElemState s = e.state.load(std::memory_order_relaxed); @@ -630,7 +629,7 @@ class RunQueue { #ifdef USE_LOCK_FREE_QUEUE OrtSpinLock spin_lock_; #else - OrtMutex mutex_; + std::mutex mutex_; #endif // Low log(kSize) + 1 bits in front_ and back_ contain rolling index of @@ -1439,7 +1438,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ThreadStatus seen = GetStatus(); if (seen == ThreadStatus::Blocking || seen == ThreadStatus::Blocked) { - std::unique_lock lk(mutex); + std::unique_lock lk(mutex); // Blocking state exists only transiently during the SetBlock() method // while holding the lock. We may observe it at the start of this // function, but after acquiring the lock then the target thread @@ -1469,7 +1468,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter void SetBlocked(std::function should_block, std::function post_block) { - std::unique_lock lk(mutex); + std::unique_lock lk(mutex); assert(GetStatus() == ThreadStatus::Spinning); status.store(ThreadStatus::Blocking, std::memory_order_relaxed); if (should_block()) { @@ -1484,8 +1483,8 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter private: std::atomic status{ThreadStatus::Spinning}; - OrtMutex mutex; - OrtCondVar cv; + std::mutex mutex; + std::condition_variable cv; }; Environment& env_; diff --git a/src/ort_include/core/platform/ort_mutex.h b/src/ort_include/core/platform/ort_mutex.h deleted file mode 100644 index 5028b03..0000000 --- a/src/ort_include/core/platform/ort_mutex.h +++ /dev/null @@ -1,9 +0,0 @@ -#pragma once - -#include -#include - -namespace onnxruntime{ - using OrtMutex = std::mutex; - using OrtCondVar = std::condition_variable; -} \ No newline at end of file diff --git a/tests/bench/CMakeLists.txt b/tests/bench/CMakeLists.txt index 8bf535d..75c5836 100644 --- a/tests/bench/CMakeLists.txt +++ b/tests/bench/CMakeLists.txt @@ -1,5 +1,5 @@ include_directories(../../src/lib) -add_executable(onnxruntime_mlas_benchmark bench_computesoftmax.cpp bench_main.cpp bench_q4dq.cpp bench_q4gemm.cpp bench_qgemm.cpp bench_sconv.cpp bench_sgemm.cpp bench_sqnbitgemm.cpp bench_symm_qgemm.cpp bench_util.cpp) +add_executable(onnxruntime_mlas_benchmark bench_computesoftmax.cpp bench_main.cpp bench_q4dq.cpp bench_q4gemm.cpp bench_qgemm.cpp bench_sconv.cpp bench_sgemm.cpp bench_qnbitgemm.cpp bench_symm_qgemm.cpp bench_util.cpp) target_link_libraries(onnxruntime_mlas_benchmark PRIVATE benchmark::benchmark ${ONNXRUNTIME_MLAS_LIBS} ) if(NOT MLAS_NO_ONNXRUNTIME) target_link_libraries(onnxruntime_mlas_benchmark PRIVATE onnxruntime_common) diff --git a/tests/bench/bench_sqnbitgemm.cpp b/tests/bench/bench_qnbitgemm.cpp similarity index 53% rename from tests/bench/bench_sqnbitgemm.cpp rename to tests/bench/bench_qnbitgemm.cpp index 71db7d8..64d2298 100644 --- a/tests/bench/bench_sqnbitgemm.cpp +++ b/tests/bench/bench_qnbitgemm.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include "benchmark/benchmark.h" @@ -16,16 +17,16 @@ #include "core/util/thread_utils.h" #include "core/platform/env_var_utils.h" -template -void RunSQNBitGemmBenchmark(size_t BlkLen, - size_t M, size_t N, size_t K, - size_t Threads, - bool Symmetric, - bool HasBias, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, - benchmark::State& state) { - if (!MlasIsSQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { - state.SkipWithMessage("SQNBitGemm is not available with the given configuration on the current machine."); +template +void RunQNBitGemmBenchmark(size_t BlkLen, + size_t M, size_t N, size_t K, + size_t Threads, + bool Symmetric, + bool HasBias, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + benchmark::State& state) { + if (!MlasIsQNBitGemmAvailable(BlkBitWidth, BlkLen, ComputeType)) { + state.SkipWithMessage("QNBitGemm is not available with the given configuration on the current machine."); return; } @@ -43,40 +44,40 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); - const auto A = RandomVectorUniform(M * K, -1.0f, 1.0f); - const auto B = RandomVectorUniform(K * N, -1.0f, 1.0f); + const auto A = RandomVectorUniform(M * K, AType(-1.0f), AType(1.0f)); + const auto B = RandomVectorUniform(K * N, AType(-1.0f), AType(1.0f)); - const auto Bias = HasBias ? RandomVectorUniform(N, -1.0f, 1.0f) : std::vector(); + const auto Bias = HasBias ? RandomVectorUniform(N, AType(-1.0f), AType(1.0f)) : std::vector(); - std::vector C(static_cast(M * N)); + std::vector C(static_cast(M * N)); std::vector QuantBData(QuantBDataSizeInBytes); - std::vector QuantBScale(QuantBScaleSize); + std::vector QuantBScale(QuantBScaleSize); std::vector QuantBZeroPoint(Symmetric ? 0 : QuantBZeroPointSizeInBytes); bool has_zp_input = !Symmetric; - MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), + MlasQuantizeBlockwise(QuantBData.data(), QuantBScale.data(), Symmetric ? nullptr : QuantBZeroPoint.data(), B.data(), static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), static_cast(N), tp.get()); std::unique_ptr Workspace; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = std::make_unique(WorkspaceSize); } std::unique_ptr PackedQuantBData; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), - QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), - tp.get()); + MlasQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), + QuantBScale.data(), has_zp_input, QuantBZeroPoint.data(), + tp.get()); } - MLAS_SQNBIT_GEMM_DATA_PARAMS params{}; + MLAS_QNBIT_GEMM_DATA_PARAMS params{}; params.A = A.data(); params.lda = K; if (PackedQuantBData != nullptr) @@ -92,15 +93,15 @@ void RunSQNBitGemmBenchmark(size_t BlkLen, params.ldc = N; // warm up run - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); for (auto _ : state) { - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace.get(), tp.get()); } } -template -void SQNBITGEMM(benchmark::State& state) { +template +void QNBITGEMM(benchmark::State& state) { using onnxruntime::narrow; const auto BlkLen = narrow(state.range(0)); @@ -110,46 +111,50 @@ void SQNBITGEMM(benchmark::State& state) { const auto Threads = narrow(state.range(4)); const auto Symmetric = narrow(state.range(5)); const bool HasBias = narrow(state.range(6)); - const auto ComputeType = static_cast(state.range(7)); + const auto ComputeType = static_cast(state.range(7)); - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); + RunQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, ComputeType, state); } -static void SQNBitGemmArgs(benchmark::internal::Benchmark* b) { +template +static void QNBitGemmArgs(benchmark::internal::Benchmark* b) { b->ArgNames({"BlkLen", "M", "N", "K", "Threads", "Symmetric", "HasBias", "ComputeType"}); b->ArgsProduct({ - {128}, // BlkLen - {1}, // M - {4096, 11008}, // N - {4096, 11008}, // K - {1, 8}, // Threads - {int64_t{false}, int64_t{true}}, // Symmetric - {int64_t{false}, int64_t{true}}, // HasBias - {int64_t{CompFp32}, int64_t{CompInt8}}, // ComputeType + {128}, // BlkLen + {1, 4096}, // M + {4096, 11008}, // N + {4096, 11008}, // K + {1, 8}, // Threads + {int64_t{false}, int64_t{true}}, // Symmetric + {int64_t{false}, int64_t{true}}, // HasBias + std::is_same_v + ? std::vector{int64_t{HQNBIT_CompFp16}} + : std::vector{int64_t{SQNBIT_CompFp32}, int64_t{SQNBIT_CompInt8}}, // ComputeType }); } -BENCHMARK(SQNBITGEMM<4>)->Apply(SQNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); // This test gets benchmark arguments from environment variables. -template -void SQNBITGEMM_ENV(benchmark::State& state) { +template +void QNBITGEMM_ENV(benchmark::State& state) { using onnxruntime::ParseEnvironmentVariableWithDefault; - const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_BLKLEN", 32); - const auto M = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_M", 1); - const auto N = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_N", 4096); - const auto K = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_K", 4096); - const auto Threads = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_THREADS", 1); - const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_SYMMETRIC", true); - const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_HAS_BIAS", false); - const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_SQNBITGEMM_COMPUTE_TYPE", - static_cast(CompFp32)); + const auto BlkLen = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_BLKLEN", 32); + const auto M = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_M", 1); + const auto N = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_N", 4096); + const auto K = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_K", 4096); + const auto Threads = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_THREADS", 1); + const auto Symmetric = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_SYMMETRIC", true); + const auto HasBias = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_HAS_BIAS", false); + const auto ComputeType = ParseEnvironmentVariableWithDefault("ORT_QNBITGEMM_COMPUTE_TYPE", + static_cast(SQNBIT_CompFp32)); - RunSQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, - static_cast(ComputeType), - state); + RunQNBitGemmBenchmark(BlkLen, M, N, K, Threads, Symmetric, HasBias, + static_cast(ComputeType), + state); std::ostringstream s; s << "BlkBitWidth:" << BlkBitWidth << "/BlkLen:" << BlkLen @@ -159,4 +164,4 @@ void SQNBITGEMM_ENV(benchmark::State& state) { state.SetLabel(s.str()); } -BENCHMARK(SQNBITGEMM_ENV<4>)->UseRealTime(); +BENCHMARK(QNBITGEMM_ENV)->UseRealTime(); diff --git a/tests/bench/bench_util.h b/tests/bench/bench_util.h index f96dd5c..e3abd7b 100644 --- a/tests/bench/bench_util.h +++ b/tests/bench/bench_util.h @@ -8,8 +8,12 @@ #include #include +#include "core/framework/float16.h" +#include "mlas.h" + template -std::vector RandomVectorUniform( +typename std::enable_if_t, std::vector> +RandomVectorUniform( size_t N, ElementType min_value = std::numeric_limits::lowest(), ElementType max_value = std::numeric_limits::max()) { @@ -26,6 +30,25 @@ std::vector RandomVectorUniform( return r; } +template +typename std::enable_if_t, std::vector> +RandomVectorUniform( + size_t N, + ElementType min_value, + ElementType max_value) { + if (min_value.ToFloat() >= max_value.ToFloat()) { + return std::vector(N, min_value); + } + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(min_value.ToFloat(), max_value.ToFloat()); + + std::vector r(N); + for (size_t i = 0; i < N; i++) { + r[i] = ElementType(distribution(generator)); + } + return r; +} + std::vector RandomVectorUniform(std::vector shape, float min_value, float max_value); std::vector BenchArgsVector(benchmark::State& state, size_t& start, size_t count); diff --git a/tests/unittest/test_hqnbitgemm_neon.cpp b/tests/unittest/test_hqnbitgemm_neon.cpp new file mode 100644 index 0000000..b598c20 --- /dev/null +++ b/tests/unittest/test_hqnbitgemm_neon.cpp @@ -0,0 +1,501 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_hqnbitgemm_neon.cpp + +Abstract: + + Tests for MLAS n-bit int block quantized GEMM on ARM CPU with input A type T1 fp16. + +--*/ + +#include +#include + +#include "test_util.h" +#include "core/mlas/lib/mlasi.h" +#include "core/mlas/lib/qnbitgemm.h" +#include "mlas_qnbit.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonFp16CastTest : public MlasTestBase { + private: + MatrixGuardBuffer fp32Buffer_; + MatrixGuardBuffer fp16Buffer_; + + template + void TestFp16ToFp32() { + const auto* src = fp16Buffer_.GetFilledBuffer(count, [](unsigned short* start, size_t size) { + for (size_t i = 0; i < size; i++) { + start[i] = static_cast(i); + } + }); + auto* dest = fp32Buffer_.GetBuffer(count, true); + + MlasCastF16ToF32KernelNeon(src, dest, count); + + for (size_t i = 0; i < count; i++) { + if ((src[i] & 0x1c00) == 0x1c00) continue; // skip inf and nan + ASSERT_EQ(dest[i], MLAS_FP16::FromBits(src[i]).ToFloat()); + } + } + + template + void TestFp32ToFp16() { + const auto* src = fp32Buffer_.GetFilledBuffer(count, [](float* p, size_t size) { + for (size_t i = 0; i < size; i++) { + p[i] = static_cast(i) + 0.125f; + } + }); + auto* dest = fp16Buffer_.GetBuffer(count, true); + + MlasCastF32ToF16KernelNeon(src, dest, count); + + for (size_t i = 0; i < count; i++) { + ASSERT_EQ(dest[i], MLAS_FP16(src[i]).val); + } + } + + public: + static const char* GetTestSuiteName() { + return "NeonFp16Cast"; + } + + void ExecuteShort(void) override { + TestFp16ToFp32<(1 << 16)>(); + TestFp16ToFp32<1>(); + TestFp16ToFp32<4>(); + TestFp16ToFp32<7>(); + TestFp32ToFp16<(1 << 16)>(); + TestFp32ToFp16<3>(); + TestFp32ToFp16<4>(); + TestFp32ToFp16<6>(); + } +}; + +class MlasNeonFp16PrepackTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution<> distrib_; + MatrixGuardBuffer input_, ref_, packed_; + + template + MLAS_FORCEINLINE void Transpose8x8(const uint8_t* src, size_t n, size_t k, uint8_t* dst) { + for (size_t c = 0; c < 8; c++) { + for (size_t r = 0; r < 8; r++) { + size_t i = (n + c) * Ldb + r + k; + size_t j = n * Ldb + (r + k) * 8 + c; + dst[j] = src[i]; + } + } + } + + MLAS_FORCEINLINE + uint8_t GetInt4(uint8_t v, size_t i) { + return (i & 1) ? (v >> 4) : (v & 0x0f); + } + + MLAS_FORCEINLINE + void PrepackSlice(const uint8_t* src, size_t j, uint8_t* dst) { + for (size_t i = 0; i < 8; i++) { + uint8_t v0 = GetInt4(src[j + (i >> 1)], i); + uint8_t v1 = GetInt4(src[j + ((8 + i) >> 1)], i + 8); + dst[j + i] = v0 | (v1 << 4); + } + } + + template + MLAS_FORCEINLINE void Prepack(const uint8_t* src, uint8_t* dst) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t k = 0; k < Ldb; k += 8) { + Transpose8x8(src, n, k, dst); + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < Ldb; k += 8) { + PrepackSlice(src, n * Ldb + k, dst); + } + } + } + + template + MLAS_FORCEINLINE void Check(const uint8_t* packed, const uint8_t* ref) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t i = 0; i < K; i += 2) { + for (size_t j = 0; j < 8; ++j) { + ASSERT_EQ(packed[n * Ldb + (i >> 1) * 8 + j], ref[n * Ldb + (i >> 1) * 8 + j]) + << " seed " << seed_ + << " n " << n << " i " << i << " j " << j; + } + } + } + + for (; n < N; ++n) { + for (size_t i = 0; i < K; i += 2) { + ASSERT_EQ(packed[n * Ldb + (i >> 1)], ref[n * Ldb + (i >> 1)]) + << " seed " << seed_ + << " n " << n << " i " << i; + } + } + } + + template + void TestPrepack() { + constexpr size_t Bits = 4; + constexpr size_t Ldb = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t BufferSize = N * Ldb; + auto InitializeBuffer = [this](uint8_t* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = static_cast(distrib_(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(BufferSize, InitializeBuffer); + auto* packed = packed_.GetBuffer(BufferSize, true); + auto* ref = ref_.GetBuffer(BufferSize, true); + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::HQNBIT_CompFp16, input, packed, + nullptr, false, nullptr, nullptr); + Prepack(input, ref); + Check(packed, ref); + } + + public: + MlasNeonFp16PrepackTest() + : seed_(19287), gen_(seed_), distrib_(0, 255) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16Prepack"; + } + + void ExecuteShort(void) override { + TestPrepack<1, 1, 16>(); + TestPrepack<1, 15, 16>(); + TestPrepack<1, 31, 16>(); + TestPrepack<8, 1, 16>(); + TestPrepack<8, 16, 16>(); + TestPrepack<9, 31, 16>(); + TestPrepack<9, 33, 32>(); + TestPrepack<15, 33, 16>(); + TestPrepack<17, 67, 16>(); + TestPrepack<17, 96, 128>(); + TestPrepack<263, 263, 16>(); + } +}; + +class MlasNeonFp16DequantBTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution<> distrib_; + std::uniform_real_distribution _distribFp; + MatrixGuardBuffer input_, zero_points_; + MatrixGuardBuffer dequant_, ref_, scales_; + + MLAS_FORCEINLINE + uint8_t GetInt4(uint8_t v, size_t i) { + return (i & 1) ? (v >> 4) : (v & 0x0f); + } + + template + void DequantB(const uint8_t* src, MLAS_FP16* dst, const MLAS_FP16* scales, const uint8_t* zero_points) { + constexpr size_t blkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t ld_src = (blkNum * BlkLen + 1) / 2; + constexpr size_t ld_dst = blkNum * BlkLen; + constexpr size_t ld_zp = (blkNum + 1) / 2; + size_t n = 0; + for (; n + 8 <= N; n += 8) { + size_t i_src = n * ld_src, i_dst = n * ld_dst, i_scale = n * blkNum, i_zp = n * ld_zp; + for (size_t blk = 0; blk < blkNum; i_zp += (blk & 1), ++blk, ++i_scale) { + for (size_t i = 0; i < BlkLen; i += 2, i_dst += 8) { + for (size_t j = 0; j < 8; ++j, ++i_src, ++i_dst) { + uint8_t v = src[i_src]; + float v0 = static_cast(GetInt4(v, 0)); + float v1 = static_cast(GetInt4(v, 1)); + float zp = static_cast(UseZeroPoints ? GetInt4(zero_points[i_zp + ld_zp * j], blk) : 8); + float scale = scales[i_scale + blkNum * j]; + dst[i_dst] = MLAS_FP16(v0 * scale - zp * scale); + dst[i_dst + 8] = MLAS_FP16(v1 * scale - zp * scale); + } + } + } + } + + for (; n < N; ++n) { + size_t i_src = n * ld_src, i_dst = n * ld_dst, i_scale = n * blkNum, i_zp = n * ld_zp; + for (size_t blk = 0; blk < blkNum; i_zp += (blk & 1), ++blk, ++i_scale) { + float zp = static_cast(UseZeroPoints ? GetInt4(zero_points[i_zp], blk) : 8); + float scale = scales[i_scale]; + for (size_t i = 0; i < BlkLen; i += 16, i_dst += 8) { + for (size_t j = 0; j < 16; j += 2, ++i_src, ++i_dst) { + uint8_t v = src[i_src]; + float v0 = static_cast(GetInt4(v, 0)); + float v1 = static_cast(GetInt4(v, 1)); + dst[i_dst] = MLAS_FP16(v0 * scale - zp * scale); + dst[i_dst + 8] = MLAS_FP16(v1 * scale - zp * scale); + } + } + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = std::abs(v0.ToFloat()), f1 = std::abs(v1.ToFloat()); + return std::abs(f0 - f1) <= f1 * rtol + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* target, const MLAS_FP16* ref) { + size_t n = 0; + for (; n + 8 <= N; n += 8) { + for (size_t i = 0; i < K; ++i) { + for (size_t j = 0; j < 8; ++j) { + size_t idx = n * Ldb + i * 8 + j; + ASSERT_TRUE(FloatEqual(target[idx], ref[idx], 0.01f, 0.01f)) + << " seed " << seed_ + << " v0 " << target[idx] << " v1 " << ref[idx] + << " n " << n << " i " << i << " j " << j; + } + } + } + + for (; n < N; ++n) { + for (size_t i = 0; i < K; ++i) { + size_t idx = n * Ldb + i; + ASSERT_TRUE(FloatEqual(target[idx], ref[idx], 0.01f, 0.01f)) + << " seed " << seed_ + << " v0 " << target[idx] << " v1 " << ref[idx] + << " n " << n << " i " << i; + } + } + } + + template + void TestDequant() { + constexpr size_t BlkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t BCount = BlkNum * BlkLen * N; + constexpr size_t ScaleCount = N * BlkNum; + constexpr size_t ZpSize = N * ((BlkNum + 1) / 2); + + auto InitializeBuffer_i8 = [this](uint8_t* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = static_cast(distrib_(gen_)); + } + }; + + auto InitializeBuffer_fp16 = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(_distribFp(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(BCount / 2, InitializeBuffer_i8); + const auto* zero_points = zero_points_.GetFilledBuffer(ZpSize, InitializeBuffer_i8); + auto* dequant = dequant_.GetBuffer(BCount); + auto* ref = ref_.GetBuffer(BCount); + const auto* scales = scales_.GetFilledBuffer(ScaleCount, InitializeBuffer_fp16); + GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16( + BlkLen, dequant, reinterpret_cast(input), scales, + UseZeroPoints ? reinterpret_cast(zero_points) : nullptr, + N, K, BlkNum); + DequantB(input, ref, scales, zero_points); + Check(dequant, ref); + } + + public: + MlasNeonFp16DequantBTest() + : seed_(19287), gen_(seed_), distrib_(0, 255), _distribFp(0.5f, 2.0f) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16DequantB"; + } + + void ExecuteShort(void) override { + TestDequant<1, 1, 16, false>(); + TestDequant<1, 1, 16, true>(); + TestDequant<1, 15, 16, false>(); + TestDequant<1, 15, 16, true>(); + TestDequant<1, 31, 16, false>(); + TestDequant<1, 31, 16, true>(); + TestDequant<8, 1, 16, false>(); + TestDequant<8, 1, 16, true>(); + TestDequant<8, 16, 16, false>(); + TestDequant<8, 16, 16, true>(); + TestDequant<9, 31, 16, false>(); + TestDequant<9, 31, 16, true>(); + TestDequant<9, 33, 32, false>(); + TestDequant<9, 33, 32, true>(); + TestDequant<15, 33, 16, false>(); + TestDequant<15, 33, 16, true>(); + TestDequant<17, 67, 16, false>(); + TestDequant<17, 67, 16, true>(); + TestDequant<17, 96, 128, false>(); + TestDequant<17, 96, 128, true>(); + TestDequant<263, 263, 16, false>(); + TestDequant<263, 263, 16, true>(); + } +}; + +class MlasNeonFp16HQ4BitGemmKernelTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + MatrixGuardBuffer A_, B_, C_, ref_, bias_; + + MLAS_FORCEINLINE + void InitializeBuffer(MLAS_FP16* buffer, float min, float max, size_t count) { + std::uniform_real_distribution distrib(min, max); + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib(gen_)); + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + float GetBVal(const MLAS_FP16* B, size_t n, size_t k) { + size_t i; + if ((N & (~7)) > n) { + size_t full8 = n & (~7); + i = full8 * ldb + 8 * k + (n - full8); + } else { + i = n * ldb + k; + } + return B[i].ToFloat(); + } + + template + void MatMul(const MLAS_FP16* A, const MLAS_FP16* B, const MLAS_FP16* bias, MLAS_FP16* C) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = UseBias ? bias[n] : 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[m * K + k].ToFloat(); + float b = GetBVal(B, n, k); + accu = accu + a * b; + } + C[m * N + n] = MLAS_FP16(accu); + } + } + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* target, const MLAS_FP16* ref) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + size_t i = m * Ldc + n; + ASSERT_TRUE(FloatEqual(target[i], ref[i], 0.02f, 0.055f)) + << " seed " << seed_ + << " v0 " << target[i] << " v1 " << ref[i] + << " m " << m << " n " << n; + } + } + } + + template + void TestHQ4BitGemmKernel() { + static_assert(M <= 2); + constexpr size_t BlkNum = (K + BlkLen - 1) / BlkLen; + constexpr size_t ldb = BlkNum * BlkLen; + + const auto* A = A_.GetFilledBuffer(M * K, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -0.25f, 0.25f, t); + }); + const auto* B = B_.GetFilledBuffer(ldb * N, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -0.25f, 0.25f, t); + }); + auto* C = C_.GetBuffer(M * N, true); + auto* ref = ref_.GetBuffer(M * N, true); + auto* bias = bias_.GetFilledBuffer(N, [this](MLAS_FP16* p, size_t t) { + InitializeBuffer(p, -5.0f, 5.0f, t); + }); + + GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16( + A, B, UseBias ? bias : nullptr, C, M, N, K, K, ldb, N); + + MatMul(A, B, bias, ref); + Check(C, ref); + } + + public: + MlasNeonFp16HQ4BitGemmKernelTest() + : seed_(19287), gen_(seed_) { + } + + static const char* GetTestSuiteName() { + return "NeonFp16HQ4BitGemmKernel"; + } + + template + void ExecuteShort_T(void) { + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + TestHQ4BitGemmKernel(); + } + + void ExecuteShort(void) override { + ExecuteShort_T<1>(); + ExecuteShort_T<2>(); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + if (GetMlasPlatform().QNBitGemmDispatch) { + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmPackQuantBData) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitBlkDequantBForHgemm_CompFp16) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + if (GetMlasPlatform().QNBitGemmDispatch->HQ4BitGemmKernel_CompFp16) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + } + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/tests/unittest/test_sqnbitgemm.cpp b/tests/unittest/test_sqnbitgemm.cpp index 0710981..e22018a 100644 --- a/tests/unittest/test_sqnbitgemm.cpp +++ b/tests/unittest/test_sqnbitgemm.cpp @@ -18,11 +18,11 @@ Module Name: #include "mlas_q4.h" #include "mlas_qnbit.h" -static constexpr const char* ComputeTypeName(MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType) { +static constexpr const char* ComputeTypeName(MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType) { switch (ComputeType) { - case CompFp32: + case SQNBIT_CompFp32: return "Fp32"; - case CompInt8: + case SQNBIT_CompInt8: return "Int8"; default: return "unknown"; @@ -63,16 +63,16 @@ class MlasSQNBitGemmTest : public MlasTestBase { float* C, size_t ldc, void* Workspace, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, MLAS_THREADPOOL* Threadpool) { - MLAS_SQNBIT_GEMM_DATA_PARAMS params; + MLAS_QNBIT_GEMM_DATA_PARAMS params; params.A = A; params.lda = lda; params.Bias = Bias; params.C = C; params.ldc = ldc; #ifdef MLAS_TARGET_AMD64_IX86 - if (ComputeType == CompInt8) { + if (ComputeType == SQNBIT_CompInt8) { params.QuantBDataWorkspace = PackedQuantBDataWorkspace; } #endif @@ -81,7 +81,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { params.QuantBZeroPoint = QuantBZeroPoint; params.PostProcessor = nullptr; - MlasSQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); + MlasQNBitGemmBatch(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType, ¶ms, Workspace, Threadpool); } void QuantizeA(size_t M, size_t K, const float* A, int8_t* QuantAData, float* QuantAScale) { @@ -201,7 +201,7 @@ class MlasSQNBitGemmTest : public MlasTestBase { public: void Test(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) { MLAS_THREADPOOL* Threadpool = WithThreadpool ? GetMlasThreadPool() : nullptr; @@ -265,19 +265,19 @@ class MlasSQNBitGemmTest : public MlasTestBase { } void* Workspace = nullptr; - if (const auto WorkspaceSize = MlasSQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); WorkspaceSize > 0) { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } void* PackedQuantBDataWorkspace = nullptr; - if (const auto PackedQuantBDataSize = MlasSQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBDataWorkspace = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); bool has_zp_input = QuantBZeroPoint != nullptr; - MlasSQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, - QuantBScale, has_zp_input, QuantBZeroPoint, - GetMlasThreadPool()); + MlasQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData, PackedQuantBDataWorkspace, + QuantBScale, has_zp_input, QuantBZeroPoint, + GetMlasThreadPool()); } CallGemm(M, N, K, @@ -289,9 +289,9 @@ class MlasSQNBitGemmTest : public MlasTestBase { ComputeType, Threadpool); - if (ComputeType == CompFp32) { + if (ComputeType == SQNBIT_CompFp32) { CallReferenceGemm_CompFp32(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); - } else if (ComputeType == CompInt8) { + } else if (ComputeType == SQNBIT_CompInt8) { CallReferenceGemm_CompInt8(M, N, K, A, QuantBData, QuantBScale, QuantBZeroPoint, Bias, CReference); } else { FAIL() << "Test is not implemented for compute type " @@ -324,7 +324,7 @@ template class SQNBitGemmShortExecuteTest : public MlasTestFixture> { public: explicit SQNBitGemmShortExecuteTest(size_t M, size_t N, size_t K, - MLAS_SQNBIT_GEMM_COMPUTE_TYPE ComputeType, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, bool WithThreadpool, bool Symmetric, bool WithBias) : M_(M), N_(N), @@ -341,11 +341,11 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture= range ? FillValue - range : FillValue; } }); } From f75cc7c604f55052837d80dff21e521f1b4a6830 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 6 Dec 2024 20:55:02 +0000 Subject: [PATCH 03/33] update --- tests/unittest/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unittest/CMakeLists.txt b/tests/unittest/CMakeLists.txt index a61df1b..4c2cd0e 100644 --- a/tests/unittest/CMakeLists.txt +++ b/tests/unittest/CMakeLists.txt @@ -26,6 +26,7 @@ test_sbgemm.cpp test_scaleoutput.cpp test_softmax.cpp test_sqnbitgemm.cpp +test_hqnbitgemm_neon.cpp test_symm_qgemm.cpp test_transpose.cpp) if(MSVC) From ac525b3f3ac800dc0897db483ed1dcfdff10f0d2 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 6 Dec 2024 20:57:08 +0000 Subject: [PATCH 04/33] revert --- include/mlas_gemm_postprocessor.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/mlas_gemm_postprocessor.h b/include/mlas_gemm_postprocessor.h index 7ea29eb..8c24705 100644 --- a/include/mlas_gemm_postprocessor.h +++ b/include/mlas_gemm_postprocessor.h @@ -16,6 +16,7 @@ Module Name: --*/ #pragma once +#include template class MLAS_GEMM_POSTPROCESSOR From 14b8f54031b3cc410bedebafc5b000cc98063848 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 6 Dec 2024 20:57:50 +0000 Subject: [PATCH 05/33] revert --- src/common/threadpool.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/threadpool.cc b/src/common/threadpool.cc index 7e8caa7..bb94f62 100644 --- a/src/common/threadpool.cc +++ b/src/common/threadpool.cc @@ -25,7 +25,7 @@ limitations under the License. #if !defined(ORT_MINIMAL_BUILD) #ifdef _WIN32 #include -#include "processthreadsapi.h" +#include #include #include #elif defined(__APPLE__) From e5e59f48345efd683bcab9018de29803ba02f31b Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 6 Dec 2024 20:58:13 +0000 Subject: [PATCH 06/33] revert --- src/lib/CMakeLists.txt | 1608 ++++++++++++++++++++-------------------- 1 file changed, 804 insertions(+), 804 deletions(-) diff --git a/src/lib/CMakeLists.txt b/src/lib/CMakeLists.txt index 95829fe..6d27525 100644 --- a/src/lib/CMakeLists.txt +++ b/src/lib/CMakeLists.txt @@ -1,804 +1,804 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -set(MLAS_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) -set(MLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) -set(MLAS_INC_DIR ${MLAS_ROOT}/../include) - -include_directories(${ONNXRUNTIME_INCLUDE_DIR}) - -#Set global compile flags for all the source code(including third_party code like protobuf) -#This section must be before any add_subdirectory, otherwise build may fail because /MD,/MT mismatch -if (MSVC) - if (CMAKE_VS_PLATFORM_NAME) - # Multi-platform generator - set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) - else() - set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) - endif() - if (onnxruntime_target_platform STREQUAL "ARM64") - set(onnxruntime_target_platform "ARM64") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "ARM64EC") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") - set(onnxruntime_target_platform "ARM") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64") - set(onnxruntime_target_platform "x64") - enable_language(ASM_MASM) - elseif (onnxruntime_target_platform STREQUAL "Win32" OR onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686") - set(onnxruntime_target_platform "x86") - enable_language(ASM_MASM) - message("Enabling SAFESEH for x86 build") - set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") - else() - message(FATAL_ERROR "Unknown CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") - endif() -endif() - -# -# All hardware agnostic source files here -# hardware specific files would cause trouble in -# multi-target build -# -add_library(onnxruntime_mlas STATIC - ${MLAS_SRC_DIR}/mlasi.h - ${MLAS_SRC_DIR}/platform.cpp - ${MLAS_SRC_DIR}/threading.cpp - ${MLAS_SRC_DIR}/sgemm.cpp - ${MLAS_SRC_DIR}/halfgemm.cpp - ${MLAS_SRC_DIR}/qgemm.cpp - ${MLAS_SRC_DIR}/qdwconv.cpp - ${MLAS_SRC_DIR}/convolve.cpp - ${MLAS_SRC_DIR}/convsym.cpp - ${MLAS_SRC_DIR}/pooling.cpp - ${MLAS_SRC_DIR}/transpose.cpp - ${MLAS_SRC_DIR}/reorder.cpp - ${MLAS_SRC_DIR}/snchwc.cpp - ${MLAS_SRC_DIR}/activate.cpp - ${MLAS_SRC_DIR}/logistic.cpp - ${MLAS_SRC_DIR}/tanh.cpp - ${MLAS_SRC_DIR}/erf.cpp - ${MLAS_SRC_DIR}/compute.cpp - ${MLAS_SRC_DIR}/quantize.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp - ${MLAS_SRC_DIR}/qladd.cpp - ${MLAS_SRC_DIR}/qlmul.cpp - ${MLAS_SRC_DIR}/qpostprocessor.cpp - ${MLAS_SRC_DIR}/qlgavgpool.cpp - ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp - ${MLAS_SRC_DIR}/qnbitgemm.h - ${MLAS_SRC_DIR}/qnbitgemm.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h - ${MLAS_SRC_DIR}/flashattn.cpp - ${MLAS_SRC_DIR}/cast.cpp -) - -target_sources(onnxruntime_mlas PRIVATE - ${MLAS_INC_DIR}/mlas_float16.h - ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h - ${MLAS_INC_DIR}/mlas_q4.h - ${MLAS_INC_DIR}/mlas_qnbit.h - ${MLAS_INC_DIR}/mlas.h -) - -if (NOT onnxruntime_ORT_MINIMAL_BUILD) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/q4_dq.cpp - ${MLAS_SRC_DIR}/q4gemm.cpp - ) -endif() - - -#TODO: set MASM flags properly -function(setup_mlas_source_for_windows) - - # - # Sources common for all platforms. - # - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ) - - #The onnxruntime_target_platform variable was added by Windows AI team in onnxruntime_common.cmake - #Don't use it for other platforms. - if((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")) - set(PREPROCESS_ARMASM_FLAGS "") - set(ARMASM_FLAGS "") - - if(onnxruntime_target_platform STREQUAL "ARM64") - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp - ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp - ) - - set(mlas_platform_preprocess_srcs - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDotLd64.asm - ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelDot.asm - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm - ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm - ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelSdot.asm - ${MLAS_SRC_DIR}/arm64/SgemmKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SgemvKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDot.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm - ) - else() - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ) - - set(mlas_platform_preprocess_srcs - ${MLAS_SRC_DIR}/arm64ec/QgemmU8X8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64ec/SgemmKernelNeon.asm - ) - - string(APPEND PREPROCESS_ARMASM_FLAGS " /arm64EC") - string(APPEND ARMASM_FLAGS " -machine ARM64EC") - endif() - - if(CMAKE_BUILD_TYPE STREQUAL "Debug") - string(APPEND ARMASM_FLAGS " -g") - endif() - - # Remove double quotes from flag strings. - separate_arguments(PREPROCESS_ARMASM_FLAGS NATIVE_COMMAND "${PREPROCESS_ARMASM_FLAGS}") - separate_arguments(ARMASM_FLAGS NATIVE_COMMAND "${ARMASM_FLAGS}") - - # Run the C precompiler on each input before the assembler. - foreach(asm_filename ${mlas_platform_preprocess_srcs}) - get_filename_component(asm_filename_base ${asm_filename} NAME_WLE) - set(preprocess_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.i) - set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.obj) - add_custom_command( - OUTPUT ${obj_filename} - COMMAND - cl.exe ${PREPROCESS_ARMASM_FLAGS} /P ${asm_filename} /Fi${preprocess_filename} - COMMAND - armasm64.exe ${ARMASM_FLAGS} ${preprocess_filename} ${obj_filename} - DEPENDS ${asm_filename} - BYPRODUCTS ${preprocess_filename} - ) - target_sources(onnxruntime_mlas PRIVATE ${obj_filename}) - endforeach() - elseif(onnxruntime_target_platform STREQUAL "ARM") - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/arm/sgemmc.cpp - ) - elseif(onnxruntime_target_platform STREQUAL "x64") - - file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS - "${MLAS_SRC_DIR}/intrinsics/avx/*.cpp" - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "/arch:AVX") - - file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS - "${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp" - ) - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") - - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/dgemm.cpp - ${mlas_platform_srcs_avx} - ${mlas_platform_srcs_avx2} - ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp - ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Vnni.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvxVnni.asm - ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelM1Avx.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/sgemma.asm - ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm - ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm - ) - if(MSVC_VERSION GREATER_EQUAL 1933) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm - ) - endif() - - if (NOT onnxruntime_ORT_MINIMAL_BUILD) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/q4gemm_avx512.cpp - ) - endif() - else() - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/i386/SgemmKernelSse2.asm - ${MLAS_SRC_DIR}/i386/SgemmKernelAvx.asm - ) - endif() -endfunction() - -if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/wasm_simd/*.cpp" - ) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/qgemm_kernel_wasmsimd.cpp - ) - else() - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/scalar/*.cpp" - ) - endif() - target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) -elseif(MSVC) - setup_mlas_source_for_windows() -else() - - if(APPLE) - get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) - - if(NOT ONNXRUNTIME_MLAS_OSX_ARCH) - set(ONNXRUNTIME_MLAS_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) - endif() - foreach(OSX_ARCH ${ONNXRUNTIME_MLAS_OSX_ARCH}) - if (OSX_ARCH STREQUAL "arm64") - set(ARM64 TRUE) - elseif (OSX_ARCH STREQUAL "arm64e") - set(ARM64 TRUE) - elseif (OSX_ARCH STREQUAL "arm") - set(ARM TRUE) - elseif (OSX_ARCH STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (OSX_ARCH STREQUAL "i386") - set(X86 TRUE) - endif() - endforeach() - elseif(ANDROID) - if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") - set(ARM TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") - set(ARM64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") - set(X86 TRUE) - endif() - else() - #Linux/FreeBSD/PowerPC/... - #The value of CMAKE_SYSTEM_PROCESSOR should be from `uname -m` - #Example values: - #arm64v8/ubuntu -> aarch64 - #arm32v6/alpine -> armv7l - #arm32v7/centos -> armv7l - #ppc64le/debian -> ppc64le - #s390x/ubuntu -> s390x - #ppc64le/busybox -> ppc64le - #arm64v8/ubuntu -> aarch64 - #Android: armv7-a aarch64 i686 x86_64 - #chasun: I don't think anyone uses 'arm64' - if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm.*") - set(ARM TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(powerpc.*|ppc.*)") - set(POWER TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") - set(X86 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") - set(X86_64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") - set(LOONGARCH64 TRUE) - endif() - endif() - - if(APPLE) - get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) - endif() - list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) - if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) - set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) - endif() - #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below - #and split MLAS to multiple static libraries. - #Otherwise, it works like if(...) elseif(...) elseif(...) endif() - set(MLAS_SOURCE_IS_NOT_SET 1) - if(ARM) - enable_language(ASM) - - set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} -mfpu=neon") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") - - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch32/QgemmU8X8KernelNeon.S - ${MLAS_SRC_DIR}/arm/sgemmc.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ) - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(ARM64 AND MLAS_SOURCE_IS_NOT_SET ) - enable_language(ASM) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDot.S - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDotLd64.S - ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelDot.S - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S - ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") - if (NOT APPLE) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S - ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp - ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp - ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - endif() - - if(ONNXRUNTIME_MLAS_MULTI_ARCH) - add_library(onnxruntime_mlas_arm64 STATIC ${mlas_platform_srcs}) - list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) - set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") - set(mlas_platform_srcs ) - else() - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(POWER AND MLAS_SOURCE_IS_NOT_SET) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp - ${MLAS_SRC_DIR}/dgemm.cpp - ${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp - ${MLAS_SRC_DIR}/power/QuantizePower.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") - - check_cxx_compiler_flag("-mcpu=power9" HAS_POWER9) - if (HAS_POWER9) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp PROPERTIES COMPILE_FLAGS "-mcpu=power9") - endif() - - check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10) - if(HAS_POWER10) - set(CMAKE_REQUIRED_FLAGS "-mcpu=power10") - check_cxx_source_compiles(" - #include - int main() { - __vector_quad acc0; - __builtin_mma_xxsetaccz (&acc0); - return 0; - }" - COMPILES_P10 - ) - if(COMPILES_P10) - check_cxx_source_compiles(" - #ifdef _AIX - #define POWER_10 0x40000 - #define POWER_10_ANDUP (POWER_10) - #include - #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) - int main() { - bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31); - return 0; - } - #else - #include - int main() { - unsigned long hwcap2 = getauxval(AT_HWCAP2); - bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); - return 0; - } - #endif" - HAS_P10_RUNTIME - ) - if (HAS_P10_RUNTIME) - set_source_files_properties(${MLAS_SRC_DIR}/platform.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") - set_source_files_properties(${MLAS_SRC_DIR}/qgemm.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") - endif() - set(mlas_platform_srcs_power10 - ${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp - ${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp - ${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE") - set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10") - set_source_files_properties(${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp PROPERTIES COMPILE_FLAGS "-O3 -mcpu=power10") - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${mlas_platform_srcs_power10} - ) - endif() - endif() - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(X86 AND MLAS_SOURCE_IS_NOT_SET) - enable_language(ASM) - - set(mlas_platform_srcs_sse2 - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/x86/SgemmKernelSse2.S - ) - set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") - - set(mlas_platform_srcs_avx - ${MLAS_SRC_DIR}/x86/SgemmKernelAvx.S - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") - - set(mlas_platform_srcs - ${mlas_platform_srcs_sse2} - ${mlas_platform_srcs_avx} - ) - - # In r23, NDK remove __x86.get_pc_thunk.* from libatomic. Add our own - # implementation to avoid external dependency. - if(ANDROID) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/x86/x86.get_pc_thunk.S - ) - endif() - - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(X86_64 AND MLAS_SOURCE_IS_NOT_SET) - enable_language(ASM) - - # Forward the flags for the minimum target platform version from the C - # compiler to the assembler. This works around CMakeASMCompiler.cmake.in - # not including the logic to set this flag for the assembler. - set(CMAKE_ASM${ASM_DIALECT}_OSX_DEPLOYMENT_TARGET_FLAG "${CMAKE_C_OSX_DEPLOYMENT_TARGET_FLAG}") - - # The LLVM assembler does not support the .arch directive to enable instruction - # set extensions and also doesn't support AVX-512F instructions without - # turning on support via command-line option. Group the sources by the - # instruction set extension and explicitly set the compiler flag as appropriate. - - set(mlas_platform_srcs_sse2 - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/x86_64/DgemmKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Sse2.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S - ) - if(NOT APPLE) - set(mlas_platform_srcs_sse2 - ${mlas_platform_srcs_sse2} - ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S - ) - endif() - set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") - - set(mlas_platform_srcs_avx - ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1Avx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1TransposeBAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Avx.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx.S - ${MLAS_SRC_DIR}/intrinsics/avx/min_max_elements.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") - - set(mlas_platform_srcs_avx2 - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8U8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvxVnni.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/DgemmKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/TransKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S - ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp - ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ) - if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) - set(mlas_platform_srcs_avx2 - ${mlas_platform_srcs_avx2} - ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S - ) - endif() -message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") -message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") - -if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") - message(STATUS "Using -mavx2 -mfma -mavxvnni flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") -else() - message(STATUS "Using -mavx2 -mfma flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") -endif() - set(mlas_platform_srcs_avx512f - ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") - - set(mlas_platform_srcs_avx512core - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Core.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S - ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") - - set(mlas_platform_srcs_avx512vnni - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") - - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/dgemm.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp - ${mlas_platform_srcs_sse2} - ${mlas_platform_srcs_avx} - ${mlas_platform_srcs_avx2} - ${mlas_platform_srcs_avx512f} - ${mlas_platform_srcs_avx512core} - ${mlas_platform_srcs_avx512vnni} - ) - - if (NOT onnxruntime_ORT_MINIMAL_BUILD) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/q4gemm_avx512.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() - if(NOT APPLE) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S - ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S - ) - set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() - - if(ONNXRUNTIME_MLAS_MULTI_ARCH) - add_library(onnxruntime_mlas_x86_64 STATIC ${mlas_platform_srcs}) - set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") - list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) - set(mlas_platform_srcs ) - else() - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp - ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S - ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S - ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S - ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S - ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/scalar/*.cpp") - elseif (onnxruntime_FORCE_GENERIC_ALGORITHMS) - file(GLOB_RECURSE mlas_platform_srcs_generic - "${MLAS_SRC_DIR}/scalar/*.cpp") - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${mlas_platform_srcs_generic} - ) - endif() - target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) -endif() - -foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_INCLUDE_DIR} ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) - target_link_libraries(${mlas_target} Microsoft.GSL::GSL) - - set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") -endforeach() - -if (WIN32) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") - if (onnxruntime_ENABLE_STATIC_ANALYSIS) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") - endif() -endif() - -if (PLATFORM_NAME STREQUAL "macabi") - # Needed for maccatalyst C compilation - # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" - target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) -endif() - -if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_mlas - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) -endif() - -# set up source group for MLAS source files -block() - set(source_group_srcs) - foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - get_target_property(mlas_target_srcs ${mlas_target} SOURCES) - foreach(mlas_target_src ${mlas_target_srcs}) - cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) - if(in_mlas_root) - list(APPEND source_group_srcs ${mlas_target_src}) - endif() - endforeach() - endforeach() -endblock() - - - - # - # Command line tool for quantization and de-quantization of 2-D fp32 tensors - # based on block-wise quantization of int4 - # - - add_executable(onnxruntime_mlas_q4dq - ${MLAS_SRC_DIR}/q4_dq_cli.cpp - ) - target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") - - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS}) - if(NOT MLAS_NO_ONNXRUNTIME) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE onnxruntime_common) - endif() - if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo) - endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs}) - endif() - - if(WIN32) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE debug Dbghelp Advapi32) - endif() - if (onnxruntime_LINK_LIBATOMIC) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE atomic) - endif() - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE Threads::Threads) - - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") - else() - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") - endif() - endif() - +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +set(MLAS_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) +set(MLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(MLAS_INC_DIR ${MLAS_ROOT}/../include) + +include_directories(${ONNXRUNTIME_INCLUDE_DIR}) + +#Set global compile flags for all the source code(including third_party code like protobuf) +#This section must be before any add_subdirectory, otherwise build may fail because /MD,/MT mismatch +if (MSVC) + if (CMAKE_VS_PLATFORM_NAME) + # Multi-platform generator + set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) + else() + set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) + endif() + if (onnxruntime_target_platform STREQUAL "ARM64") + set(onnxruntime_target_platform "ARM64") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM64EC") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") + set(onnxruntime_target_platform "ARM") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64") + set(onnxruntime_target_platform "x64") + enable_language(ASM_MASM) + elseif (onnxruntime_target_platform STREQUAL "Win32" OR onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686") + set(onnxruntime_target_platform "x86") + enable_language(ASM_MASM) + message("Enabling SAFESEH for x86 build") + set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") + else() + message(FATAL_ERROR "Unknown CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") + endif() +endif() + +# +# All hardware agnostic source files here +# hardware specific files would cause trouble in +# multi-target build +# +add_library(onnxruntime_mlas STATIC + ${MLAS_SRC_DIR}/mlasi.h + ${MLAS_SRC_DIR}/platform.cpp + ${MLAS_SRC_DIR}/threading.cpp + ${MLAS_SRC_DIR}/sgemm.cpp + ${MLAS_SRC_DIR}/halfgemm.cpp + ${MLAS_SRC_DIR}/qgemm.cpp + ${MLAS_SRC_DIR}/qdwconv.cpp + ${MLAS_SRC_DIR}/convolve.cpp + ${MLAS_SRC_DIR}/convsym.cpp + ${MLAS_SRC_DIR}/pooling.cpp + ${MLAS_SRC_DIR}/transpose.cpp + ${MLAS_SRC_DIR}/reorder.cpp + ${MLAS_SRC_DIR}/snchwc.cpp + ${MLAS_SRC_DIR}/activate.cpp + ${MLAS_SRC_DIR}/logistic.cpp + ${MLAS_SRC_DIR}/tanh.cpp + ${MLAS_SRC_DIR}/erf.cpp + ${MLAS_SRC_DIR}/compute.cpp + ${MLAS_SRC_DIR}/quantize.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp + ${MLAS_SRC_DIR}/qladd.cpp + ${MLAS_SRC_DIR}/qlmul.cpp + ${MLAS_SRC_DIR}/qpostprocessor.cpp + ${MLAS_SRC_DIR}/qlgavgpool.cpp + ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/qnbitgemm.h + ${MLAS_SRC_DIR}/qnbitgemm.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h + ${MLAS_SRC_DIR}/flashattn.cpp + ${MLAS_SRC_DIR}/cast.cpp +) + +target_sources(onnxruntime_mlas PRIVATE + ${MLAS_INC_DIR}/mlas_float16.h + ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h + ${MLAS_INC_DIR}/mlas_q4.h + ${MLAS_INC_DIR}/mlas_qnbit.h + ${MLAS_INC_DIR}/mlas.h +) + +if (NOT onnxruntime_ORT_MINIMAL_BUILD) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/q4_dq.cpp + ${MLAS_SRC_DIR}/q4gemm.cpp + ) +endif() + + +#TODO: set MASM flags properly +function(setup_mlas_source_for_windows) + + # + # Sources common for all platforms. + # + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ) + + #The onnxruntime_target_platform variable was added by Windows AI team in onnxruntime_common.cmake + #Don't use it for other platforms. + if((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")) + set(PREPROCESS_ARMASM_FLAGS "") + set(ARMASM_FLAGS "") + + if(onnxruntime_target_platform STREQUAL "ARM64") + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ) + + set(mlas_platform_preprocess_srcs + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDotLd64.asm + ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelDot.asm + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm + ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm + ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelSdot.asm + ${MLAS_SRC_DIR}/arm64/SgemmKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SgemvKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDot.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm + ) + else() + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ) + + set(mlas_platform_preprocess_srcs + ${MLAS_SRC_DIR}/arm64ec/QgemmU8X8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64ec/SgemmKernelNeon.asm + ) + + string(APPEND PREPROCESS_ARMASM_FLAGS " /arm64EC") + string(APPEND ARMASM_FLAGS " -machine ARM64EC") + endif() + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + string(APPEND ARMASM_FLAGS " -g") + endif() + + # Remove double quotes from flag strings. + separate_arguments(PREPROCESS_ARMASM_FLAGS NATIVE_COMMAND "${PREPROCESS_ARMASM_FLAGS}") + separate_arguments(ARMASM_FLAGS NATIVE_COMMAND "${ARMASM_FLAGS}") + + # Run the C precompiler on each input before the assembler. + foreach(asm_filename ${mlas_platform_preprocess_srcs}) + get_filename_component(asm_filename_base ${asm_filename} NAME_WLE) + set(preprocess_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.i) + set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.obj) + add_custom_command( + OUTPUT ${obj_filename} + COMMAND + cl.exe ${PREPROCESS_ARMASM_FLAGS} /P ${asm_filename} /Fi${preprocess_filename} + COMMAND + armasm64.exe ${ARMASM_FLAGS} ${preprocess_filename} ${obj_filename} + DEPENDS ${asm_filename} + BYPRODUCTS ${preprocess_filename} + ) + target_sources(onnxruntime_mlas PRIVATE ${obj_filename}) + endforeach() + elseif(onnxruntime_target_platform STREQUAL "ARM") + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/arm/sgemmc.cpp + ) + elseif(onnxruntime_target_platform STREQUAL "x64") + + file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS + "${MLAS_SRC_DIR}/intrinsics/avx/*.cpp" + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "/arch:AVX") + + file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS + "${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp" + ) + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") + + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/dgemm.cpp + ${mlas_platform_srcs_avx} + ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp + ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Vnni.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvxVnni.asm + ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelM1Avx.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/sgemma.asm + ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm + ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm + ) + if(MSVC_VERSION GREATER_EQUAL 1933) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm + ) + endif() + + if (NOT onnxruntime_ORT_MINIMAL_BUILD) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/q4gemm_avx512.cpp + ) + endif() + else() + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp + ${MLAS_SRC_DIR}/i386/SgemmKernelSse2.asm + ${MLAS_SRC_DIR}/i386/SgemmKernelAvx.asm + ) + endif() +endfunction() + +if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/wasm_simd/*.cpp" + ) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/qgemm_kernel_wasmsimd.cpp + ) + else() + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/scalar/*.cpp" + ) + endif() + target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) +elseif(MSVC) + setup_mlas_source_for_windows() +else() + + if(APPLE) + get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) + + if(NOT ONNXRUNTIME_MLAS_OSX_ARCH) + set(ONNXRUNTIME_MLAS_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) + endif() + foreach(OSX_ARCH ${ONNXRUNTIME_MLAS_OSX_ARCH}) + if (OSX_ARCH STREQUAL "arm64") + set(ARM64 TRUE) + elseif (OSX_ARCH STREQUAL "arm64e") + set(ARM64 TRUE) + elseif (OSX_ARCH STREQUAL "arm") + set(ARM TRUE) + elseif (OSX_ARCH STREQUAL "x86_64") + set(X86_64 TRUE) + elseif (OSX_ARCH STREQUAL "i386") + set(X86 TRUE) + endif() + endforeach() + elseif(ANDROID) + if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") + set(ARM TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") + set(ARM64 TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") + set(X86_64 TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") + set(X86 TRUE) + endif() + else() + #Linux/FreeBSD/PowerPC/... + #The value of CMAKE_SYSTEM_PROCESSOR should be from `uname -m` + #Example values: + #arm64v8/ubuntu -> aarch64 + #arm32v6/alpine -> armv7l + #arm32v7/centos -> armv7l + #ppc64le/debian -> ppc64le + #s390x/ubuntu -> s390x + #ppc64le/busybox -> ppc64le + #arm64v8/ubuntu -> aarch64 + #Android: armv7-a aarch64 i686 x86_64 + #chasun: I don't think anyone uses 'arm64' + if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*") + set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm.*") + set(ARM TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") + set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(powerpc.*|ppc.*)") + set(POWER TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") + set(X86 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") + set(X86_64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") + set(LOONGARCH64 TRUE) + endif() + endif() + + if(APPLE) + get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) + endif() + list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) + if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) + set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) + endif() + #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below + #and split MLAS to multiple static libraries. + #Otherwise, it works like if(...) elseif(...) elseif(...) endif() + set(MLAS_SOURCE_IS_NOT_SET 1) + if(ARM) + enable_language(ASM) + + set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} -mfpu=neon") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") + + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch32/QgemmU8X8KernelNeon.S + ${MLAS_SRC_DIR}/arm/sgemmc.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ) + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(ARM64 AND MLAS_SOURCE_IS_NOT_SET ) + enable_language(ASM) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDot.S + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDotLd64.S + ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelDot.S + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S + ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") + if (NOT APPLE) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/fp16_neon_common.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + endif() + + if(ONNXRUNTIME_MLAS_MULTI_ARCH) + add_library(onnxruntime_mlas_arm64 STATIC ${mlas_platform_srcs}) + list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) + set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") + set(mlas_platform_srcs ) + else() + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(POWER AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp + ${MLAS_SRC_DIR}/dgemm.cpp + ${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp + ${MLAS_SRC_DIR}/power/QuantizePower.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") + + check_cxx_compiler_flag("-mcpu=power9" HAS_POWER9) + if (HAS_POWER9) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp PROPERTIES COMPILE_FLAGS "-mcpu=power9") + endif() + + check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10) + if(HAS_POWER10) + set(CMAKE_REQUIRED_FLAGS "-mcpu=power10") + check_cxx_source_compiles(" + #include + int main() { + __vector_quad acc0; + __builtin_mma_xxsetaccz (&acc0); + return 0; + }" + COMPILES_P10 + ) + if(COMPILES_P10) + check_cxx_source_compiles(" + #ifdef _AIX + #define POWER_10 0x40000 + #define POWER_10_ANDUP (POWER_10) + #include + #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) + int main() { + bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31); + return 0; + } + #else + #include + int main() { + unsigned long hwcap2 = getauxval(AT_HWCAP2); + bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); + return 0; + } + #endif" + HAS_P10_RUNTIME + ) + if (HAS_P10_RUNTIME) + set_source_files_properties(${MLAS_SRC_DIR}/platform.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") + set_source_files_properties(${MLAS_SRC_DIR}/qgemm.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") + endif() + set(mlas_platform_srcs_power10 + ${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp + ${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp + ${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE") + set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10") + set_source_files_properties(${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp PROPERTIES COMPILE_FLAGS "-O3 -mcpu=power10") + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${mlas_platform_srcs_power10} + ) + endif() + endif() + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(X86 AND MLAS_SOURCE_IS_NOT_SET) + enable_language(ASM) + + set(mlas_platform_srcs_sse2 + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/x86/SgemmKernelSse2.S + ) + set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") + + set(mlas_platform_srcs_avx + ${MLAS_SRC_DIR}/x86/SgemmKernelAvx.S + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") + + set(mlas_platform_srcs + ${mlas_platform_srcs_sse2} + ${mlas_platform_srcs_avx} + ) + + # In r23, NDK remove __x86.get_pc_thunk.* from libatomic. Add our own + # implementation to avoid external dependency. + if(ANDROID) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/x86/x86.get_pc_thunk.S + ) + endif() + + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(X86_64 AND MLAS_SOURCE_IS_NOT_SET) + enable_language(ASM) + + # Forward the flags for the minimum target platform version from the C + # compiler to the assembler. This works around CMakeASMCompiler.cmake.in + # not including the logic to set this flag for the assembler. + set(CMAKE_ASM${ASM_DIALECT}_OSX_DEPLOYMENT_TARGET_FLAG "${CMAKE_C_OSX_DEPLOYMENT_TARGET_FLAG}") + + # The LLVM assembler does not support the .arch directive to enable instruction + # set extensions and also doesn't support AVX-512F instructions without + # turning on support via command-line option. Group the sources by the + # instruction set extension and explicitly set the compiler flag as appropriate. + + set(mlas_platform_srcs_sse2 + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/x86_64/DgemmKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Sse2.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S + ) + if(NOT APPLE) + set(mlas_platform_srcs_sse2 + ${mlas_platform_srcs_sse2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S + ) + endif() + set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") + + set(mlas_platform_srcs_avx + ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1Avx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1TransposeBAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Avx.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx.S + ${MLAS_SRC_DIR}/intrinsics/avx/min_max_elements.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") + + set(mlas_platform_srcs_avx2 + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8U8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvxVnni.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/DgemmKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/TransKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S + ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp + ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ) + if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) + set(mlas_platform_srcs_avx2 + ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S + ) + endif() +message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") + +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") + message(STATUS "Using -mavx2 -mfma -mavxvnni flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") +else() + message(STATUS "Using -mavx2 -mfma flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") +endif() + set(mlas_platform_srcs_avx512f + ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") + + set(mlas_platform_srcs_avx512core + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Core.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S + ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") + + set(mlas_platform_srcs_avx512vnni + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") + + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/dgemm.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp + ${mlas_platform_srcs_sse2} + ${mlas_platform_srcs_avx} + ${mlas_platform_srcs_avx2} + ${mlas_platform_srcs_avx512f} + ${mlas_platform_srcs_avx512core} + ${mlas_platform_srcs_avx512vnni} + ) + + if (NOT onnxruntime_ORT_MINIMAL_BUILD) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/q4gemm_avx512.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") + endif() + if(NOT APPLE) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S + ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S + ) + set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + endif() + + if(ONNXRUNTIME_MLAS_MULTI_ARCH) + add_library(onnxruntime_mlas_x86_64 STATIC ${mlas_platform_srcs}) + set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") + list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) + set(mlas_platform_srcs ) + else() + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S + ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S + ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/scalar/*.cpp") + elseif (onnxruntime_FORCE_GENERIC_ALGORITHMS) + file(GLOB_RECURSE mlas_platform_srcs_generic + "${MLAS_SRC_DIR}/scalar/*.cpp") + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${mlas_platform_srcs_generic} + ) + endif() + target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) +endif() + +foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) + target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_INCLUDE_DIR} ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) + target_link_libraries(${mlas_target} Microsoft.GSL::GSL) + + set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") +endforeach() + +if (WIN32) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") + if (onnxruntime_ENABLE_STATIC_ANALYSIS) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") + endif() +endif() + +if (PLATFORM_NAME STREQUAL "macabi") + # Needed for maccatalyst C compilation + # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" + target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) +endif() + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_mlas + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() + +# set up source group for MLAS source files +block() + set(source_group_srcs) + foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) + get_target_property(mlas_target_srcs ${mlas_target} SOURCES) + foreach(mlas_target_src ${mlas_target_srcs}) + cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) + if(in_mlas_root) + list(APPEND source_group_srcs ${mlas_target_src}) + endif() + endforeach() + endforeach() +endblock() + + + + # + # Command line tool for quantization and de-quantization of 2-D fp32 tensors + # based on block-wise quantization of int4 + # + + add_executable(onnxruntime_mlas_q4dq + ${MLAS_SRC_DIR}/q4_dq_cli.cpp + ) + target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") + + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS}) + if(NOT MLAS_NO_ONNXRUNTIME) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE onnxruntime_common) + endif() + if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo) + endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs}) + endif() + + if(WIN32) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE debug Dbghelp Advapi32) + endif() + if (onnxruntime_LINK_LIBATOMIC) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE atomic) + endif() + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE Threads::Threads) + + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") + else() + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") + endif() + endif() + From 5739d2ebdeb75abf391cf7fc12ab4e5afe517a44 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 6 Dec 2024 21:03:45 +0000 Subject: [PATCH 07/33] update --- {src/lib => include}/qnbitgemm.h | 0 tests/unittest/test_hqnbitgemm_neon.cpp | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename {src/lib => include}/qnbitgemm.h (100%) diff --git a/src/lib/qnbitgemm.h b/include/qnbitgemm.h similarity index 100% rename from src/lib/qnbitgemm.h rename to include/qnbitgemm.h diff --git a/tests/unittest/test_hqnbitgemm_neon.cpp b/tests/unittest/test_hqnbitgemm_neon.cpp index b598c20..946eb67 100644 --- a/tests/unittest/test_hqnbitgemm_neon.cpp +++ b/tests/unittest/test_hqnbitgemm_neon.cpp @@ -18,8 +18,8 @@ Module Name: #include #include "test_util.h" -#include "core/mlas/lib/mlasi.h" -#include "core/mlas/lib/qnbitgemm.h" +#include "mlasi.h" +#include "qnbitgemm.h" #include "mlas_qnbit.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) From f4b1197f5fdba005ab6156c0898168a96cfdda45 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 6 Dec 2024 21:08:23 +0000 Subject: [PATCH 08/33] update --- src/lib/CMakeLists.txt | 1 - tests/unittest/CMakeLists.txt | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lib/CMakeLists.txt b/src/lib/CMakeLists.txt index 6d27525..0942945 100644 --- a/src/lib/CMakeLists.txt +++ b/src/lib/CMakeLists.txt @@ -68,7 +68,6 @@ add_library(onnxruntime_mlas STATIC ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp - ${MLAS_SRC_DIR}/qnbitgemm.h ${MLAS_SRC_DIR}/qnbitgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp diff --git a/tests/unittest/CMakeLists.txt b/tests/unittest/CMakeLists.txt index 4c2cd0e..757082b 100644 --- a/tests/unittest/CMakeLists.txt +++ b/tests/unittest/CMakeLists.txt @@ -47,7 +47,7 @@ if(IOS) XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO" ) endif() -target_include_directories(mlas_unittest PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT} +target_include_directories(mlas_unittest PRIVATE ${ONNXRUNTIME_ROOT}/lib ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) target_link_libraries(mlas_unittest PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS}) if(NOT MLAS_NO_ONNXRUNTIME) From ddf8d4fb2033d1ccd7fe37e19788617309faecc3 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 6 Dec 2024 21:11:45 +0000 Subject: [PATCH 09/33] fix --- src/lib/sqnbitgemm_kernel_avx512.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/sqnbitgemm_kernel_avx512.cpp b/src/lib/sqnbitgemm_kernel_avx512.cpp index b4e25d4..b6cd624 100644 --- a/src/lib/sqnbitgemm_kernel_avx512.cpp +++ b/src/lib/sqnbitgemm_kernel_avx512.cpp @@ -18,7 +18,7 @@ Module Name: #include #include #include - +#include #include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" From 24e92e048abb64972cb09ab29b7e2902805a2ffe Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 6 Dec 2024 21:29:38 +0000 Subject: [PATCH 10/33] update --- src/lib/sqnbitgemm_kernel_avx512.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lib/sqnbitgemm_kernel_avx512.cpp b/src/lib/sqnbitgemm_kernel_avx512.cpp index b6cd624..592b244 100644 --- a/src/lib/sqnbitgemm_kernel_avx512.cpp +++ b/src/lib/sqnbitgemm_kernel_avx512.cpp @@ -19,6 +19,8 @@ Module Name: #include #include #include +#include + #include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" #include "sqnbitgemm_kernel_avx_common_int8.h" From 74fb793e94215c04ad3283295affeaf51f6ff489 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 08:58:29 -0700 Subject: [PATCH 11/33] sync to 70de20b94fdc77a0e603ef868a2c53ef38da557c --- cmake/deps.txt | 2 +- include/mlas.h | 218 +- include/mlas_gemm_postprocessor.h | 1 - include/mlas_q4.h | 4 +- include/mlas_qnbit.h | 26 +- include/qnbitgemm.h | 158 +- src/common/cpuid_info.cc | 134 +- src/common/cpuid_uarch.cc | 50 +- src/common/logging/logging.cc | 1 - src/common/string_utils.h | 18 + src/common/threadpool.cc | 4 +- src/core/platform/check_intel.h | 13 + src/core/platform/posix/env.cc | 9 +- src/core/platform/windows/env.cc | 1328 ++-- src/core/platform/windows/env.h | 21 +- src/lib/CMakeLists.txt | 1688 ++--- src/lib/activate_fp16.cpp | 34 +- src/lib/amd64/ConvSymKernelAvx2.asm | 87 + ...6_neon_common.cpp => cast_kernel_neon.cpp} | 2 +- src/lib/compute.cpp | 177 +- src/lib/dwconv.cpp | 6 +- src/lib/eltwise.cpp | 71 + src/lib/eltwise.h | 37 + src/lib/eltwise_kernel_neon.cpp | 32 + src/lib/eltwise_kernel_neon.h | 28 + src/lib/eltwise_kernel_neon_fp16.cpp | 118 + src/lib/fp16_common.h | 290 +- src/lib/halfgemm.cpp | 247 + src/lib/halfgemm.h | 197 + src/lib/halfgemm_kernel_neon_fp16.cpp | 3174 +++++++++ src/lib/hgemm_kernel_neon.cpp | 30 + src/lib/hqnbitgemm_kernel_neon_fp16.cpp | 33 - .../intrinsics/avx2/saturation_check_avx2.cpp | 62 + src/lib/kai_ukernel_interface.cpp | 81 + src/lib/kai_ukernel_interface.h | 12 + src/lib/mlasi.h | 52 +- src/lib/platform.cpp | 26 +- src/lib/pooling_fp16.cpp | 20 +- src/lib/q4_dq.cpp | 519 +- src/lib/q4common.h | 1 - src/lib/q4gemm_avx512.cpp | 1 - src/lib/qgemm.cpp | 13 +- src/lib/qgemm.h | 8 + src/lib/qgemm_kernel_wasmrelaxedsimd.cpp | 563 ++ src/lib/qnbitgemm.cpp | 261 +- src/lib/qnbitgemm_kernel_neon.cpp | 188 +- src/lib/qnbitgemm_kernel_neon.h | 37 + src/lib/quantize.cpp | 4 +- src/lib/rotary_embedding.cpp | 108 + src/lib/rotary_embedding.h | 57 + src/lib/rotary_embedding_kernel_avx2.cpp | 308 + src/lib/rotary_embedding_kernel_avx2.h | 47 + src/lib/rotary_embedding_kernel_neon.cpp | 32 + src/lib/rotary_embedding_kernel_neon.h | 37 + src/lib/rotary_embedding_kernel_neon_fp16.cpp | 279 + src/lib/saturation_check.cpp | 42 + src/lib/sgemm.cpp | 9 +- src/lib/softmax.h | 129 + src/lib/softmax_kernel_neon.cpp | 38 + src/lib/softmax_kernel_neon.h | 40 + src/lib/softmax_kernel_neon_fp16.cpp | 917 +++ src/lib/sqnbitgemm_kernel_avx2.cpp | 133 +- .../sqnbitgemm_kernel_avx2_int8_blklen16.h | 699 ++ .../sqnbitgemm_kernel_avx2_int8_blklen32.h | 709 ++ .../sqnbitgemm_kernel_avx2_int8_blklen64.h | 674 ++ src/lib/sqnbitgemm_kernel_avx512.cpp | 135 +- .../sqnbitgemm_kernel_avx512_int8_blklen128.h | 559 ++ .../sqnbitgemm_kernel_avx512_int8_blklen16.h | 703 +- .../sqnbitgemm_kernel_avx512_int8_blklen32.h | 805 ++- .../sqnbitgemm_kernel_avx512_int8_blklen64.h | 707 +- src/lib/sqnbitgemm_kernel_avx512vnni.cpp | 135 +- src/lib/sqnbitgemm_kernel_avx_common.h | 162 +- src/lib/sqnbitgemm_kernel_neon_int8.cpp | 85 +- src/lib/sqnbitgemm_q8_block.h | 2 +- src/lib/tanh.cpp | 57 +- src/lib/transpose.cpp | 338 +- src/lib/x86_64/ConvSymKernelAvx2.S | 91 + src/ort_include/core/common/common.h | 34 +- src/ort_include/core/common/cpuid_info.h | 17 + src/ort_include/core/common/exceptions.h | 47 +- src/ort_include/core/common/logging/logging.h | 9 +- src/ort_include/core/common/parse_string.h | 40 +- src/ort_include/core/common/profiler_common.h | 4 +- src/ort_include/core/common/spin_pause.h | 17 +- src/ort_include/core/common/status.h | 12 +- src/ort_include/core/framework/callback.h | 74 + src/ort_include/core/framework/float16.h | 8 +- .../platform/EigenNonBlockingThreadPool.h | 152 +- src/ort_include/core/platform/env.h | 8 +- .../core/session/onnxruntime_c_api.h | 6293 ++++++++++++++++- src/ort_include/core/util/thread_utils.h | 33 +- tests/bench/bench_cast.cpp | 54 + tests/bench/bench_computesoftmax.cpp | 12 - tests/bench/bench_hgemm.cpp | 89 + tests/bench/bench_qnbitgemm.cpp | 9 +- tests/bench/bench_rope.cpp | 58 + tests/unittest/test_blockq4.cpp | 255 +- tests/unittest/test_eltwise.cpp | 106 + tests/unittest/test_halfgemm.h | 2 +- tests/unittest/test_hgemm_neon.cpp | 683 ++ tests/unittest/test_rope.cpp | 141 + tests/unittest/test_softcap.cpp | 112 + tests/unittest/test_softmax.cpp | 202 + tests/unittest/test_sq8bitgemm.cpp | 521 ++ tests/unittest/test_sqnbitgemm.cpp | 10 +- tests/unittest/test_sqnbitgemm_neon_fp16.cpp | 2 +- tests/unittest/test_transpose.cpp | 30 +- tests/unittest/test_util.h | 38 +- 108 files changed, 24749 insertions(+), 2446 deletions(-) create mode 100644 src/core/platform/check_intel.h rename src/lib/{fp16_neon_common.cpp => cast_kernel_neon.cpp} (99%) create mode 100644 src/lib/eltwise.cpp create mode 100644 src/lib/eltwise.h create mode 100644 src/lib/eltwise_kernel_neon.cpp create mode 100644 src/lib/eltwise_kernel_neon.h create mode 100644 src/lib/eltwise_kernel_neon_fp16.cpp create mode 100644 src/lib/halfgemm_kernel_neon_fp16.cpp create mode 100644 src/lib/hgemm_kernel_neon.cpp create mode 100644 src/lib/intrinsics/avx2/saturation_check_avx2.cpp create mode 100644 src/lib/kai_ukernel_interface.cpp create mode 100644 src/lib/kai_ukernel_interface.h create mode 100644 src/lib/qgemm_kernel_wasmrelaxedsimd.cpp create mode 100644 src/lib/rotary_embedding.cpp create mode 100644 src/lib/rotary_embedding.h create mode 100644 src/lib/rotary_embedding_kernel_avx2.cpp create mode 100644 src/lib/rotary_embedding_kernel_avx2.h create mode 100644 src/lib/rotary_embedding_kernel_neon.cpp create mode 100644 src/lib/rotary_embedding_kernel_neon.h create mode 100644 src/lib/rotary_embedding_kernel_neon_fp16.cpp create mode 100644 src/lib/saturation_check.cpp create mode 100644 src/lib/softmax.h create mode 100644 src/lib/softmax_kernel_neon.cpp create mode 100644 src/lib/softmax_kernel_neon.h create mode 100644 src/lib/softmax_kernel_neon_fp16.cpp create mode 100644 tests/bench/bench_cast.cpp create mode 100644 tests/bench/bench_hgemm.cpp create mode 100644 tests/bench/bench_rope.cpp create mode 100644 tests/unittest/test_eltwise.cpp create mode 100644 tests/unittest/test_hgemm_neon.cpp create mode 100644 tests/unittest/test_rope.cpp create mode 100644 tests/unittest/test_softcap.cpp create mode 100644 tests/unittest/test_sq8bitgemm.cpp diff --git a/cmake/deps.txt b/cmake/deps.txt index 1524485..fa296d6 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -1,4 +1,4 @@ -eigen;https://gitlab.com/libeigen/eigen/-/archive/ff174f79264d3f8dc0115dea7a288f98208b694f/eigen-ff174f79264d3f8dc0115dea7a288f98208b694f.zip;e06074b74725f2677369be2eb2e97e57e2dc4353 +eigen;https://gitlab.com/libeigen/eigen/-/archive/ff174f79264d3f8dc0115dea7a288f98208b694f/eigen-ff174f79264d3f8dc0115dea7a288f98208b694f.zip;666e2f940faeef0196e72617a5d01241a22b67f3 microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 googletest;https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip;9d2d0af8d77ac726ea55d44a8fa727ec98311349 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177 diff --git a/include/mlas.h b/include/mlas.h index 28ae64c..2663709 100644 --- a/include/mlas.h +++ b/include/mlas.h @@ -63,7 +63,10 @@ Module Name: #endif #if defined(__wasm__) #define MLAS_TARGET_WASM -#if defined(__wasm_simd128__) +#if defined(__wasm_relaxed_simd__) +#define MLAS_TARGET_WASM_RELAXED_SIMD +#define MLAS_TARGET_WASM_SIMD +#elif defined(__wasm_simd128__) #define MLAS_TARGET_WASM_SIMD #else #define MLAS_TARGET_WASM_SCALAR @@ -990,11 +993,12 @@ MlasComputeErf( size_t N ); +template void MLASCALL MlasComputeExp( - const float* Input, - float* Output, + const T* Input, + T* Output, size_t N ); @@ -1006,11 +1010,12 @@ MlasComputeLogistic( size_t N ); +template void MLASCALL MlasComputeSoftmax( - const float* Input, - float* Output, + const T* Input, + T* Output, size_t N, size_t D, bool LogSoftmax, @@ -1018,61 +1023,48 @@ MlasComputeSoftmax( MLAS_THREADPOOL* ThreadPool ); +template void MLASCALL -MlasComputeTanh( - const float* Input, - float* Output, - size_t N - ); - -// -// Transpose routines. -// - -void -MLASCALL -MlasTranspose( - const uint8_t* Input, - uint8_t* Output, - size_t M, - size_t N +MlasComputeSoftcap( + const T* Input, + T* Output, + size_t N, + T cap ); +template void MLASCALL -MlasTranspose( - const int8_t* Input, - int8_t* Output, - size_t M, +MlasEltwiseAdd( + const T* left, + const T* right, + T* output, size_t N ); +template void MLASCALL -MlasTranspose( - const uint16_t* Input, - uint16_t* Output, - size_t M, +MlasComputeTanh( + const T* Input, + T* Output, size_t N ); -void -MLASCALL -MlasTranspose( - const uint32_t* Input, - uint32_t* Output, - size_t M, - size_t N - ); +// +// Transpose routines. +// +template void MLASCALL MlasTranspose( - const float* Input, - float* Output, + const DataType* Input, + DataType* Output, size_t M, - size_t N + size_t N, + MLAS_THREADPOOL* ThreadPool ); // @@ -1435,7 +1427,141 @@ MLAS_FP16* Destination, size_t Count ); - /** +/** + * @brief rotary embedding for one hidden state vector + * + * @tparam T: data type of input, sin, cos and output. Currently only float32/16 are supported. + * @param input: input tensor, of shape [dim] + * @param sin: sin tensor, of shape [dim/2] + * @param cos: cos tensor, of shape [dim/2] + * @param dim: dimension of rotary embedding + * @param interleaved: whether the real part and imaginary parts are interleaved + * @param output: output tensor, of shape [dim] + */ +template +void +MLASCALL +MlasRotaryEmbedOneRow( + const T* input, + const T* sin_data, + const T* cos_data, + size_t dim, + bool interleaved, + T* output +); + +/** + * @brief Supply matrices data information to half precision gemm functions + */ +struct MLAS_HGEMM_DATA_PARAMS { + const MLAS_FP16* A; /**< Supplies the address of matrix A */ + size_t lda; /**< Supplies the first dimension of matrix A. */ + const MLAS_FP16* B; /**< Supplies the address of matrix B */ + size_t ldb; /**< Supplies the first dimension of matrix B. */ + MLAS_FP16* C; /**< Supplies the address of matrix C */ + size_t ldc; /**< Supplies the first dimension of matrix C. */ + uint16_t alpha; /**< Supplies the scalar alpha multiplier (see GEMM definition). FP16 encoding. */ + uint16_t beta; /**< Supplies the scalar beta multiplier (see GEMM definition). FP16 encoding. */ +}; + +/** + * @brief Check whether current CPU supports half precision gemm. + */ +bool +MLASCALL +MlasHGemmSupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB + ); + +/** + * @brief Check whether mlas supports GQA kernels with the type and transpose settings. + */ +template +bool +MLASCALL +MlasGQASupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB + ); + +/** + * @brief Batched half precision matrix/matrix multiply operation (HGEMM) + * + * @param TransA Supplies the transpose operation for matrix A. + * @param TransB Supplies the transpose operation for matrix B. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number of rows of matrix B. + * @param Data A array of matrices data parameters + * @param BatchSize Supplies number of multiplications in this batch + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + */ +void +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_HGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + +/** + * @brief half precision matrix/matrix multiply operation (HGEMM) + * C = alpha * op(A) * op(B) + beta * C + * + * @param TransA Supplies the transpose operation for matrix A. Currently only support CblasNoTrans. + * @param TransB Supplies the transpose operation for matrix B. Currently only support CblasTrans. + * @param M Supplies the number of rows of matrix A and matrix C. + * @param N Supplies the number of columns of matrix B and matrix C. + * @param K Supplies the number of columns of matrix A and the number of rows of matrix B. + * @param A Supplies the address of matrix A + * @param lda Supplies the first dimension of matrix A. + * @param B Supplies the address of matrix B + * @param ldb Supplies the first dimension of matrix B. + * @param C Supplies the address of matrix C + * @param ldc Supplies the first dimension of matrix C. + * @param alpha Supplies the scalar alpha multiplier (see GEMM definition) + * @param beta Supplies the scalar beta multiplier (see GEMM definition) + * @param ThreadPool Supplies the thread pool object to use, else nullptr if the base library threading support + * should be used. + */ +inline +void +MlasGemm( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_FP16* A, + size_t lda, + const MLAS_FP16* B, + size_t ldb, + MLAS_FP16* C, + size_t ldc, + uint16_t alpha, + uint16_t beta, + MLAS_THREADPOOL* ThreadPool +) { + MLAS_HGEMM_DATA_PARAMS Data; + Data.A = A; + Data.lda = lda; + Data.B = B; + Data.ldb = ldb; + Data.C = C; + Data.ldc = ldc; + Data.alpha = alpha; + Data.beta = beta; + MlasGemmBatch(TransA, TransB, M, N, K, &Data, 1, ThreadPool); +} + +/** * @brief Whether current CPU supports FP16 acceleration. */ bool MLASCALL @@ -1780,20 +1906,22 @@ MlasConvDepthwise( MLAS_HALF_GEMM_POSTPROCESSOR* PostProc ); - inline void MlasTranspose( const MLAS_FP16* Input, MLAS_FP16* Output, size_t M, - size_t N + size_t N, + MLAS_THREADPOOL* ThreadPool ) { MlasTranspose( reinterpret_cast(Input), reinterpret_cast(Output), - M, N); + M, + N, + ThreadPool); } diff --git a/include/mlas_gemm_postprocessor.h b/include/mlas_gemm_postprocessor.h index 8c24705..7f5ec05 100644 --- a/include/mlas_gemm_postprocessor.h +++ b/include/mlas_gemm_postprocessor.h @@ -17,7 +17,6 @@ Module Name: #pragma once #include - template class MLAS_GEMM_POSTPROCESSOR { diff --git a/include/mlas_q4.h b/include/mlas_q4.h index aec1407..c5f846f 100644 --- a/include/mlas_q4.h +++ b/include/mlas_q4.h @@ -266,7 +266,7 @@ MlasBlockwiseQuantizedShape( /** * @brief Compute the sizes of the quantized data and quantization parameter buffers. * - * @param qbits The bit width of each quantized value. + * @tparam qbits The bit width of each quantized value. * @param block_size The number of quantized values in a block. * @param columnwise Whether a block contains values from a matrix column (true) or row (false). * @param rows Number of matrix rows. @@ -277,9 +277,9 @@ MlasBlockwiseQuantizedShape( * * If the qbits or block_size values are unsupported the output sizes will be zero. */ +template void MLASCALL MlasBlockwiseQuantizedBufferSizes( - int qbits, int block_size, bool columnwise, int rows, diff --git a/include/mlas_qnbit.h b/include/mlas_qnbit.h index 9608644..3627989 100644 --- a/include/mlas_qnbit.h +++ b/include/mlas_qnbit.h @@ -123,6 +123,7 @@ MlasIsQNBitGemmAvailable( * @param[in] BatchN number of batches * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL @@ -133,6 +134,7 @@ MlasQNBitGemmBatchWorkspaceSize( size_t BatchN, size_t BlkBitWidth, size_t BlkLen, + bool HasZeroPoint, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); @@ -147,6 +149,7 @@ MlasQNBitGemmBatchWorkspaceSize( * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ size_t MLASCALL @@ -155,6 +158,7 @@ MlasQNBitGemmPackQuantBDataSize( size_t K, size_t BlkBitWidth, size_t BlkLen, + bool HasZeroPoint, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); @@ -181,7 +185,7 @@ MlasQNBitGemmPackQuantBDataSize( * @param[in] QuantBData quantized B data * @param[in] PackedQuantBDataAndOrBlkSum buffer to store packed quantized B data and/or BlkSum * @param[in] QuantBScale quantized B scale - * @param[in] has_zp_input whether QuantBZeroPoint is provided + * @param[in] HasZeroPoint whether QuantBZeroPoint is provided * @param[in] QuantBZeroPoint quantized B zero point * @param[in] ThreadPool thread pool to use (no parallel if nullptr) */ @@ -195,7 +199,25 @@ MlasQNBitGemmPackQuantBData( const void* QuantBData, void* PackedQuantBDataAndOrBlkSum, const void* QuantBScale, - bool has_zp_input, + bool HasZeroPoint, const void* QuantBZeroPoint, MLAS_THREADPOOL* ThreadPool ); + +/** + * @brief Returns true if scales are packed when calling MlasQNBitGemmPackQuantBData the first time. + * + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkBitWidth quantized value bit width (e.g., 4 means 4 bit ints) + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + * @param[in] HasZeroPoint whether QuantBZeroPoint is provided + */ +bool MLASCALL +MlasQNBitGemmScalesPacked( + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + bool HasZeroPoint +); diff --git a/include/qnbitgemm.h b/include/qnbitgemm.h index eb3d0b4..4c13310 100644 --- a/include/qnbitgemm.h +++ b/include/qnbitgemm.h @@ -46,18 +46,19 @@ MlasAlignAddress(void* addr, const size_t alignment) return addr; } -template +template struct PackedQuantBDataStruct { PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen) : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) { - // TODO: duplicate code from Q4BitGemmPackQuantBDataSize - constexpr size_t BlkBitWidth = 4; const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); - - // _mm256_load_si256 requires alignment on a 32-byte boundary - PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); +#if defined(MLAS_TARGET_AMD64_IX86) + // avx512 requires alignment on a 64-byte boundary + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 64); +#else + PackedQuantBData = (std::byte*)PackedQuantBWorkspace; +#endif QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize); QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); @@ -95,11 +96,23 @@ struct MLAS_QNBIT_GEMM_DISPATCH { size_t N, size_t K, size_t BlkLen, + bool HasZeroPoint, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; + /** Gets size of packed quantized B data containing 8-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ + typedef size_t(Q8BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + Q8BitGemmPackQuantBDataSize_Fn* Q8BitGemmPackQuantBDataSize = nullptr; + /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ typedef void(Q4BitGemmPackQuantBData_Fn)( size_t N, @@ -121,14 +134,29 @@ struct MLAS_QNBIT_GEMM_DISPATCH { MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, - bool has_zp_input, + bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ); SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndBlkSum = nullptr; + typedef void(SQ8BitGemmPackQuantBDataAndSumBlk_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool + ); + + SQ8BitGemmPackQuantBDataAndSumBlk_Fn* SQ8BitGemmPackQuantBDataAndBlkSum = nullptr; + // // Workspace size calculation function prototypes. // @@ -141,17 +169,19 @@ struct MLAS_QNBIT_GEMM_DISPATCH { * @param[in] N column size of matrix B and C * @param[in] K column size of matrix A and row size of matrix B * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(Q4BitGemmPerGemmWorkspaceSize_Fn)( + typedef size_t(QNBitGemmPerGemmWorkspaceSize_Fn)( size_t M, size_t N, size_t K, size_t BlkLen, + bool HasZeroPoint, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr; + QNBitGemmPerGemmWorkspaceSize_Fn* QNBitGemmPerGemmWorkspaceSize = nullptr; /** * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. @@ -159,12 +189,12 @@ struct MLAS_QNBIT_GEMM_DISPATCH { * @param[in] BlkLen number of quantized values per block * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) */ - typedef size_t(Q4BitGemmPerGemmWorkspaceAlignment_Fn)( + typedef size_t(QNBitGemmPerGemmWorkspaceAlignment_Fn)( size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ); - Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr; + QNBitGemmPerGemmWorkspaceAlignment_Fn* QNBitGemmPerGemmWorkspaceAlignment = nullptr; // // SQNBIT_CompFp32 kernel function prototypes. @@ -267,6 +297,39 @@ struct MLAS_QNBIT_GEMM_DISPATCH { // SQNBIT_CompInt8 kernel function prototypes. // + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * A should be packed using QuantizeA_Packed_CompInt8. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param PackedQuantBData Supplies the packed quantized B matrix data. + * @param[out] C Supplies the output C matrix. + * @param RangeStartM Start of M range. + * @param RangeCountM Number of rows of A and C. + * @param RangeStartN Start of N range. + * @param RangeCountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param ldc Number of elements between adjacent rows of C. + */ + typedef void(SQ4BitGemmKernel_Packed_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* PackedQuantBData, + float* C, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN, + size_t CountK, + size_t ldc, + const float* Bias + ); + + SQ4BitGemmKernel_Packed_CompInt8_Fn* SQ4BitGemmKernel_Packed_CompInt8 = nullptr; + /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. @@ -306,6 +369,45 @@ struct MLAS_QNBIT_GEMM_DISPATCH { SQ4BitGemmKernel_BlkSum_CompInt8_Fn* SQ4BitGemmKernel_BlkSum_CompInt8 = nullptr; + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 8-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + */ + typedef size_t(SQ8BitGemmKernel_BlkSum_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum + ); + + SQ8BitGemmKernel_BlkSum_CompInt8_Fn* SQ8BitGemmKernel_BlkSum_CompInt8 = nullptr; + /** * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. * A and B are block quantized and B is column major. @@ -343,6 +445,38 @@ struct MLAS_QNBIT_GEMM_DISPATCH { SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr; + /** + * @brief Whether to use SQ4BitGemmKernel_Packed_CompInt8 for this problem. + */ + typedef bool(UsePacked_CompInt8_Fn)( + size_t K, + size_t BlkLen, + bool HasZp + ); + + UsePacked_CompInt8_Fn* UsePacked_CompInt8 = nullptr; + + /** + * @brief Block quantize values from matrix A from floats to quantized 8-bit integers. + * Used in conjunction with SQ4BitGemmKernel_Packed_CompInt8. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param CountM Number of rows of A. + * @param CountK Number of columns of A. + * @param[out] QuantA Supplies the output quantized A matrix. + * Binary data containing block quantized int8 data and scale values. + */ + typedef void(QuantizeA_Packed_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountM, + size_t CountK, + std::byte* QuantA + ); + + QuantizeA_Packed_CompInt8_Fn* QuantizeA_Packed_CompInt8 = nullptr; + /** * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. * diff --git a/src/common/cpuid_info.cc b/src/common/cpuid_info.cc index 04172e4..aec3ee5 100644 --- a/src/common/cpuid_info.cc +++ b/src/common/cpuid_info.cc @@ -5,7 +5,9 @@ #include "core/common/logging/severity.h" #ifdef __linux__ - +#if (defined(_M_AMD64) || defined(__x86_64__)) && !defined(__ANDROID__) +#include +#endif #include #include #if !defined(__NR_getcpu) @@ -42,8 +44,6 @@ #include -#define HAS_WINDOWS_DESKTOP WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP) - #ifndef PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE #define PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE 43 #endif @@ -107,6 +107,9 @@ void CPUIDInfo::X86Init() { int data[4] = {-1}; GetCPUID(0, data); + vendor_ = GetX86Vendor(data); + vendor_id_ = GetVendorId(vendor_); + int num_IDs = data[0]; if (num_IDs >= 1) { GetCPUID(1, data); @@ -141,8 +144,24 @@ void CPUIDInfo::X86Init() { } } +std::string CPUIDInfo::GetX86Vendor(int32_t* data) { + char vendor[sizeof(int32_t) * 3 + 1]{}; + *reinterpret_cast(vendor + 0) = data[1]; + *reinterpret_cast(vendor + 4) = data[3]; + *reinterpret_cast(vendor + 8) = data[2]; + return vendor; +} + #endif // defined(CPUIDINFO_ARCH_X86) +uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { + if (vendor == "GenuineIntel") return 0x8086; + if (vendor == "GenuineAMD") return 0x1022; + if (vendor.find("Qualcomm") == 0) return 'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24); + if (vendor.find("NV") == 0) return 0x10DE; + return 0; +} + #if defined(CPUIDINFO_ARCH_ARM) #if defined(__linux__) @@ -194,21 +213,26 @@ void CPUIDInfo::ArmLinuxInit() { #elif defined(_WIN32) // ^ defined(__linux__) void CPUIDInfo::ArmWindowsInit() { -// ARM32 certainly doesn't have fp16, so we will skip the logic to avoid using RegGetValueA Windows API -#if !defined(_M_ARM) -#pragma region Application Family or OneCore Family -#if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) - // Read MIDR from windows registry + // Get the ARM vendor string from the registry + vendor_ = GetArmWindowsVendor(); + vendor_id_ = GetVendorId(vendor_); + + // Read MIDR and ID_AA64ISAR1_EL1 register values from Windows registry + // There should be one per CPU + std::vector midr_values{}, id_aa64isar1_el1_values{}; + // TODO!! Don't support multiple processor group yet!! constexpr int MAX_CORES = 64; constexpr int MAX_VALUE_NAME = 4096; - CHAR midrKey[MAX_VALUE_NAME] = ""; // buffer for processor registry name - uint32_t lastUarch = cpuinfo_uarch_unknown; - for (int i = 0; i < MAX_CORES - 1; i++) { - snprintf(midrKey, MAX_VALUE_NAME, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\%d", i); - uint64_t midrVal; - unsigned long midrSize = sizeof(uint64_t); + CHAR processor_subkey[MAX_VALUE_NAME] = ""; // buffer for processor registry name + + for (size_t i = 0; i < MAX_CORES - 1; i++) { + snprintf(processor_subkey, MAX_VALUE_NAME, "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\%d", + static_cast(i)); + + uint64_t midr_value; + unsigned long data_size = sizeof(midr_value); /* * ARM lists for each coprocessor register 5 fields: op0/op1/CRn/CRm/op2. @@ -223,48 +247,74 @@ void CPUIDInfo::ArmWindowsInit() { * * For the CP value of MIDR, op0 = 3 and the others are all = 0, so we come up with 0x4000, */ - auto retCode = ::RegGetValueA(HKEY_LOCAL_MACHINE, midrKey, "CP 4000", RRF_RT_REG_QWORD, nullptr, &midrVal, &midrSize); - if (retCode != ERROR_SUCCESS) { + if (::RegGetValueA(HKEY_LOCAL_MACHINE, processor_subkey, "CP 4000", RRF_RT_REG_QWORD, + nullptr, &midr_value, &data_size) != ERROR_SUCCESS) { break; } - uint32_t uarch = cpuinfo_uarch_unknown; - decodeMIDR((uint32_t)midrVal, &uarch); - core_uarchs_.push_back(uarch); - if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || - uarch == cpuinfo_uarch_cortex_a55) { - is_armv8_narrow_ld_.push_back(true); - } else { - is_armv8_narrow_ld_.push_back(false); + + uint64_t id_aa64isar1_el1_value; + data_size = sizeof(id_aa64isar1_el1_value); + + // CP 4031 corresponds to ID_AA64ISAR1_EL1 register + if (::RegGetValueA(HKEY_LOCAL_MACHINE, processor_subkey, "CP 4031", RRF_RT_REG_QWORD, + nullptr, &id_aa64isar1_el1_value, &data_size) != ERROR_SUCCESS) { + break; } - if (i == 0) { - lastUarch = uarch; - } else if (lastUarch != uarch) { - is_hybrid_ = true; - lastUarch = uarch; + midr_values.push_back(midr_value); + id_aa64isar1_el1_values.push_back(id_aa64isar1_el1_value); + } + + // process midr_values + { + uint32_t lastUarch = cpuinfo_uarch_unknown; + for (size_t i = 0; i < midr_values.size(); ++i) { + uint32_t uarch = cpuinfo_uarch_unknown; + decodeMIDR(static_cast(midr_values[i]), &uarch); + core_uarchs_.push_back(uarch); + if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || + uarch == cpuinfo_uarch_cortex_a55) { + is_armv8_narrow_ld_.push_back(true); + } else { + is_armv8_narrow_ld_.push_back(false); + } + + if (i == 0) { + lastUarch = uarch; + } else if (lastUarch != uarch) { + is_hybrid_ = true; + lastUarch = uarch; + } } } -#endif // WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_APP | WINAPI_PARTITION_SYSTEM) + + has_arm_neon_i8mm_ = std::all_of( + id_aa64isar1_el1_values.begin(), id_aa64isar1_el1_values.end(), + [](uint64_t id_aa64isar1_el1_value) { + // I8MM, bits [55:52] + return ((id_aa64isar1_el1_value >> 52) & 0xF) != 0; + }); has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#else // ^ !defined(_M_ARM) / v defined(_M_ARM) - has_arm_neon_dot_ = false; -#endif // defined(_M_ARM) #if defined(CPUINFO_SUPPORTED) if (pytorch_cpuinfo_init_) { has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); - has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); - has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + // cpuinfo_has_arm_i8mm() doesn't work on Windows yet. See https://github.com/pytorch/cpuinfo/issues/279. + // has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && has_arm_neon_i8mm_; has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); - } else -#endif // defined(CPUINFO_SUPPORTED) - { - has_fp16_ = false; - has_arm_neon_i8mm_ = false; - has_arm_sve_i8mm_ = false; - has_arm_neon_bf16_ = false; } +#endif // defined(CPUINFO_SUPPORTED) +} + +std::string CPUIDInfo::GetArmWindowsVendor() { + const int MAX_VALUE_NAME = 256; + const CHAR vendorKey[] = "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"; + CHAR vendorVal[MAX_VALUE_NAME] = ""; + unsigned long vendorSize = sizeof(char) * MAX_VALUE_NAME; + ::RegGetValueA(HKEY_LOCAL_MACHINE, vendorKey, "Vendor Identifier", RRF_RT_REG_SZ | RRF_ZEROONFAILURE, nullptr, &vendorVal, &vendorSize); + return vendorVal; } #elif defined(__APPLE__) // ^ defined(_WIN32) diff --git a/src/common/cpuid_uarch.cc b/src/common/cpuid_uarch.cc index 16634b2..28a3524 100644 --- a/src/common/cpuid_uarch.cc +++ b/src/common/cpuid_uarch.cc @@ -30,9 +30,11 @@ inline static uint32_t midr_get_part(uint32_t midr) { return (midr & CPUINFO_ARM_MIDR_PART_MASK) >> CPUINFO_ARM_MIDR_PART_OFFSET; } +#if 0 inline static uint32_t midr_get_variant(uint32_t midr) { return (midr & CPUINFO_ARM_MIDR_VARIANT_MASK) >> CPUINFO_ARM_MIDR_VARIANT_OFFSET; } +#endif void decodeMIDR( uint32_t midr, @@ -137,8 +139,8 @@ void decodeMIDR( *uarch = cpuinfo_uarch_arm11; break; // #endif /* ARM */ - default: - std::cerr << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown ARM CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } } break; @@ -156,8 +158,8 @@ void decodeMIDR( *uarch = cpuinfo_uarch_thunderx2; break; // #endif - default: - std::cerr << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Broadcom CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if (defined(_M_ARM64) || defined(__aarch64__)) && !defined(__ANDROID__) @@ -172,8 +174,8 @@ void decodeMIDR( case 0x0AF: /* ThunderX2 99XX */ *uarch = cpuinfo_uarch_thunderx2; break; - default: - std::cerr << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Cavium CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif @@ -187,8 +189,8 @@ void decodeMIDR( case 0xD40: /* Kirin 980 Big/Medium cores -> Cortex-A76 */ *uarch = cpuinfo_uarch_cortex_a76; break; - default: - std::cerr << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Huawei CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -199,8 +201,8 @@ void decodeMIDR( case 6: /* PXA 3XX */ *uarch = cpuinfo_uarch_xscale; break; - default: - std::cerr << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Intel CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ @@ -215,8 +217,8 @@ void decodeMIDR( case 0x004: *uarch = cpuinfo_uarch_carmel; break; - default: - std::cerr << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Nvidia CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #if !defined(__ANDROID__) @@ -225,8 +227,8 @@ void decodeMIDR( case 0x000: *uarch = cpuinfo_uarch_xgene; break; - default: - std::cerr << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Applied Micro CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; #endif @@ -297,8 +299,8 @@ void decodeMIDR( *uarch = cpuinfo_uarch_saphira; break; // #endif /* ARM64 && !defined(__ANDROID__) */ - default: - std::cerr << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Qualcomm CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; case 'S': @@ -343,10 +345,10 @@ void decodeMIDR( */ *uarch = cpuinfo_uarch_exynos_m5; break; - default: - std::cerr << "unknown Samsung CPU variant 0x" - << std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) - << " ignored\n"; + // default: + // std::cerr << "unknown Samsung CPU variant 0x" + //<< std::hex << midr_get_variant(midr) << " part 0x" << std::hex << midr_get_part(midr) + //<< " ignored\n"; } break; // #if defined(_M_ARM) || defined(__arm__) @@ -356,13 +358,13 @@ void decodeMIDR( case 0x584: /* PJ4B-MP / PJ4C */ *uarch = cpuinfo_uarch_pj4; break; - default: - std::cerr << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; + // default: + // std::cerr << "unknown Marvell CPU part 0x" << std::hex << midr_get_part(midr) << " ignored\n"; } break; // #endif /* ARM */ - default: - std::cerr << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr << "\n"; + // default: + // std::cerr << "unknown CPU uarch from MIDR value: 0x" << std::hex << midr << "\n"; } } diff --git a/src/common/logging/logging.cc b/src/common/logging/logging.cc index 103a932..4c94a29 100644 --- a/src/common/logging/logging.cc +++ b/src/common/logging/logging.cc @@ -249,7 +249,6 @@ unsigned int GetProcessId() { #endif } - bool LoggingManager::AddSinkOfType(SinkType sink_type, std::function()> sinkFactory, logging::Severity severity) { std::lock_guard guard(sink_mutex_); diff --git a/src/common/string_utils.h b/src/common/string_utils.h index 716eed1..c2e26f6 100644 --- a/src/common/string_utils.h +++ b/src/common/string_utils.h @@ -3,6 +3,8 @@ #pragma once +#include +#include #include #include #include @@ -84,5 +86,21 @@ inline uint32_t GetHashFromString(const std::string& str_value) { return hash; } +/** + * Returns a lowercase version of the input string. + * @param str The string to lowercase. + * @return The lowercased string. + */ +inline std::string GetLowercaseString(std::string str) { + // https://en.cppreference.com/w/cpp/string/byte/tolower + // The behavior of tolower from is undefined if the argument is neither representable as unsigned char + // nor equal to EOF. To use tolower safely with a plain char (or signed char), the argument must be converted to + // unsigned char. + std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { + return static_cast(std::tolower(c)); + }); + return str; +} + } // namespace utils } // namespace onnxruntime diff --git a/src/common/threadpool.cc b/src/common/threadpool.cc index bb94f62..0fd2336 100644 --- a/src/common/threadpool.cc +++ b/src/common/threadpool.cc @@ -25,7 +25,7 @@ limitations under the License. #if !defined(ORT_MINIMAL_BUILD) #ifdef _WIN32 #include -#include +#include "processthreadsapi.h" #include #include #elif defined(__APPLE__) @@ -439,7 +439,7 @@ void ThreadPool::ParallelForFixedBlockSizeScheduling(const std::ptrdiff_t total, // threads is handled within RunInParallel, hence we can deallocate lc and other state captured by // run_work. RunInParallel(run_work, num_work_items, block_size); - } + } } void ThreadPool::SimpleParallelFor(std::ptrdiff_t total, const std::function& fn) { diff --git a/src/core/platform/check_intel.h b/src/core/platform/check_intel.h new file mode 100644 index 0000000..1b82940 --- /dev/null +++ b/src/core/platform/check_intel.h @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +typedef struct { + bool is_intel; + bool is_intel_specified_platform; +} CheckIntelResult; + +CheckIntelResult CheckIntel(); +} // namespace onnxruntime diff --git a/src/core/platform/posix/env.cc b/src/core/platform/posix/env.cc index b956268..2bb17f3 100644 --- a/src/core/platform/posix/env.cc +++ b/src/core/platform/posix/env.cc @@ -148,14 +148,13 @@ class PosixThread : public EnvThread { unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param, const ThreadOptions& thread_options) { ORT_ENFORCE(index >= 0, "Negative thread index is not allowed"); - - + auto param_ptr = std::make_unique(name_prefix, index, start_address, param); if (narrow(index) < thread_options.affinities.size()) { param_ptr->affinity = thread_options.affinities[index]; } - { + { pthread_attr_t attr; int s = pthread_attr_init(&attr); if (s != 0) { @@ -183,7 +182,7 @@ class PosixThread : public EnvThread { } ~PosixThread() override { - { + { void* res; #ifdef NDEBUG pthread_join(hThread, &res); @@ -412,7 +411,6 @@ class PosixEnv : public Env { return Status::OK(); } - static common::Status ReportSystemError(const char* operation_name, const std::string& path) { auto [err_no, err_msg] = GetErrnoInfo(); std::ostringstream oss; @@ -483,7 +481,6 @@ class PosixEnv : public Env { return filename; } - // \brief returns a value for the queried variable name (var_name) std::string GetEnvironmentVar(const std::string& var_name) const override { char* val = getenv(var_name.c_str()); diff --git a/src/core/platform/windows/env.cc b/src/core/platform/windows/env.cc index f7b063f..300acc9 100644 --- a/src/core/platform/windows/env.cc +++ b/src/core/platform/windows/env.cc @@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, @@ -27,630 +27,762 @@ limitations under the License. #include #include -#include -#include "core/session/onnxruntime_c_api.h" +#include "core/common/logging/logging.h" #include "core/common/narrow.h" #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" - +#if defined(_M_X64) && !defined(_M_ARM64EC) +#include "core/platform/windows/hardware_core_enumerator.h" +#endif #include +#include #include "core/platform/path_lib.h" // for LoopDir() +#include "core/platform/windows/dll_load_error.h" EXTERN_C IMAGE_DOS_HEADER __ImageBase; namespace onnxruntime { - class UnmapFileParam { - public: - void* addr; - size_t len; - }; - - - - std::wstring Basename(const std::wstring& path) { - auto basename_index = path.find_last_of(L"/\\") + 1; // results in 0 if no separator is found - return path.substr(basename_index); - } - - class WindowsThread : public EnvThread { - private: - struct Param { - const ORTCHAR_T* name_prefix; - int index; - unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param); - Eigen::ThreadPoolInterface* param; - std::optional affinity; - Param(const ORTCHAR_T* name_prefix1, - int index1, - unsigned (*start_address1)(int id, Eigen::ThreadPoolInterface* param), - Eigen::ThreadPoolInterface* param1) - : name_prefix(name_prefix1), - index(index1), - start_address(start_address1), - param(param1) {} - }; - - public: - WindowsThread(const ORTCHAR_T* name_prefix, int index, - unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param, - const ThreadOptions& thread_options) { - ORT_ENFORCE(index >= 0, "Negative thread index is not allowed"); - - - std::unique_ptr local_param = std::make_unique(name_prefix, index, start_address, param); - if (narrow(index) < thread_options.affinities.size()) { - local_param->affinity = thread_options.affinities[index]; - } - - { - _set_errno(0); - _set_doserrno(0); - auto th_handle = _beginthreadex(nullptr, thread_options.stack_size, ThreadMain, - local_param.get(), 0, - &threadID); - if (th_handle == 0) { - auto dos_error = _doserrno; - auto [err, msg] = GetErrnoInfo(); - ORT_THROW("WindowThread:_beginthreadex failed with errno:", err, " message:", msg, - " doserrno:", dos_error); - } - local_param.release(); - hThread.reset(reinterpret_cast(th_handle)); - // Do not throw beyond this point so we do not lose thread handle and then not being able to join it. - } - } - - ~WindowsThread() { - { - DWORD waitStatus = WaitForSingleObject(hThread.get(), INFINITE); - assert(waitStatus != WAIT_FAILED); - } - } - - private: - typedef HRESULT(WINAPI* SetThreadDescriptionFunc)(HANDLE hThread, PCWSTR lpThreadDescription); +class UnmapFileParam { + public: + void* addr; + size_t len; +}; + +static void UnmapFile(void* param) noexcept { + std::unique_ptr p(reinterpret_cast(param)); + bool ret = UnmapViewOfFile(p->addr); + if (!ret) { + const auto error_code = GetLastError(); + LOGS_DEFAULT(ERROR) << "unmap view of file failed. error code: " << error_code + << " error msg: " << std::system_category().message(error_code); + } +} + +std::wstring Basename(const std::wstring& path) { + auto basename_index = path.find_last_of(L"/\\") + 1; // results in 0 if no separator is found + return path.substr(basename_index); +} + +class WindowsThread : public EnvThread { + private: + struct Param { + const ORTCHAR_T* name_prefix; + int index; + unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param); + Eigen::ThreadPoolInterface* param; + std::optional affinity; + Param(const ORTCHAR_T* name_prefix1, + int index1, + unsigned (*start_address1)(int id, Eigen::ThreadPoolInterface* param), + Eigen::ThreadPoolInterface* param1) + : name_prefix(name_prefix1), + index(index1), + start_address(start_address1), + param(param1) {} + }; + + public: + WindowsThread(const ORTCHAR_T* name_prefix, int index, + unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), Eigen::ThreadPoolInterface* param, + const ThreadOptions& thread_options) { + ORT_ENFORCE(index >= 0, "Negative thread index is not allowed"); + + std::unique_ptr local_param = std::make_unique(name_prefix, index, start_address, param); + if (narrow(index) < thread_options.affinities.size()) { + local_param->affinity = thread_options.affinities[index]; + } + + { + _set_errno(0); + _set_doserrno(0); + auto th_handle = _beginthreadex(nullptr, thread_options.stack_size, ThreadMain, + local_param.get(), 0, + &threadID); + if (th_handle == 0) { + auto dos_error = _doserrno; + auto [err, msg] = GetErrnoInfo(); + ORT_THROW("WindowThread:_beginthreadex failed with errno:", err, " message:", msg, + " doserrno:", dos_error); + } + local_param.release(); + hThread.reset(reinterpret_cast(th_handle)); + // Do not throw beyond this point so we do not lose thread handle and then not being able to join it. + } + } + + ~WindowsThread() { + { + DWORD waitStatus = WaitForSingleObject(hThread.get(), INFINITE); + FAIL_FAST_LAST_ERROR_IF(waitStatus == WAIT_FAILED); + } + } + + private: + typedef HRESULT(WINAPI* SetThreadDescriptionFunc)(HANDLE hThread, PCWSTR lpThreadDescription); #pragma warning(push) #pragma warning(disable : 6387) - static unsigned __stdcall ThreadMain(void* param) { - std::unique_ptr p(static_cast(param)); - - // Not all machines have kernel32.dll and/or SetThreadDescription (e.g. Azure App Service sandbox) - // so we need to ensure it's available before calling. - HMODULE kernelModule = GetModuleHandle(TEXT("kernel32.dll")); - if (kernelModule != nullptr) { - auto setThreadDescriptionFn = (SetThreadDescriptionFunc)GetProcAddress(kernelModule, "SetThreadDescription"); - if (setThreadDescriptionFn != nullptr) { - const ORTCHAR_T* name_prefix = (p->name_prefix == nullptr || wcslen(p->name_prefix) == 0) ? L"onnxruntime" - : p->name_prefix; - std::wostringstream oss; - oss << name_prefix << "-" << p->index; - // Ignore any errors - (void)(setThreadDescriptionFn)(GetCurrentThread(), oss.str().c_str()); - } - } - - unsigned ret = 0; - ORT_TRY{ - - -ret = p->start_address(p->index, p->param); - } - ORT_CATCH(...) { - p->param->Cancel(); - ret = 1; - } - return ret; - } + static unsigned __stdcall ThreadMain(void* param) { + std::unique_ptr p(static_cast(param)); + + // Not all machines have kernel32.dll and/or SetThreadDescription (e.g. Azure App Service sandbox) + // so we need to ensure it's available before calling. + HMODULE kernelModule = GetModuleHandle(TEXT("kernel32.dll")); + if (kernelModule != nullptr) { + auto setThreadDescriptionFn = (SetThreadDescriptionFunc)GetProcAddress(kernelModule, "SetThreadDescription"); + if (setThreadDescriptionFn != nullptr) { + const ORTCHAR_T* name_prefix = (p->name_prefix == nullptr || wcslen(p->name_prefix) == 0) ? L"onnxruntime" + : p->name_prefix; + std::wostringstream oss; + oss << name_prefix << "-" << p->index; + // Ignore any errors + (void)(setThreadDescriptionFn)(GetCurrentThread(), oss.str().c_str()); + } + } + + unsigned ret = 0; + ORT_TRY { + if (p->affinity.has_value() && !p->affinity->empty()) { + int group_id = -1; + KAFFINITY mask = 0; + constexpr KAFFINITY bit = 1; + const WindowsEnv& env = WindowsEnv::Instance(); + for (auto global_processor_id : *p->affinity) { + auto processor_info = env.GetProcessorAffinityMask(global_processor_id); + if (processor_info.local_processor_id > -1 && + processor_info.local_processor_id < sizeof(KAFFINITY) * CHAR_BIT) { + mask |= bit << processor_info.local_processor_id; + } else { + // Logical processor id starts from 0 internally, but in ort API, it starts from 1, + // that's why id need to increase by 1 when logging. + LOGS_DEFAULT(ERROR) << "Cannot set affinity for thread " << GetCurrentThreadId() + << ", processor " << global_processor_id + 1 << " does not exist"; + group_id = -1; + mask = 0; + break; + } + if (group_id == -1) { + group_id = processor_info.group_id; + } else if (group_id != processor_info.group_id) { + LOGS_DEFAULT(ERROR) << "Cannot set cross-group affinity for thread " + << GetCurrentThreadId() << ", first on group " + << group_id << ", then on " << processor_info.group_id; + group_id = -1; + mask = 0; + break; + } + } // for + if (group_id > -1 && mask) { + GROUP_AFFINITY thread_affinity = {}; + thread_affinity.Group = static_cast(group_id); + thread_affinity.Mask = mask; + if (SetThreadGroupAffinity(GetCurrentThread(), &thread_affinity, nullptr)) { + LOGS_DEFAULT(VERBOSE) << "SetThreadAffinityMask done for thread: " << GetCurrentThreadId() + << ", group_id: " << thread_affinity.Group + << ", mask: " << thread_affinity.Mask; + } else { + const auto error_code = GetLastError(); + LOGS_DEFAULT(ERROR) << "SetThreadAffinityMask failed for thread: " << GetCurrentThreadId() + << ", index: " << p->index + << ", mask: " << *p->affinity + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code) + << ". Specify the number of threads explicitly so the affinity is not set."; + } + } + } + + ret = p->start_address(p->index, p->param); + } + ORT_CATCH(...) { + p->param->Cancel(); + ret = 1; + } + return ret; + } #pragma warning(pop) - static void CustomThreadMain(void* param) { - std::unique_ptr p(static_cast(param)); - ORT_TRY{ - p->start_address(p->index, p->param); - } - ORT_CATCH(...) { - p->param->Cancel(); - } - } - unsigned threadID = 0; - wil::unique_handle hThread; - }; + static void CustomThreadMain(void* param) { + std::unique_ptr p(static_cast(param)); + ORT_TRY { + p->start_address(p->index, p->param); + } + ORT_CATCH(...) { + p->param->Cancel(); + } + } + unsigned threadID = 0; + wil::unique_handle hThread; +}; #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(push) #pragma warning(disable : 26409) #endif - EnvThread* WindowsEnv::CreateThread(_In_opt_z_ const ORTCHAR_T* name_prefix, int index, - unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), - Eigen::ThreadPoolInterface* param, const ThreadOptions& thread_options) { - return new WindowsThread(name_prefix, index, start_address, param, thread_options); - } +EnvThread* WindowsEnv::CreateThread(_In_opt_z_ const ORTCHAR_T* name_prefix, int index, + unsigned (*start_address)(int id, Eigen::ThreadPoolInterface* param), + Eigen::ThreadPoolInterface* param, const ThreadOptions& thread_options) { + return new WindowsThread(name_prefix, index, start_address, param, thread_options); +} #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif - Env& Env::Default() { - return WindowsEnv::Instance(); - } - - void WindowsEnv::SleepForMicroseconds(int64_t micros) const { - Sleep(static_cast(micros) / 1000); - } - - - int WindowsEnv::DefaultNumCores() { - return std::max(1, static_cast(std::thread::hardware_concurrency() / 2)); - } - - int WindowsEnv::GetNumPhysicalCpuCores() const { - return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); - - } - - std::vector WindowsEnv::GetDefaultThreadAffinities() const { - return cores_.empty() ? std::vector(DefaultNumCores(), LogicalProcessors{}) : cores_; - } - - int WindowsEnv::GetL2CacheSize() const { - return l2_cache_size_; - } - - WindowsEnv& WindowsEnv::Instance() { - static WindowsEnv default_env; - return default_env; - } - - PIDType WindowsEnv::GetSelfPid() const { - return GetCurrentProcessId(); - } - - Status WindowsEnv::GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const { - wil::unique_hfile file_handle{ - CreateFile2(file_path, FILE_READ_ATTRIBUTES, FILE_SHARE_READ, OPEN_EXISTING, NULL) }; - if (file_handle.get() == INVALID_HANDLE_VALUE) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - LARGE_INTEGER filesize; - if (!GetFileSizeEx(file_handle.get(), &filesize)) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetFileSizeEx ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - if (static_cast(filesize.QuadPart) > std::numeric_limits::max()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetFileLength: File is too large"); - } - length = static_cast(filesize.QuadPart); - return Status::OK(); - } - - common::Status WindowsEnv::GetFileLength(int fd, /*out*/ size_t& file_size) const { - using namespace common; - if (fd < 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid fd was supplied: ", fd); - } - - struct _stat buf; - int rc = _fstat(fd, &buf); - if (rc < 0) { - return Status(SYSTEM, errno); - } - - if (buf.st_size < 0) { - return ORT_MAKE_STATUS(SYSTEM, FAIL, "Received negative size from stat call"); - } - - if (static_cast(buf.st_size) > std::numeric_limits::max()) { - return ORT_MAKE_STATUS(SYSTEM, FAIL, "File is too large."); - } - - file_size = static_cast(buf.st_size); - return Status::OK(); - } - - Status WindowsEnv::ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* const file_path, const FileOffsetType offset, const size_t length, - const gsl::span buffer) const { - ORT_RETURN_IF_NOT(file_path, "file_path == nullptr"); - ORT_RETURN_IF_NOT(offset >= 0, "offset < 0"); - ORT_RETURN_IF_NOT(length <= buffer.size(), "length > buffer.size()"); - wil::unique_hfile file_handle{ - CreateFile2(file_path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, NULL) }; - if (file_handle.get() == INVALID_HANDLE_VALUE) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - - if (length == 0) - return Status::OK(); - - if (offset > 0) { - LARGE_INTEGER current_position; - current_position.QuadPart = offset; - if (!SetFilePointerEx(file_handle.get(), current_position, ¤t_position, FILE_BEGIN)) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SetFilePointerEx ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - } - - size_t total_bytes_read = 0; - while (total_bytes_read < length) { - constexpr DWORD k_max_bytes_to_read = 1 << 30; // read at most 1GB each time - const size_t bytes_remaining = length - total_bytes_read; - const DWORD bytes_to_read = static_cast(std::min(bytes_remaining, k_max_bytes_to_read)); - DWORD bytes_read; - - if (!ReadFile(file_handle.get(), buffer.data() + total_bytes_read, bytes_to_read, &bytes_read, nullptr)) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ReadFile ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - - if (bytes_read != bytes_to_read) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ReadFile ", ToUTF8String(Basename(file_path)), " fail: unexpected end"); - } - - total_bytes_read += bytes_read; - } - - return Status::OK(); - } - - - - common::Status WindowsEnv::GetCanonicalPath( - const PathString& path, - PathString& canonical_path) const { - // adapted from MSVC STL std::filesystem::canonical() implementation - // https://github.com/microsoft/STL/blob/ed3cbf36416a385828e7a5987ca52cb42882d84b/stl/inc/filesystem#L2986 - CREATEFILE2_EXTENDED_PARAMETERS param; - memset(¶m, 0, sizeof(param)); - param.dwSize = sizeof(CREATEFILE2_EXTENDED_PARAMETERS); - param.dwFileFlags = FILE_FLAG_BACKUP_SEMANTICS; - wil::unique_hfile file_handle{ CreateFile2( - path.c_str(), - FILE_READ_ATTRIBUTES, - FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, - OPEN_EXISTING, - ¶m) }; - - if (file_handle.get() == INVALID_HANDLE_VALUE) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(path)), " fail, errcode = ", - error_code, " - ", std::system_category().message(error_code)); - } - - constexpr DWORD initial_buffer_size = MAX_PATH; - std::vector result_buffer{}; - result_buffer.resize(initial_buffer_size); - - while (true) { - const DWORD result_length = GetFinalPathNameByHandleW( - file_handle.get(), - result_buffer.data(), - static_cast(result_buffer.size()), - 0); - - ORT_RETURN_IF_NOT( - result_length > 0, "GetFinalPathNameByHandle() failed: ", GetLastError()); - - if (result_length < result_buffer.size()) { // buffer is large enough - canonical_path.assign(result_buffer.data(), result_length); - break; - } - - // need larger buffer - result_buffer.resize(result_length); - } - - // update prefixes - if (canonical_path.find(ORT_TSTR(R"(\\?\)")) == 0) { - if (canonical_path.size() > 6 && - (ORT_TSTR('A') <= canonical_path[4] && canonical_path[4] <= ORT_TSTR('Z') || - ORT_TSTR('a') <= canonical_path[4] && canonical_path[4] <= ORT_TSTR('z')) && - canonical_path[5] == ORT_TSTR(':')) { - // "\\?\:" -> ":" - canonical_path.erase(0, 4); - } - else if (canonical_path.find(ORT_TSTR(R"(UNC\)"), 4) == 4) { - // "\\?\UNC\" -> "\\" - canonical_path.erase(2, 6); - } - } - - return Status::OK(); - } - - // Return the path of the executable/shared library for the current running code. This is to make it - // possible to load other shared libraries installed next to our core runtime code. - PathString WindowsEnv::GetRuntimePath() const { - wchar_t buffer[MAX_PATH]; - if (!GetModuleFileNameW(reinterpret_cast(&__ImageBase), buffer, _countof(buffer))) { - return PathString(); - } - - // Remove the filename at the end, but keep the trailing slash - PathString path(buffer); - auto slash_index = path.find_last_of(ORT_TSTR('\\')); - if (slash_index == std::string::npos) { - // Windows supports forward slashes - slash_index = path.find_last_of(ORT_TSTR('/')); - if (slash_index == std::string::npos) { - return PathString(); - } - } - return path.substr(0, slash_index + 1); - } - - Status WindowsEnv::LoadDynamicLibrary(const PathString& wlibrary_filename, bool /*global_symbols*/, void** handle) const { +Env& Env::Default() { + return WindowsEnv::Instance(); +} + +void WindowsEnv::SleepForMicroseconds(int64_t micros) const { + Sleep(static_cast(micros) / 1000); +} + +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) +static constexpr std::array kVendorID_Intel = {0x756e6547, 0x6c65746e, 0x49656e69}; // "GenuntelineI" +#endif +int WindowsEnv::DefaultNumCores() { + return std::max(1, static_cast(std::thread::hardware_concurrency() / 2)); +} + +int WindowsEnv::GetNumPhysicalCpuCores() const { +// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. +#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) + // The following code is a temporary fix for a perf problem on Intel's Meteor Lake CPUs. The Intel compute platform has + // a hybrid architecture that some CPU cores runs significant slower than the others. If we distribute our compute work + // evenly to all CPU cores, the slowest CPU core will drag the performance down. So, instead, we reduce the total number + // of threads to exclude the slowest cores out. + // The following code is based on assumptions that: + // 1. All Intel hybrid CPUs should have 3 levels of cache. + // 2. If a CPU core is only associated with two levels of cache, it should be a low performance CPU core and should + // not be used. + // Since we don't know what the next Intel hybrid CPU would be like, later on we may need to rework the following code. + // However, no matter what the code should not cause any crash. The worst is it might return 1 that + // thread pools will not be created, which is just a perf issue and does not impact usability. + // TODO: detect if CPUID instruction is available per instructions at https://wiki.osdev.org/CPUID#Checking_CPUID_availability + int regs[4]; + __cpuid(regs, 0); + bool bIsIntel = + (kVendorID_Intel[0] == regs[1]) && + (kVendorID_Intel[1] == regs[2]) && + (kVendorID_Intel[2] == regs[3]); + if (bIsIntel && regs[0] >= 7) { + // Query Structured Extended Feature Flags Enumeration Leaf + __cpuid(regs, 0x7); + // The bit 15 of EDX indicates if the processor is identified as a hybrid part. + bool ishybrid = regs[3] & (1 << 15); + if (ishybrid) { + // NOTE: even if ishybrid is true, it doesn't mean the processor must have P-cores and E-cores. + // On Intel CPUs we assume the HardwareCoreEnumerator::DefaultIntraOpNumThreads function would never fail. + // NOTE: due to resource restrictions, we cannot test this branch in our CI build pipelines. + return std::max(static_cast(1), HardwareCoreEnumerator::DefaultIntraOpNumThreads()); + } else { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } + } else +#endif + { + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); + } +} + +std::vector WindowsEnv::GetDefaultThreadAffinities() const { + return cores_.empty() ? std::vector(DefaultNumCores(), LogicalProcessors{}) : cores_; +} + +int WindowsEnv::GetL2CacheSize() const { + return l2_cache_size_; +} + +WindowsEnv& WindowsEnv::Instance() { + static WindowsEnv default_env; + return default_env; +} + +PIDType WindowsEnv::GetSelfPid() const { + return GetCurrentProcessId(); +} + +Status WindowsEnv::GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const { + wil::unique_hfile file_handle{ + CreateFile2(file_path, FILE_READ_ATTRIBUTES, FILE_SHARE_READ, OPEN_EXISTING, NULL)}; + if (file_handle.get() == INVALID_HANDLE_VALUE) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); + } + LARGE_INTEGER filesize; + if (!GetFileSizeEx(file_handle.get(), &filesize)) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetFileSizeEx ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); + } + if (static_cast(filesize.QuadPart) > std::numeric_limits::max()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetFileLength: File is too large"); + } + length = static_cast(filesize.QuadPart); + return Status::OK(); +} + +common::Status WindowsEnv::GetFileLength(int fd, /*out*/ size_t& file_size) const { + using namespace common; + if (fd < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid fd was supplied: ", fd); + } + + struct _stat buf; + int rc = _fstat(fd, &buf); + if (rc < 0) { + return Status(SYSTEM, errno); + } + + if (buf.st_size < 0) { + return ORT_MAKE_STATUS(SYSTEM, FAIL, "Received negative size from stat call"); + } + + if (static_cast(buf.st_size) > std::numeric_limits::max()) { + return ORT_MAKE_STATUS(SYSTEM, FAIL, "File is too large."); + } + + file_size = static_cast(buf.st_size); + return Status::OK(); +} + +Status WindowsEnv::ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* const file_path, const FileOffsetType offset, const size_t length, + const gsl::span buffer) const { + ORT_RETURN_IF_NOT(file_path, "file_path == nullptr"); + ORT_RETURN_IF_NOT(offset >= 0, "offset < 0"); + ORT_RETURN_IF_NOT(length <= buffer.size(), "length > buffer.size()"); + wil::unique_hfile file_handle{ + CreateFile2(file_path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, NULL)}; + if (file_handle.get() == INVALID_HANDLE_VALUE) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); + } + + if (length == 0) + return Status::OK(); + + if (offset > 0) { + LARGE_INTEGER current_position; + current_position.QuadPart = offset; + if (!SetFilePointerEx(file_handle.get(), current_position, ¤t_position, FILE_BEGIN)) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SetFilePointerEx ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); + } + } + + size_t total_bytes_read = 0; + while (total_bytes_read < length) { + constexpr DWORD k_max_bytes_to_read = 1 << 30; // read at most 1GB each time + const size_t bytes_remaining = length - total_bytes_read; + const DWORD bytes_to_read = static_cast(std::min(bytes_remaining, k_max_bytes_to_read)); + DWORD bytes_read; + + if (!ReadFile(file_handle.get(), buffer.data() + total_bytes_read, bytes_to_read, &bytes_read, nullptr)) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ReadFile ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); + } + + if (bytes_read != bytes_to_read) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ReadFile ", ToUTF8String(Basename(file_path)), " fail: unexpected end"); + } + + total_bytes_read += bytes_read; + } + + return Status::OK(); +} + + + +bool WindowsEnv::FileExists(const std::wstring& path) const { + DWORD attributes = GetFileAttributesW(path.c_str()); + return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL); +} + +bool WindowsEnv::FileExists(const std::string& path) const { + DWORD attributes = GetFileAttributesA(path.c_str()); + return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL); +} + +common::Status WindowsEnv::GetCanonicalPath( + const PathString& path, + PathString& canonical_path) const { + // adapted from MSVC STL std::filesystem::canonical() implementation + // https://github.com/microsoft/STL/blob/ed3cbf36416a385828e7a5987ca52cb42882d84b/stl/inc/filesystem#L2986 + CREATEFILE2_EXTENDED_PARAMETERS param; + memset(¶m, 0, sizeof(param)); + param.dwSize = sizeof(CREATEFILE2_EXTENDED_PARAMETERS); + param.dwFileFlags = FILE_FLAG_BACKUP_SEMANTICS; + wil::unique_hfile file_handle{CreateFile2( + path.c_str(), + FILE_READ_ATTRIBUTES, + FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, + OPEN_EXISTING, + ¶m)}; + + if (file_handle.get() == INVALID_HANDLE_VALUE) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(path)), " fail, errcode = ", + error_code, " - ", std::system_category().message(error_code)); + } + + constexpr DWORD initial_buffer_size = MAX_PATH; + std::vector result_buffer{}; + result_buffer.resize(initial_buffer_size); + + while (true) { + const DWORD result_length = GetFinalPathNameByHandleW( + file_handle.get(), + result_buffer.data(), + static_cast(result_buffer.size()), + 0); + + ORT_RETURN_IF_NOT( + result_length > 0, "GetFinalPathNameByHandle() failed: ", GetLastError()); + + if (result_length < result_buffer.size()) { // buffer is large enough + canonical_path.assign(result_buffer.data(), result_length); + break; + } + + // need larger buffer + result_buffer.resize(result_length); + } + + // update prefixes + if (canonical_path.find(ORT_TSTR(R"(\\?\)")) == 0) { + if (canonical_path.size() > 6 && + (ORT_TSTR('A') <= canonical_path[4] && canonical_path[4] <= ORT_TSTR('Z') || + ORT_TSTR('a') <= canonical_path[4] && canonical_path[4] <= ORT_TSTR('z')) && + canonical_path[5] == ORT_TSTR(':')) { + // "\\?\:" -> ":" + canonical_path.erase(0, 4); + } else if (canonical_path.find(ORT_TSTR(R"(UNC\)"), 4) == 4) { + // "\\?\UNC\" -> "\\" + canonical_path.erase(2, 6); + } + } + + return Status::OK(); +} + +// Return the path of the executable/shared library for the current running code. This is to make it +// possible to load other shared libraries installed next to our core runtime code. +PathString WindowsEnv::GetRuntimePath() const { + wchar_t buffer[MAX_PATH]; + if (!GetModuleFileNameW(reinterpret_cast(&__ImageBase), buffer, _countof(buffer))) { + return PathString(); + } + + // Remove the filename at the end, but keep the trailing slash + PathString path(buffer); + auto slash_index = path.find_last_of(ORT_TSTR('\\')); + if (slash_index == std::string::npos) { + // Windows supports forward slashes + slash_index = path.find_last_of(ORT_TSTR('/')); + if (slash_index == std::string::npos) { + return PathString(); + } + } + return path.substr(0, slash_index + 1); +} + +Status WindowsEnv::LoadDynamicLibrary(const PathString& wlibrary_filename, bool /*global_symbols*/, void** handle) const { #if WINAPI_FAMILY == WINAPI_FAMILY_PC_APP - * handle = ::LoadPackagedLibrary(wlibrary_filename.c_str(), 0); + *handle = ::LoadPackagedLibrary(wlibrary_filename.c_str(), 0); #else - // TODO: in most cases, the path name is a relative path and the behavior of the following line of code is undefined. - * handle = ::LoadLibraryExW(wlibrary_filename.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH); + // TODO: in most cases, the path name is a relative path and the behavior of the following line of code is undefined. + *handle = ::LoadLibraryExW(wlibrary_filename.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH); #endif - if (!*handle) { - const auto error_code = GetLastError(); - static constexpr DWORD bufferLength = 64 * 1024; - std::wstring s(bufferLength, '\0'); - FormatMessageW( - FORMAT_MESSAGE_FROM_SYSTEM | - FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, - error_code, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPWSTR)s.data(), - 0, NULL); - std::wostringstream oss; - oss << L"LoadLibrary failed with error " << error_code << L" \"" << s.c_str() << L"\" when trying to load \"" << wlibrary_filename << L"\""; - std::wstring errmsg = oss.str(); - // TODO: trim the ending '\r' and/or '\n' - common::Status status(common::ONNXRUNTIME, common::FAIL, ToUTF8String(errmsg)); - return status; - } - return Status::OK(); - } - - Status WindowsEnv::UnloadDynamicLibrary(void* handle) const { - if (::FreeLibrary(reinterpret_cast(handle)) == 0) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "FreeLibrary failed with error ", error_code, " - ", std::system_category().message(error_code)); - } - return Status::OK(); - } - - namespace dlfcn_win32 { - // adapted from https://github.com/dlfcn-win32 version 1.3.1. - // Simplified to only support finding symbols in libraries that were linked against. - // If ORT dynamically loads a custom ops library using RegisterCustomOpsLibrary[_V2] the handle from the library load - // is explicitly provided in the call to GetSymbolFromLibrary. - // - /* Load Psapi.dll at runtime, this avoids linking caveat */ - bool MyEnumProcessModules(HANDLE hProcess, HMODULE* lphModule, DWORD cb, LPDWORD lpcbNeeded) { - using EnumProcessModulesFn = BOOL(WINAPI*)(HANDLE, HMODULE*, DWORD, LPDWORD); - static EnumProcessModulesFn EnumProcessModulesPtr = []() { - EnumProcessModulesFn fn = nullptr; - // Windows 7 and newer versions have K32EnumProcessModules in Kernel32.dll which is always pre-loaded - HMODULE psapi = GetModuleHandleA("Kernel32.dll"); - if (psapi) { - fn = (EnumProcessModulesFn)(LPVOID)GetProcAddress(psapi, "K32EnumProcessModules"); - } - - return fn; - }(); - - if (EnumProcessModulesPtr == nullptr) { - return false; - } - - return EnumProcessModulesPtr(hProcess, lphModule, cb, lpcbNeeded); - } - - void* SearchModulesForSymbol(const char* name) { - HANDLE current_proc = GetCurrentProcess(); - DWORD size = 0; - void* symbol = nullptr; - - // GetModuleHandle(NULL) only returns the current program file. So if we want to get ALL loaded module including - // those in linked DLLs, we have to use EnumProcessModules(). - if (MyEnumProcessModules(current_proc, nullptr, 0, &size) != false) { - size_t num_handles = size / sizeof(HMODULE); - std::unique_ptr modules = std::make_unique(num_handles); - HMODULE* modules_ptr = modules.get(); - DWORD cb_needed = 0; - if (MyEnumProcessModules(current_proc, modules_ptr, size, &cb_needed) != 0 && size == cb_needed) { - for (size_t i = 0; i < num_handles; i++) { - symbol = GetProcAddress(modules[i], name); - if (symbol != nullptr) { - break; - } - } - } - } - - return symbol; - } - } // namespace dlfcn_win32 - - Status WindowsEnv::GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const { - Status status = Status::OK(); - - // global search to replicate dlsym RTLD_DEFAULT if handle is nullptr - if (handle == nullptr) { - *symbol = dlfcn_win32::SearchModulesForSymbol(symbol_name.c_str()); - } - else { - *symbol = ::GetProcAddress(reinterpret_cast(handle), symbol_name.c_str()); - } - - if (!*symbol) { - const auto error_code = GetLastError(); - static constexpr DWORD bufferLength = 64 * 1024; - std::wstring s(bufferLength, '\0'); - FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, error_code, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPWSTR)s.data(), 0, NULL); - std::wostringstream oss; - oss << L"Failed to find symbol " << ToWideString(symbol_name) << L" in library, error code: " - << error_code << L" \"" << s.c_str() << L"\""; - std::wstring errmsg = oss.str(); - // TODO: trim the ending '\r' and/or '\n' - status = Status(common::ONNXRUNTIME, common::FAIL, ToUTF8String(errmsg)); - } - - return status; - } - - std::string WindowsEnv::FormatLibraryFileName(const std::string& name, const std::string& version) const { - ORT_UNUSED_PARAMETER(name); - ORT_UNUSED_PARAMETER(version); - ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); - } - - - - // \brief returns a value for the queried variable name (var_name) - std::string WindowsEnv::GetEnvironmentVar(const std::string& var_name) const { - // Why getenv() should be avoided on Windows: - // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/getenv-wgetenv - // Instead use the Win32 API: GetEnvironmentVariableA() - - // Max limit of an environment variable on Windows including the null-terminating character - constexpr DWORD kBufferSize = 32767; - - // Create buffer to hold the result - std::string buffer(kBufferSize, '\0'); - - // The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters. - // If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character. - // Therefore, If the function succeeds, kBufferSize should be larger than char_count. - auto char_count = GetEnvironmentVariableA(var_name.c_str(), buffer.data(), kBufferSize); - - if (kBufferSize > char_count) { - buffer.resize(char_count); - return buffer; - } - - // Else either the call was failed, or the buffer wasn't large enough. - // TODO: Understand the reason for failure by calling GetLastError(). - // If it is due to the specified environment variable being found in the environment block, - // GetLastError() returns ERROR_ENVVAR_NOT_FOUND. - // For now, we assume that the environment variable is not found. - - return std::string(); - } - - /* - Read logical processor info from the map. - {-1,-1} stands for failure. - */ - ProcessorInfo WindowsEnv::GetProcessorAffinityMask(int global_processor_id) const { - if (global_processor_info_map_.count(global_processor_id)) { - return global_processor_info_map_.at(global_processor_id); - } - else { - return { -1, -1 }; - } - } - - WindowsEnv::WindowsEnv() { - l2_cache_size_ = 0; - InitializeCpuInfo(); - } - - /* - Discover all cores in a windows system. - Note - every "id" here, given it be group id, core id, or logical processor id, starts from 0. - */ - void WindowsEnv::InitializeCpuInfo() { - DWORD returnLength = 0; - GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &returnLength); - auto last_error = GetLastError(); - if (last_error != ERROR_INSUFFICIENT_BUFFER) { - return; - } - - std::unique_ptr allocation = std::make_unique(returnLength); - SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX* processorInfos = reinterpret_cast(allocation.get()); - - if (!GetLogicalProcessorInformationEx(RelationProcessorCore, processorInfos, &returnLength)) { - return; - } - - int core_id = 0; - int global_processor_id = 0; - const BYTE* iter = reinterpret_cast(processorInfos); - const BYTE* end = iter + returnLength; - std::stringstream log_stream; - - while (iter < end) { - auto processor_info = reinterpret_cast(iter); - auto size = processor_info->Size; - - // Discoverred a phyical core and it belongs exclusively to a single group - if (processor_info->Relationship == RelationProcessorCore && - processor_info->Processor.GroupCount == 1) { - log_stream << std::endl - << "core " << core_id + 1 << " consist of logical processors: "; - LogicalProcessors core_global_proc_ids; - constexpr KAFFINITY bit = 1; - constexpr int id_upper_bound = sizeof(KAFFINITY) * CHAR_BIT; - const auto& group_mask = processor_info->Processor.GroupMask[0]; - for (int logical_proessor_id = 0; logical_proessor_id < id_upper_bound; ++logical_proessor_id) { - if (group_mask.Mask & (bit << logical_proessor_id)) { - log_stream << global_processor_id + 1 << " "; - core_global_proc_ids.push_back(global_processor_id); - /* - * Build up a map between global processor id and local processor id. - * The map helps to bridge between ort API and windows affinity API - - * we need local processor id to build an affinity mask for a particular group. - */ - global_processor_info_map_.insert_or_assign(global_processor_id, - ProcessorInfo{ static_cast(group_mask.Group), - logical_proessor_id }); - global_processor_id++; - } - } - cores_.push_back(std::move(core_global_proc_ids)); - core_id++; - } - iter += size; - } - - DWORD newLength = 0; - GetLogicalProcessorInformationEx(RelationCache, nullptr, &newLength); - last_error = GetLastError(); - if (last_error != ERROR_INSUFFICIENT_BUFFER) { - return; - } - - if (newLength > returnLength) { - // Re-allocate - allocation = std::make_unique(newLength); - processorInfos = reinterpret_cast(allocation.get()); - } - - if (!GetLogicalProcessorInformationEx(RelationCache, processorInfos, &newLength)) { - return; - } - - iter = reinterpret_cast(processorInfos); - end = iter + newLength; - - while (iter < end) { - auto processor_info = reinterpret_cast(iter); - auto size = processor_info->Size; - - if (processor_info->Relationship == RelationCache && - processor_info->Cache.Level == 2) { - // L2 cache - l2_cache_size_ = static_cast(processor_info->Cache.CacheSize); - break; - } - - iter += size; - } - - } + if (!*handle) { + const auto error_code = GetLastError(); + static constexpr DWORD bufferLength = 64 * 1024; + std::wstring s(bufferLength, '\0'); + FormatMessageW( + FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, + error_code, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPWSTR)s.data(), + bufferLength, NULL); + s.erase(std::remove(s.begin(), s.end(), L'\r'), s.end()); + s.erase(std::remove(s.begin(), s.end(), L'\n'), s.end()); + std::wostringstream oss; + oss << DetermineLoadLibraryError(wlibrary_filename.c_str(), LOAD_WITH_ALTERED_SEARCH_PATH) + << L" (Error " << error_code << ": \"" << s.c_str() << "\")"; + std::wstring errmsg = oss.str(); + common::Status status(common::ONNXRUNTIME, common::FAIL, ToUTF8String(errmsg)); + return status; + } + return Status::OK(); +} + +Status WindowsEnv::UnloadDynamicLibrary(void* handle) const { + if (::FreeLibrary(reinterpret_cast(handle)) == 0) { + const auto error_code = GetLastError(); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "FreeLibrary failed with error ", error_code, " - ", std::system_category().message(error_code)); + } + return Status::OK(); +} + +namespace dlfcn_win32 { +// adapted from https://github.com/dlfcn-win32 version 1.3.1. +// Simplified to only support finding symbols in libraries that were linked against. +// If ORT dynamically loads a custom ops library using RegisterCustomOpsLibrary[_V2] the handle from the library load +// is explicitly provided in the call to GetSymbolFromLibrary. +// +/* Load Psapi.dll at runtime, this avoids linking caveat */ +bool MyEnumProcessModules(HANDLE hProcess, HMODULE* lphModule, DWORD cb, LPDWORD lpcbNeeded) { + using EnumProcessModulesFn = BOOL(WINAPI*)(HANDLE, HMODULE*, DWORD, LPDWORD); + static EnumProcessModulesFn EnumProcessModulesPtr = []() { + EnumProcessModulesFn fn = nullptr; + // Windows 7 and newer versions have K32EnumProcessModules in Kernel32.dll which is always pre-loaded + HMODULE psapi = GetModuleHandleA("Kernel32.dll"); + if (psapi) { + fn = (EnumProcessModulesFn)(LPVOID)GetProcAddress(psapi, "K32EnumProcessModules"); + } + + return fn; + }(); + + if (EnumProcessModulesPtr == nullptr) { + return false; + } + + return EnumProcessModulesPtr(hProcess, lphModule, cb, lpcbNeeded); +} + +void* SearchModulesForSymbol(const char* name) { + HANDLE current_proc = GetCurrentProcess(); + DWORD size = 0; + void* symbol = nullptr; + + // GetModuleHandle(NULL) only returns the current program file. So if we want to get ALL loaded module including + // those in linked DLLs, we have to use EnumProcessModules(). + if (MyEnumProcessModules(current_proc, nullptr, 0, &size) != false) { + size_t num_handles = size / sizeof(HMODULE); + std::unique_ptr modules = std::make_unique(num_handles); + HMODULE* modules_ptr = modules.get(); + DWORD cb_needed = 0; + if (MyEnumProcessModules(current_proc, modules_ptr, size, &cb_needed) != 0 && size == cb_needed) { + for (size_t i = 0; i < num_handles; i++) { + symbol = GetProcAddress(modules[i], name); + if (symbol != nullptr) { + break; + } + } + } + } + + return symbol; +} +} // namespace dlfcn_win32 + +Status WindowsEnv::GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const { + Status status = Status::OK(); + + // global search to replicate dlsym RTLD_DEFAULT if handle is nullptr + if (handle == nullptr) { + *symbol = dlfcn_win32::SearchModulesForSymbol(symbol_name.c_str()); + } else { + *symbol = ::GetProcAddress(reinterpret_cast(handle), symbol_name.c_str()); + } + + if (!*symbol) { + const auto error_code = GetLastError(); + static constexpr DWORD bufferLength = 64 * 1024; + std::wstring s(bufferLength, '\0'); + FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, error_code, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + (LPWSTR)s.data(), 0, NULL); + std::wostringstream oss; + oss << L"Failed to find symbol " << ToWideString(symbol_name) << L" in library, error code: " + << error_code << L" \"" << s.c_str() << L"\""; + std::wstring errmsg = oss.str(); + // TODO: trim the ending '\r' and/or '\n' + status = Status(common::ONNXRUNTIME, common::FAIL, ToUTF8String(errmsg)); + } + + return status; +} + +std::string WindowsEnv::FormatLibraryFileName(const std::string& name, const std::string& version) const { + ORT_UNUSED_PARAMETER(name); + ORT_UNUSED_PARAMETER(version); + ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); +} + +// \brief returns a value for the queried variable name (var_name) +std::string WindowsEnv::GetEnvironmentVar(const std::string& var_name) const { + // Why getenv() should be avoided on Windows: + // https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/getenv-wgetenv + // Instead use the Win32 API: GetEnvironmentVariableA() + + // Max limit of an environment variable on Windows including the null-terminating character + constexpr DWORD kBufferSize = 32767; + + // Create buffer to hold the result + std::string buffer(kBufferSize, '\0'); + + // The last argument is the size of the buffer pointed to by the lpBuffer parameter, including the null-terminating character, in characters. + // If the function succeeds, the return value is the number of characters stored in the buffer pointed to by lpBuffer, not including the terminating null character. + // Therefore, If the function succeeds, kBufferSize should be larger than char_count. + auto char_count = GetEnvironmentVariableA(var_name.c_str(), buffer.data(), kBufferSize); + + if (kBufferSize > char_count) { + buffer.resize(char_count); + return buffer; + } + + // Else either the call was failed, or the buffer wasn't large enough. + // TODO: Understand the reason for failure by calling GetLastError(). + // If it is due to the specified environment variable being found in the environment block, + // GetLastError() returns ERROR_ENVVAR_NOT_FOUND. + // For now, we assume that the environment variable is not found. + + return std::string(); +} + +/* +Read logical processor info from the map. +{-1,-1} stands for failure. +*/ +ProcessorInfo WindowsEnv::GetProcessorAffinityMask(int global_processor_id) const { + if (global_processor_info_map_.count(global_processor_id)) { + return global_processor_info_map_.at(global_processor_id); + } else { + return {-1, -1}; + } +} + +WindowsEnv::WindowsEnv() { + l2_cache_size_ = 0; + InitializeCpuInfo(); +} + +/* +Discover all cores in a windows system. +Note - every "id" here, given it be group id, core id, or logical processor id, starts from 0. +*/ +void WindowsEnv::InitializeCpuInfo() { + DWORD returnLength = 0; + GetLogicalProcessorInformationEx(RelationProcessorCore, nullptr, &returnLength); + auto last_error = GetLastError(); + if (last_error != ERROR_INSUFFICIENT_BUFFER) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(ERROR) << "Failed to calculate byte size for saving cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + std::unique_ptr allocation = std::make_unique(returnLength); + SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX* processorInfos = reinterpret_cast(allocation.get()); + + if (!GetLogicalProcessorInformationEx(RelationProcessorCore, processorInfos, &returnLength)) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(ERROR) << "Failed to fetch cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + int core_id = 0; + int global_processor_id = 0; + const BYTE* iter = reinterpret_cast(processorInfos); + const BYTE* end = iter + returnLength; + std::stringstream log_stream; + + while (iter < end) { + auto processor_info = reinterpret_cast(iter); + auto size = processor_info->Size; + + // Discoverred a phyical core and it belongs exclusively to a single group + if (processor_info->Relationship == RelationProcessorCore && + processor_info->Processor.GroupCount == 1) { + log_stream << std::endl + << "core " << core_id + 1 << " consist of logical processors: "; + LogicalProcessors core_global_proc_ids; + constexpr KAFFINITY bit = 1; + constexpr int id_upper_bound = sizeof(KAFFINITY) * CHAR_BIT; + const auto& group_mask = processor_info->Processor.GroupMask[0]; + for (int logical_proessor_id = 0; logical_proessor_id < id_upper_bound; ++logical_proessor_id) { + if (group_mask.Mask & (bit << logical_proessor_id)) { + log_stream << global_processor_id + 1 << " "; + core_global_proc_ids.push_back(global_processor_id); + /* + * Build up a map between global processor id and local processor id. + * The map helps to bridge between ort API and windows affinity API - + * we need local processor id to build an affinity mask for a particular group. + */ + global_processor_info_map_.insert_or_assign(global_processor_id, + ProcessorInfo{static_cast(group_mask.Group), + logical_proessor_id}); + global_processor_id++; + } + } + cores_.push_back(std::move(core_global_proc_ids)); + core_id++; + } + iter += size; + } + + DWORD newLength = 0; + GetLogicalProcessorInformationEx(RelationCache, nullptr, &newLength); + last_error = GetLastError(); + if (last_error != ERROR_INSUFFICIENT_BUFFER) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(ERROR) << "Failed to calculate byte size for saving cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + if (newLength > returnLength) { + // Re-allocate + allocation = std::make_unique(newLength); + processorInfos = reinterpret_cast(allocation.get()); + } + + if (!GetLogicalProcessorInformationEx(RelationCache, processorInfos, &newLength)) { + const auto error_code = GetLastError(); + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(ERROR) << "Failed to fetch cpu info on windows" + << ", error code: " << error_code + << ", error msg: " << std::system_category().message(error_code); + } + return; + } + + iter = reinterpret_cast(processorInfos); + end = iter + newLength; + + while (iter < end) { + auto processor_info = reinterpret_cast(iter); + auto size = processor_info->Size; + + if (processor_info->Relationship == RelationCache && + processor_info->Cache.Level == 2) { + // L2 cache + l2_cache_size_ = static_cast(processor_info->Cache.CacheSize); + break; + } + + iter += size; + } + + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(VERBOSE) << "Found total " << cores_.size() << " core(s) from windows system:"; + LOGS_DEFAULT(VERBOSE) << log_stream.str(); + LOGS_DEFAULT(VERBOSE) << "\nDetected L2 cache size: " << l2_cache_size_ << " bytes"; + } +} } // namespace onnxruntime diff --git a/src/core/platform/windows/env.h b/src/core/platform/windows/env.h index 9e53a71..05b92bb 100644 --- a/src/core/platform/windows/env.h +++ b/src/core/platform/windows/env.h @@ -15,6 +15,7 @@ limitations under the License. // Portions Copyright (c) Microsoft Corporation #include "core/platform/env.h" +#include "core/platform/windows/telemetry.h" #include "core/common/inlined_containers.h" #include @@ -61,14 +62,29 @@ class WindowsEnv : public Env { common::Status GetFileLength(int fd, /*out*/ size_t& file_size) const override; Status ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* const file_path, const FileOffsetType offset, const size_t length, const gsl::span buffer) const override; - - + Status MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path, + FileOffsetType offset, + size_t length, + MappedMemoryPtr& mapped_memory) const override; + bool FolderExists(const std::wstring& path) const override; + bool FolderExists(const std::string& path) const override; + bool FileExists(const std::wstring& path) const override; + bool FileExists(const std::string& path) const override; + common::Status CreateFolder(const std::wstring& path) const override; + common::Status CreateFolder(const std::string& path) const override; + common::Status DeleteFolder(const PathString& path) const override; + common::Status FileOpenRd(const std::wstring& path, /*out*/ int& fd) const override; + common::Status FileOpenWr(const std::wstring& path, /*out*/ int& fd) const override; + common::Status FileOpenRd(const std::string& path, /*out*/ int& fd) const override; + common::Status FileOpenWr(const std::string& path, /*out*/ int& fd) const override; + common::Status FileClose(int fd) const override; common::Status GetCanonicalPath(const PathString& path, PathString& canonical_path) const override; PathString GetRuntimePath() const override; Status LoadDynamicLibrary(const PathString& library_filename, bool /*global_symbols*/, void** handle) const override; Status UnloadDynamicLibrary(void* handle) const override; Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const override; std::string FormatLibraryFileName(const std::string& name, const std::string& version) const override; + const Telemetry& GetTelemetryProvider() const override; std::string GetEnvironmentVar(const std::string& var_name) const override; ProcessorInfo GetProcessorAffinityMask(int global_processor_id) const; @@ -122,6 +138,7 @@ class WindowsEnv : public Env { private: void InitializeCpuInfo(); typedef VOID(WINAPI* FnGetSystemTimePreciseAsFileTime)(LPFILETIME); + WindowsTelemetry telemetry_provider_; }; } // namespace onnxruntime diff --git a/src/lib/CMakeLists.txt b/src/lib/CMakeLists.txt index 0942945..719423b 100644 --- a/src/lib/CMakeLists.txt +++ b/src/lib/CMakeLists.txt @@ -1,803 +1,885 @@ -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. - -set(MLAS_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) -set(MLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) -set(MLAS_INC_DIR ${MLAS_ROOT}/../include) - -include_directories(${ONNXRUNTIME_INCLUDE_DIR}) - -#Set global compile flags for all the source code(including third_party code like protobuf) -#This section must be before any add_subdirectory, otherwise build may fail because /MD,/MT mismatch -if (MSVC) - if (CMAKE_VS_PLATFORM_NAME) - # Multi-platform generator - set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) - else() - set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) - endif() - if (onnxruntime_target_platform STREQUAL "ARM64") - set(onnxruntime_target_platform "ARM64") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "ARM64EC") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") - set(onnxruntime_target_platform "ARM") - enable_language(ASM_MARMASM) - elseif (onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64") - set(onnxruntime_target_platform "x64") - enable_language(ASM_MASM) - elseif (onnxruntime_target_platform STREQUAL "Win32" OR onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686") - set(onnxruntime_target_platform "x86") - enable_language(ASM_MASM) - message("Enabling SAFESEH for x86 build") - set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") - else() - message(FATAL_ERROR "Unknown CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") - endif() -endif() - -# -# All hardware agnostic source files here -# hardware specific files would cause trouble in -# multi-target build -# -add_library(onnxruntime_mlas STATIC - ${MLAS_SRC_DIR}/mlasi.h - ${MLAS_SRC_DIR}/platform.cpp - ${MLAS_SRC_DIR}/threading.cpp - ${MLAS_SRC_DIR}/sgemm.cpp - ${MLAS_SRC_DIR}/halfgemm.cpp - ${MLAS_SRC_DIR}/qgemm.cpp - ${MLAS_SRC_DIR}/qdwconv.cpp - ${MLAS_SRC_DIR}/convolve.cpp - ${MLAS_SRC_DIR}/convsym.cpp - ${MLAS_SRC_DIR}/pooling.cpp - ${MLAS_SRC_DIR}/transpose.cpp - ${MLAS_SRC_DIR}/reorder.cpp - ${MLAS_SRC_DIR}/snchwc.cpp - ${MLAS_SRC_DIR}/activate.cpp - ${MLAS_SRC_DIR}/logistic.cpp - ${MLAS_SRC_DIR}/tanh.cpp - ${MLAS_SRC_DIR}/erf.cpp - ${MLAS_SRC_DIR}/compute.cpp - ${MLAS_SRC_DIR}/quantize.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp - ${MLAS_SRC_DIR}/qladd.cpp - ${MLAS_SRC_DIR}/qlmul.cpp - ${MLAS_SRC_DIR}/qpostprocessor.cpp - ${MLAS_SRC_DIR}/qlgavgpool.cpp - ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp - ${MLAS_SRC_DIR}/qnbitgemm.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h - ${MLAS_SRC_DIR}/flashattn.cpp - ${MLAS_SRC_DIR}/cast.cpp -) - -target_sources(onnxruntime_mlas PRIVATE - ${MLAS_INC_DIR}/mlas_float16.h - ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h - ${MLAS_INC_DIR}/mlas_q4.h - ${MLAS_INC_DIR}/mlas_qnbit.h - ${MLAS_INC_DIR}/mlas.h -) - -if (NOT onnxruntime_ORT_MINIMAL_BUILD) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/q4_dq.cpp - ${MLAS_SRC_DIR}/q4gemm.cpp - ) -endif() - - -#TODO: set MASM flags properly -function(setup_mlas_source_for_windows) - - # - # Sources common for all platforms. - # - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ) - - #The onnxruntime_target_platform variable was added by Windows AI team in onnxruntime_common.cmake - #Don't use it for other platforms. - if((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")) - set(PREPROCESS_ARMASM_FLAGS "") - set(ARMASM_FLAGS "") - - if(onnxruntime_target_platform STREQUAL "ARM64") - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp - ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp - ) - - set(mlas_platform_preprocess_srcs - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDotLd64.asm - ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelDot.asm - ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm - ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm - ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelSdot.asm - ${MLAS_SRC_DIR}/arm64/SgemmKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SgemvKernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDot.asm - ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm - ) - else() - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ) - - set(mlas_platform_preprocess_srcs - ${MLAS_SRC_DIR}/arm64ec/QgemmU8X8KernelNeon.asm - ${MLAS_SRC_DIR}/arm64ec/SgemmKernelNeon.asm - ) - - string(APPEND PREPROCESS_ARMASM_FLAGS " /arm64EC") - string(APPEND ARMASM_FLAGS " -machine ARM64EC") - endif() - - if(CMAKE_BUILD_TYPE STREQUAL "Debug") - string(APPEND ARMASM_FLAGS " -g") - endif() - - # Remove double quotes from flag strings. - separate_arguments(PREPROCESS_ARMASM_FLAGS NATIVE_COMMAND "${PREPROCESS_ARMASM_FLAGS}") - separate_arguments(ARMASM_FLAGS NATIVE_COMMAND "${ARMASM_FLAGS}") - - # Run the C precompiler on each input before the assembler. - foreach(asm_filename ${mlas_platform_preprocess_srcs}) - get_filename_component(asm_filename_base ${asm_filename} NAME_WLE) - set(preprocess_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.i) - set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.obj) - add_custom_command( - OUTPUT ${obj_filename} - COMMAND - cl.exe ${PREPROCESS_ARMASM_FLAGS} /P ${asm_filename} /Fi${preprocess_filename} - COMMAND - armasm64.exe ${ARMASM_FLAGS} ${preprocess_filename} ${obj_filename} - DEPENDS ${asm_filename} - BYPRODUCTS ${preprocess_filename} - ) - target_sources(onnxruntime_mlas PRIVATE ${obj_filename}) - endforeach() - elseif(onnxruntime_target_platform STREQUAL "ARM") - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/arm/sgemmc.cpp - ) - elseif(onnxruntime_target_platform STREQUAL "x64") - - file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS - "${MLAS_SRC_DIR}/intrinsics/avx/*.cpp" - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "/arch:AVX") - - file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS - "${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp" - ) - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") - - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/dgemm.cpp - ${mlas_platform_srcs_avx} - ${mlas_platform_srcs_avx2} - ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp - ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Vnni.asm - ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvxVnni.asm - ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm - ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx512Core.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelM1Avx.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/SconvKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelSse2.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/sgemma.asm - ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm - ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm - ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm - ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm - ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm - ) - if(MSVC_VERSION GREATER_EQUAL 1933) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm - ) - endif() - - if (NOT onnxruntime_ORT_MINIMAL_BUILD) - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/q4gemm_avx512.cpp - ) - endif() - else() - target_sources(onnxruntime_mlas PRIVATE - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp - ${MLAS_SRC_DIR}/i386/SgemmKernelSse2.asm - ${MLAS_SRC_DIR}/i386/SgemmKernelAvx.asm - ) - endif() -endfunction() - -if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/wasm_simd/*.cpp" - ) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/qgemm_kernel_wasmsimd.cpp - ) - else() - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/scalar/*.cpp" - ) - endif() - target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) -elseif(MSVC) - setup_mlas_source_for_windows() -else() - - if(APPLE) - get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) - - if(NOT ONNXRUNTIME_MLAS_OSX_ARCH) - set(ONNXRUNTIME_MLAS_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) - endif() - foreach(OSX_ARCH ${ONNXRUNTIME_MLAS_OSX_ARCH}) - if (OSX_ARCH STREQUAL "arm64") - set(ARM64 TRUE) - elseif (OSX_ARCH STREQUAL "arm64e") - set(ARM64 TRUE) - elseif (OSX_ARCH STREQUAL "arm") - set(ARM TRUE) - elseif (OSX_ARCH STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (OSX_ARCH STREQUAL "i386") - set(X86 TRUE) - endif() - endforeach() - elseif(ANDROID) - if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") - set(ARM TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") - set(ARM64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") - set(X86_64 TRUE) - elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") - set(X86 TRUE) - endif() - else() - #Linux/FreeBSD/PowerPC/... - #The value of CMAKE_SYSTEM_PROCESSOR should be from `uname -m` - #Example values: - #arm64v8/ubuntu -> aarch64 - #arm32v6/alpine -> armv7l - #arm32v7/centos -> armv7l - #ppc64le/debian -> ppc64le - #s390x/ubuntu -> s390x - #ppc64le/busybox -> ppc64le - #arm64v8/ubuntu -> aarch64 - #Android: armv7-a aarch64 i686 x86_64 - #chasun: I don't think anyone uses 'arm64' - if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm.*") - set(ARM TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") - set(ARM64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(powerpc.*|ppc.*)") - set(POWER TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") - set(X86 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") - set(X86_64 TRUE) - elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") - set(LOONGARCH64 TRUE) - endif() - endif() - - if(APPLE) - get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) - endif() - list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) - if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) - set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) - endif() - #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below - #and split MLAS to multiple static libraries. - #Otherwise, it works like if(...) elseif(...) elseif(...) endif() - set(MLAS_SOURCE_IS_NOT_SET 1) - if(ARM) - enable_language(ASM) - - set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} -mfpu=neon") - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") - - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch32/QgemmU8X8KernelNeon.S - ${MLAS_SRC_DIR}/arm/sgemmc.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ) - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(ARM64 AND MLAS_SOURCE_IS_NOT_SET ) - enable_language(ASM) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDot.S - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDotLd64.S - ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelDot.S - ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S - ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S - ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S - ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp - ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h - ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp - PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") - if (NOT APPLE) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S - ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S - ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S - ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp - ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp - ${MLAS_SRC_DIR}/fp16_neon_common.cpp - ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") - set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") - set_source_files_properties(${MLAS_SRC_DIR}/fp16_neon_common.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") - endif() - - if(ONNXRUNTIME_MLAS_MULTI_ARCH) - add_library(onnxruntime_mlas_arm64 STATIC ${mlas_platform_srcs}) - list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) - set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") - set(mlas_platform_srcs ) - else() - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(POWER AND MLAS_SOURCE_IS_NOT_SET) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp - ${MLAS_SRC_DIR}/dgemm.cpp - ${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp - ${MLAS_SRC_DIR}/power/QuantizePower.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") - - check_cxx_compiler_flag("-mcpu=power9" HAS_POWER9) - if (HAS_POWER9) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp PROPERTIES COMPILE_FLAGS "-mcpu=power9") - endif() - - check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10) - if(HAS_POWER10) - set(CMAKE_REQUIRED_FLAGS "-mcpu=power10") - check_cxx_source_compiles(" - #include - int main() { - __vector_quad acc0; - __builtin_mma_xxsetaccz (&acc0); - return 0; - }" - COMPILES_P10 - ) - if(COMPILES_P10) - check_cxx_source_compiles(" - #ifdef _AIX - #define POWER_10 0x40000 - #define POWER_10_ANDUP (POWER_10) - #include - #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) - int main() { - bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31); - return 0; - } - #else - #include - int main() { - unsigned long hwcap2 = getauxval(AT_HWCAP2); - bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); - return 0; - } - #endif" - HAS_P10_RUNTIME - ) - if (HAS_P10_RUNTIME) - set_source_files_properties(${MLAS_SRC_DIR}/platform.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") - set_source_files_properties(${MLAS_SRC_DIR}/qgemm.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") - endif() - set(mlas_platform_srcs_power10 - ${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp - ${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp - ${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE") - set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10") - set_source_files_properties(${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp PROPERTIES COMPILE_FLAGS "-O3 -mcpu=power10") - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${mlas_platform_srcs_power10} - ) - endif() - endif() - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(X86 AND MLAS_SOURCE_IS_NOT_SET) - enable_language(ASM) - - set(mlas_platform_srcs_sse2 - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/x86/SgemmKernelSse2.S - ) - set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") - - set(mlas_platform_srcs_avx - ${MLAS_SRC_DIR}/x86/SgemmKernelAvx.S - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") - - set(mlas_platform_srcs - ${mlas_platform_srcs_sse2} - ${mlas_platform_srcs_avx} - ) - - # In r23, NDK remove __x86.get_pc_thunk.* from libatomic. Add our own - # implementation to avoid external dependency. - if(ANDROID) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/x86/x86.get_pc_thunk.S - ) - endif() - - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(X86_64 AND MLAS_SOURCE_IS_NOT_SET) - enable_language(ASM) - - # Forward the flags for the minimum target platform version from the C - # compiler to the assembler. This works around CMakeASMCompiler.cmake.in - # not including the logic to set this flag for the assembler. - set(CMAKE_ASM${ASM_DIALECT}_OSX_DEPLOYMENT_TARGET_FLAG "${CMAKE_C_OSX_DEPLOYMENT_TARGET_FLAG}") - - # The LLVM assembler does not support the .arch directive to enable instruction - # set extensions and also doesn't support AVX-512F instructions without - # turning on support via command-line option. Group the sources by the - # instruction set extension and explicitly set the compiler flag as appropriate. - - set(mlas_platform_srcs_sse2 - ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp - ${MLAS_SRC_DIR}/x86_64/DgemmKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Sse2.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S - ) - if(NOT APPLE) - set(mlas_platform_srcs_sse2 - ${mlas_platform_srcs_sse2} - ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S - ) - endif() - set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") - - set(mlas_platform_srcs_avx - ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1Avx.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1TransposeBAvx.S - ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Avx.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx.S - ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx.S - ${MLAS_SRC_DIR}/intrinsics/avx/min_max_elements.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") - - set(mlas_platform_srcs_avx2 - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8U8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvxVnni.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S - ${MLAS_SRC_DIR}/x86_64/DgemmKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/TransKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S - ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S - ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp - ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp - ) - if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) - set(mlas_platform_srcs_avx2 - ${mlas_platform_srcs_avx2} - ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S - ) - endif() -message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") -message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") - -if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") - message(STATUS "Using -mavx2 -mfma -mavxvnni flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") -else() - message(STATUS "Using -mavx2 -mfma flags") - set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") -endif() - set(mlas_platform_srcs_avx512f - ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S - ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S - ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") - - set(mlas_platform_srcs_avx512core - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Core.S - ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S - ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S - ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") - - set(mlas_platform_srcs_avx512vnni - ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp - ) - set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") - - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/activate_fp16.cpp - ${MLAS_SRC_DIR}/dwconv.cpp - ${MLAS_SRC_DIR}/dgemm.cpp - ${MLAS_SRC_DIR}/pooling_fp16.cpp - ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp - ${mlas_platform_srcs_sse2} - ${mlas_platform_srcs_avx} - ${mlas_platform_srcs_avx2} - ${mlas_platform_srcs_avx512f} - ${mlas_platform_srcs_avx512core} - ${mlas_platform_srcs_avx512vnni} - ) - - if (NOT onnxruntime_ORT_MINIMAL_BUILD) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/q4gemm_avx512.cpp - ) - set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() - if(NOT APPLE) - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S - ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp - ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S - ) - set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") - endif() - - if(ONNXRUNTIME_MLAS_MULTI_ARCH) - add_library(onnxruntime_mlas_x86_64 STATIC ${mlas_platform_srcs}) - set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") - list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) - set(mlas_platform_srcs ) - else() - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) - set(mlas_platform_srcs - ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp - ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S - ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S - ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S - ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S - ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S - ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S - ) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) - set(MLAS_SOURCE_IS_NOT_SET 0) - endif() - endif() - if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) - file(GLOB_RECURSE mlas_platform_srcs - "${MLAS_SRC_DIR}/scalar/*.cpp") - elseif (onnxruntime_FORCE_GENERIC_ALGORITHMS) - file(GLOB_RECURSE mlas_platform_srcs_generic - "${MLAS_SRC_DIR}/scalar/*.cpp") - set(mlas_platform_srcs - ${mlas_platform_srcs} - ${mlas_platform_srcs_generic} - ) - endif() - target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) -endif() - -foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_INCLUDE_DIR} ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) - target_link_libraries(${mlas_target} Microsoft.GSL::GSL) - - set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") -endforeach() - -if (WIN32) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") - if (onnxruntime_ENABLE_STATIC_ANALYSIS) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") - endif() -endif() - -if (PLATFORM_NAME STREQUAL "macabi") - # Needed for maccatalyst C compilation - # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" - target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) -endif() - -if (NOT onnxruntime_BUILD_SHARED_LIB) - install(TARGETS onnxruntime_mlas - ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} - LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} - RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} - FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) -endif() - -# set up source group for MLAS source files -block() - set(source_group_srcs) - foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - get_target_property(mlas_target_srcs ${mlas_target} SOURCES) - foreach(mlas_target_src ${mlas_target_srcs}) - cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) - if(in_mlas_root) - list(APPEND source_group_srcs ${mlas_target_src}) - endif() - endforeach() - endforeach() -endblock() - - - - # - # Command line tool for quantization and de-quantization of 2-D fp32 tensors - # based on block-wise quantization of int4 - # - - add_executable(onnxruntime_mlas_q4dq - ${MLAS_SRC_DIR}/q4_dq_cli.cpp - ) - target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") - - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS}) - if(NOT MLAS_NO_ONNXRUNTIME) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE onnxruntime_common) - endif() - if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo) - endif() - if (CMAKE_SYSTEM_NAME STREQUAL "Android") - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs}) - endif() - - if(WIN32) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE debug Dbghelp Advapi32) - endif() - if (onnxruntime_LINK_LIBATOMIC) - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE atomic) - endif() - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE Threads::Threads) - - if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") - else() - set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") - endif() - endif() - +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +set(MLAS_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) +set(MLAS_SRC_DIR ${CMAKE_CURRENT_SOURCE_DIR}) +set(MLAS_INC_DIR ${MLAS_ROOT}/../include) + +include_directories(${ONNXRUNTIME_INCLUDE_DIR}) + +#Set global compile flags for all the source code(including third_party code like protobuf) +#This section must be before any add_subdirectory, otherwise build may fail because /MD,/MT mismatch +if (MSVC) + if (CMAKE_VS_PLATFORM_NAME) + # Multi-platform generator + set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) + else() + set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) + endif() + if (onnxruntime_target_platform STREQUAL "ARM64") + set(onnxruntime_target_platform "ARM64") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM64EC") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") + set(onnxruntime_target_platform "ARM") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64") + set(onnxruntime_target_platform "x64") + enable_language(ASM_MASM) + elseif (onnxruntime_target_platform STREQUAL "Win32" OR onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686") + set(onnxruntime_target_platform "x86") + enable_language(ASM_MASM) + message("Enabling SAFESEH for x86 build") + set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") + else() + message(FATAL_ERROR "Unknown CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") + endif() +endif() + +# +# All hardware agnostic source files here +# hardware specific files would cause trouble in +# multi-target build +# +add_library(onnxruntime_mlas STATIC + ${MLAS_SRC_DIR}/mlasi.h + ${MLAS_SRC_DIR}/platform.cpp + ${MLAS_SRC_DIR}/threading.cpp + ${MLAS_SRC_DIR}/sgemm.cpp + ${MLAS_SRC_DIR}/halfgemm.cpp + ${MLAS_SRC_DIR}/qgemm.cpp + ${MLAS_SRC_DIR}/qdwconv.cpp + ${MLAS_SRC_DIR}/convolve.cpp + ${MLAS_SRC_DIR}/convsym.cpp + ${MLAS_SRC_DIR}/pooling.cpp + ${MLAS_SRC_DIR}/transpose.cpp + ${MLAS_SRC_DIR}/reorder.cpp + ${MLAS_SRC_DIR}/snchwc.cpp + ${MLAS_SRC_DIR}/activate.cpp + ${MLAS_SRC_DIR}/logistic.cpp + ${MLAS_SRC_DIR}/tanh.cpp + ${MLAS_SRC_DIR}/eltwise.h + ${MLAS_SRC_DIR}/eltwise.cpp + ${MLAS_SRC_DIR}/erf.cpp + ${MLAS_SRC_DIR}/compute.cpp + ${MLAS_SRC_DIR}/quantize.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp + ${MLAS_SRC_DIR}/qladd.cpp + ${MLAS_SRC_DIR}/qlmul.cpp + ${MLAS_SRC_DIR}/qpostprocessor.cpp + ${MLAS_SRC_DIR}/qlgavgpool.cpp + ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/qnbitgemm.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h + ${MLAS_SRC_DIR}/flashattn.cpp + ${MLAS_SRC_DIR}/cast.cpp + ${MLAS_SRC_DIR}/rotary_embedding.h + ${MLAS_SRC_DIR}/rotary_embedding.cpp + ${MLAS_SRC_DIR}/softmax.h + ${MLAS_SRC_DIR}/saturation_check.cpp +) + +target_sources(onnxruntime_mlas PRIVATE + ${MLAS_INC_DIR}/mlas_float16.h + ${MLAS_INC_DIR}/mlas_gemm_postprocessor.h + ${MLAS_INC_DIR}/mlas_q4.h + ${MLAS_INC_DIR}/mlas_qnbit.h + ${MLAS_INC_DIR}/mlas.h +) + +if (NOT onnxruntime_ORT_MINIMAL_BUILD) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/q4_dq.cpp + ${MLAS_SRC_DIR}/q4gemm.cpp + ) +endif() + +#TODO: set MASM flags properly +function(setup_mlas_source_for_windows) + + # + # Sources common for all platforms. + # + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ) + + #The onnxruntime_target_platform variable was added by Windows AI team in onnxruntime_common.cmake + #Don't use it for other platforms. + if((onnxruntime_target_platform STREQUAL "ARM64") OR (onnxruntime_target_platform STREQUAL "ARM64EC")) + set(PREPROCESS_ARMASM_FLAGS "") + set(ARMASM_FLAGS "") + + if(onnxruntime_target_platform STREQUAL "ARM64") + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/cast_kernel_neon.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/softmax_kernel_neon.h + ${MLAS_SRC_DIR}/softmax_kernel_neon.cpp + ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon.h + ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp + ) + + set(mlas_platform_preprocess_srcs + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDot.asm + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelDotLd64.asm + ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelDot.asm + ${MLAS_SRC_DIR}/arm64/ConvSymS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/ConvSymU8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvSymU8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/DepthwiseQConvKernelSize9Neon.asm + ${MLAS_SRC_DIR}/arm64/HalfGemmKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/QgemmU8X8KernelUdot.asm + ${MLAS_SRC_DIR}/arm64/QgemmS8S8KernelSdot.asm + ${MLAS_SRC_DIR}/arm64/SgemmKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SgemvKernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDot.asm + ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm + ) + + if (onnxruntime_USE_KLEIDIAI) + setup_kleidiai() + endif() + else() + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ) + + set(mlas_platform_preprocess_srcs + ${MLAS_SRC_DIR}/arm64ec/QgemmU8X8KernelNeon.asm + ${MLAS_SRC_DIR}/arm64ec/SgemmKernelNeon.asm + ) + + string(APPEND PREPROCESS_ARMASM_FLAGS " /arm64EC") + string(APPEND ARMASM_FLAGS " -machine ARM64EC") + endif() + + if(CMAKE_BUILD_TYPE STREQUAL "Debug") + string(APPEND ARMASM_FLAGS " -g") + endif() + + # Remove double quotes from flag strings. + separate_arguments(PREPROCESS_ARMASM_FLAGS NATIVE_COMMAND "${PREPROCESS_ARMASM_FLAGS}") + separate_arguments(ARMASM_FLAGS NATIVE_COMMAND "${ARMASM_FLAGS}") + + # Run the C precompiler on each input before the assembler. + foreach(asm_filename ${mlas_platform_preprocess_srcs}) + get_filename_component(asm_filename_base ${asm_filename} NAME_WLE) + set(preprocess_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.i) + set(obj_filename ${CMAKE_CURRENT_BINARY_DIR}/${asm_filename_base}.obj) + add_custom_command( + OUTPUT ${obj_filename} + COMMAND + cl.exe ${PREPROCESS_ARMASM_FLAGS} /P ${asm_filename} /Fi${preprocess_filename} + COMMAND + armasm64.exe ${ARMASM_FLAGS} ${preprocess_filename} ${obj_filename} + DEPENDS ${asm_filename} + BYPRODUCTS ${preprocess_filename} + ) + target_sources(onnxruntime_mlas PRIVATE ${obj_filename}) + endforeach() + elseif(onnxruntime_target_platform STREQUAL "ARM") + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/arm/sgemmc.cpp + ) + elseif(onnxruntime_target_platform STREQUAL "x64") + + file(GLOB_RECURSE mlas_platform_srcs_avx CONFIGURE_DEPENDS + "${MLAS_SRC_DIR}/intrinsics/avx/*.cpp" + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "/arch:AVX") + + file(GLOB_RECURSE mlas_platform_srcs_avx2 CONFIGURE_DEPENDS + "${MLAS_SRC_DIR}/intrinsics/avx2/*.cpp" + ) + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "/arch:AVX2") + + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/dgemm.cpp + ${mlas_platform_srcs_avx} + ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp + ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAmx.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8S8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8U8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemmU8X8KernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvx512Vnni.asm + ${MLAS_SRC_DIR}/amd64/QgemvU8S8KernelAvxVnni.asm + ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm + ${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx512Core.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/DgemmKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelM1Avx.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/SgemmKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/SconvKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelSse2.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SpoolKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/sgemma.asm + ${MLAS_SRC_DIR}/amd64/cvtfp16a.asm + ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx.asm + ${MLAS_SRC_DIR}/amd64/SoftmaxKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/TransKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/TransKernelAvx512F.asm + ${MLAS_SRC_DIR}/amd64/LogisticKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/TanhKernelFma3.asm + ${MLAS_SRC_DIR}/amd64/ErfKernelFma3.asm + ) + + if(onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + set_source_files_properties(${MLAS_SRC_DIR}/amd64/ConvSymKernelAvx2.asm PROPERTIES COMPILE_FLAGS "-DENABLE_CONVSYMKERNELAVX2_SAT_CHECKER") + endif() + + if(MSVC_VERSION GREATER_EQUAL 1933) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/amd64/cvtfp16Avx.asm + ) + endif() + + if (NOT onnxruntime_ORT_MINIMAL_BUILD) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/q4gemm_avx512.cpp + ) + endif() + else() + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp + ${MLAS_SRC_DIR}/i386/SgemmKernelSse2.asm + ${MLAS_SRC_DIR}/i386/SgemmKernelAvx.asm + ) + endif() +endfunction() + +function(setup_kleidiai) + target_compile_definitions(onnxruntime_mlas PRIVATE USE_KLEIDIAI) + + # Disable the KleidiAI tests + set(KLEIDIAI_BUILD_TESTS OFF) + + # Fetch KleidiAI sources: + if (NOT TARGET kleidiai) + onnxruntime_fetchcontent_declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai} EXCLUDE_FROM_ALL) + endif() + onnxruntime_fetchcontent_makeavailable(kleidiai) + + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/kai_ukernel_interface.cpp + ) + target_link_libraries(onnxruntime_mlas PRIVATE kleidiai) +endfunction() + +if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/wasm_simd/*.cpp" + ) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/qgemm_kernel_wasmsimd.cpp + ) + if (onnxruntime_ENABLE_WEBASSEMBLY_RELAXED_SIMD) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/qgemm_kernel_wasmrelaxedsimd.cpp + ) + endif() + else() + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/scalar/*.cpp" + ) + endif() + target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) +elseif(MSVC) + setup_mlas_source_for_windows() +else() + + if(APPLE) + get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) + + if(NOT ONNXRUNTIME_MLAS_OSX_ARCH) + set(ONNXRUNTIME_MLAS_OSX_ARCH ${CMAKE_HOST_SYSTEM_PROCESSOR}) + endif() + foreach(OSX_ARCH ${ONNXRUNTIME_MLAS_OSX_ARCH}) + if (OSX_ARCH STREQUAL "arm64") + set(ARM64 TRUE) + elseif (OSX_ARCH STREQUAL "arm64e") + set(ARM64 TRUE) + elseif (OSX_ARCH STREQUAL "arm") + set(ARM TRUE) + elseif (OSX_ARCH STREQUAL "x86_64") + set(X86_64 TRUE) + elseif (OSX_ARCH STREQUAL "i386") + set(X86 TRUE) + endif() + endforeach() + elseif(ANDROID) + if (CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") + set(ARM TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") + set(ARM64 TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") + set(X86_64 TRUE) + elseif (CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") + set(X86 TRUE) + endif() + else() + #Linux/FreeBSD/PowerPC/... + #The value of CMAKE_SYSTEM_PROCESSOR should be from `uname -m` + #Example values: + #arm64v8/ubuntu -> aarch64 + #arm32v6/alpine -> armv7l + #arm32v7/centos -> armv7l + #ppc64le/debian -> ppc64le + #s390x/ubuntu -> s390x + #ppc64le/busybox -> ppc64le + #arm64v8/ubuntu -> aarch64 + #Android: armv7-a aarch64 i686 x86_64 + #chasun: I don't think anyone uses 'arm64' + if(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm64.*") + set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^arm.*") + set(ARM TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*") + set(ARM64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(powerpc.*|ppc.*)") + set(POWER TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$") + set(X86 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$") + set(X86_64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") + set(LOONGARCH64 TRUE) + endif() + endif() + + if(APPLE) + get_target_property(ONNXRUNTIME_MLAS_MACOSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) + endif() + list(LENGTH ONNXRUNTIME_MLAS_MACOSX_ARCH ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH) + if(ONNXRUNTIME_MLAS_MACOSX_ARCH_LENGTH GREATER 1) + set(ONNXRUNTIME_MLAS_MULTI_ARCH TRUE) + endif() + #If ONNXRUNTIME_MLAS_MULTI_ARCH is true, we need to go through every if branch below + #and split MLAS to multiple static libraries. + #Otherwise, it works like if(...) elseif(...) elseif(...) endif() + set(MLAS_SOURCE_IS_NOT_SET 1) + if(ARM) + enable_language(ASM) + + set(CMAKE_ASM_FLAGS "${CMAKE_ASM_FLAGS} -mfpu=neon") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") + + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch32/QgemmU8X8KernelNeon.S + ${MLAS_SRC_DIR}/arm/sgemmc.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ) + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(ARM64 AND MLAS_SOURCE_IS_NOT_SET ) + enable_language(ASM) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDot.S + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelDotLd64.S + ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelDot.S + ${MLAS_SRC_DIR}/aarch64/ConvSymS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/ConvSymU8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvSymU8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/DepthwiseQConvKernelSize9Neon.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUdot.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSdot.S + ${MLAS_SRC_DIR}/aarch64/SgemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SgemvKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelNeon.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdot.S + ${MLAS_SRC_DIR}/aarch64/SymQgemmS8KernelSdotLd64.S + ${MLAS_SRC_DIR}/qgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_udot.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_sdot.cpp + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.h + ${MLAS_SRC_DIR}/qnbitgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_fp32.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon.cpp + ${MLAS_SRC_DIR}/hgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/softmax_kernel_neon.h + ${MLAS_SRC_DIR}/softmax_kernel_neon.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon.h + ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp + ) + if (onnxruntime_USE_KLEIDIAI) + setup_kleidiai() + endif() + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") + if (NOT APPLE) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp + ${MLAS_SRC_DIR}/cast_kernel_neon.cpp + ${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") + set_source_files_properties(${MLAS_SRC_DIR}/cast_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/hqnbitgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/rotary_embedding_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/halfgemm_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/softmax_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + endif() + + if(ONNXRUNTIME_MLAS_MULTI_ARCH) + add_library(onnxruntime_mlas_arm STATIC64 ${mlas_platform_srcs}) + set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") + list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) + set(mlas_platform_srcs ) + else() + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(POWER AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp + ${MLAS_SRC_DIR}/dgemm.cpp + ${MLAS_SRC_DIR}/power/DgemmKernelPower.cpp + ${MLAS_SRC_DIR}/power/QuantizePower.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPower.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") + + check_cxx_compiler_flag("-mcpu=power9" HAS_POWER9) + if (HAS_POWER9) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/QuantizePowerVSX.cpp PROPERTIES COMPILE_FLAGS "-mcpu=power9") + endif() + + check_cxx_compiler_flag("-mcpu=power10" HAS_POWER10) + if(HAS_POWER10) + set(CMAKE_REQUIRED_FLAGS "-mcpu=power10") + check_cxx_source_compiles(" + #include + int main() { + __vector_quad acc0; + __builtin_mma_xxsetaccz (&acc0); + return 0; + }" + COMPILES_P10 + ) + if(COMPILES_P10) + check_cxx_source_compiles(" + #ifdef _AIX + #define POWER_10 0x40000 + #define POWER_10_ANDUP (POWER_10) + #include + #define __power_10_andup() (_system_configuration.implementation & POWER_10_ANDUP) + int main() { + bool HasP10 = (__power_10_andup() && __power_mma_version() == MMA_V31); + return 0; + } + #else + #include + int main() { + unsigned long hwcap2 = getauxval(AT_HWCAP2); + bool HasP10 = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); + return 0; + } + #endif" + HAS_P10_RUNTIME + ) + if (HAS_P10_RUNTIME) + set_source_files_properties(${MLAS_SRC_DIR}/platform.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") + set_source_files_properties(${MLAS_SRC_DIR}/qgemm.cpp PROPERTIES COMPILE_FLAGS "-DPOWER10") + endif() + set(mlas_platform_srcs_power10 + ${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp + ${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp + ${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/power/SgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10 -DSINGLE") + set_source_files_properties(${MLAS_SRC_DIR}/power/DgemmKernelPOWER10.cpp PROPERTIES COMPILE_FLAGS "-O2 -mcpu=power10") + set_source_files_properties(${MLAS_SRC_DIR}/power/qgemm_kernel_power10.cpp PROPERTIES COMPILE_FLAGS "-O3 -mcpu=power10") + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${mlas_platform_srcs_power10} + ) + endif() + endif() + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(X86 AND MLAS_SOURCE_IS_NOT_SET) + enable_language(ASM) + + set(mlas_platform_srcs_sse2 + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/x86/SgemmKernelSse2.S + ) + set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") + + set(mlas_platform_srcs_avx + ${MLAS_SRC_DIR}/x86/SgemmKernelAvx.S + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") + + set(mlas_platform_srcs + ${mlas_platform_srcs_sse2} + ${mlas_platform_srcs_avx} + ) + + # In r23, NDK remove __x86.get_pc_thunk.* from libatomic. Add our own + # implementation to avoid external dependency. + if(ANDROID) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/x86/x86.get_pc_thunk.S + ) + endif() + + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(X86_64 AND MLAS_SOURCE_IS_NOT_SET) + enable_language(ASM) + + # Forward the flags for the minimum target platform version from the C + # compiler to the assembler. This works around CMakeASMCompiler.cmake.in + # not including the logic to set this flag for the assembler. + set(CMAKE_ASM${ASM_DIALECT}_OSX_DEPLOYMENT_TARGET_FLAG "${CMAKE_C_OSX_DEPLOYMENT_TARGET_FLAG}") + + # The LLVM assembler does not support the .arch directive to enable instruction + # set extensions and also doesn't support AVX-512F instructions without + # turning on support via command-line option. Group the sources by the + # instruction set extension and explicitly set the compiler flag as appropriate. + + set(mlas_platform_srcs_sse2 + ${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp + ${MLAS_SRC_DIR}/x86_64/DgemmKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Sse2.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelSse2.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelSse2.S + ) + if(NOT APPLE) + set(mlas_platform_srcs_sse2 + ${mlas_platform_srcs_sse2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16a.S + ) + endif() + set_source_files_properties(${mlas_platform_srcs_sse2} PROPERTIES COMPILE_FLAGS "-msse2") + + set(mlas_platform_srcs_avx + ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1Avx.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelM1TransposeBAvx.S + ${MLAS_SRC_DIR}/x86_64/SgemmTransposePackB16x4Avx.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx.S + ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx.S + ${MLAS_SRC_DIR}/intrinsics/avx/min_max_elements.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx} PROPERTIES COMPILE_FLAGS "-mavx") + + set(mlas_platform_srcs_avx2 + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8U8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvxVnni.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S + ${MLAS_SRC_DIR}/x86_64/DgemmKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/TransKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/LogisticKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/TanhKernelFma3.S + ${MLAS_SRC_DIR}/x86_64/ErfKernelFma3.S + ${MLAS_SRC_DIR}/intrinsics/avx2/qladd_avx2.cpp + ${MLAS_SRC_DIR}/intrinsics/avx2/qdwconv_avx2.cpp + ${MLAS_SRC_DIR}/intrinsics/avx2/saturation_check_avx2.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.h + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ${MLAS_SRC_DIR}/rotary_embedding_kernel_avx2.cpp + ) + if(CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 13.1 AND NOT(APPLE)) + set(mlas_platform_srcs_avx2 + ${mlas_platform_srcs_avx2} + ${MLAS_SRC_DIR}/x86_64/cvtfp16Avx.S + ) + endif() + +message(STATUS "CMAKE_CXX_COMPILER_ID: ${CMAKE_CXX_COMPILER_ID}") +message(STATUS "CMAKE_CXX_COMPILER_VERSION: ${CMAKE_CXX_COMPILER_VERSION}") + +if(NOT "${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" OR CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "11") + message(STATUS "Using -mavx2 -mfma -mavxvnni flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -mavxvnni") +else() + message(STATUS "Using -mavx2 -mfma flags") + set_source_files_properties(${mlas_platform_srcs_avx2} PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c") +endif() + set(mlas_platform_srcs_avx512f + ${MLAS_SRC_DIR}/x86_64/DgemmKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SgemmKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SconvKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SoftmaxKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/SpoolKernelAvx512F.S + ${MLAS_SRC_DIR}/x86_64/TransKernelAvx512F.S + ${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512f} PROPERTIES COMPILE_FLAGS "-mavx512f") + + set(mlas_platform_srcs_avx512core + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Core.S + ${MLAS_SRC_DIR}/x86_64/QgemvU8S8KernelAvx512Vnni.S + ${MLAS_SRC_DIR}/x86_64/QgemmU8X8KernelAvx512Core.S + ${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx512Core.S + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512core} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl") + + set(mlas_platform_srcs_avx512vnni + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_avx512vnni.cpp + ) + set_source_files_properties(${mlas_platform_srcs_avx512vnni} PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") + + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/activate_fp16.cpp + ${MLAS_SRC_DIR}/dwconv.cpp + ${MLAS_SRC_DIR}/dgemm.cpp + ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp + ${mlas_platform_srcs_sse2} + ${mlas_platform_srcs_avx} + ${mlas_platform_srcs_avx2} + ${mlas_platform_srcs_avx512f} + ${mlas_platform_srcs_avx512core} + ${mlas_platform_srcs_avx512vnni} + ) + + if (NOT onnxruntime_ORT_MINIMAL_BUILD) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/q4gemm_avx512.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f") + endif() + if(NOT APPLE) + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmxCommon.S + ${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp + ${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S + ) + set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f") + endif() + + if(onnxruntime_ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + set_source_files_properties(${MLAS_SRC_DIR}/x86_64/ConvSymKernelAvx2.S PROPERTIES COMPILE_FLAGS "-mavx2 -mfma -mf16c -DENABLE_CONVSYMKERNELAVX2_SAT_CHECKER") + endif() + + if(ONNXRUNTIME_MLAS_MULTI_ARCH) + add_library(onnxruntime_mlas_x STATIC86_64 ${mlas_platform_srcs}) + set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") + list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) + set(mlas_platform_srcs ) + else() + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S + ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S + ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S + ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S + ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S + ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx") + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) + file(GLOB_RECURSE mlas_platform_srcs + "${MLAS_SRC_DIR}/scalar/*.cpp") + elseif (onnxruntime_FORCE_GENERIC_ALGORITHMS) + file(GLOB_RECURSE mlas_platform_srcs_generic + "${MLAS_SRC_DIR}/scalar/*.cpp") + set(mlas_platform_srcs + ${mlas_platform_srcs} + ${mlas_platform_srcs_generic} + ) + endif() + target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs}) +endif() + +foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) + target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_INCLUDE_DIR} ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) + target_link_libraries(${mlas_target} Microsoft.GSL::GSL) + + set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") +endforeach() + +if (WIN32) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") + if (onnxruntime_ENABLE_STATIC_ANALYSIS) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") + endif() +endif() + +if (PLATFORM_NAME STREQUAL "macabi") + # Needed for maccatalyst C compilation + # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" + target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) +endif() + +if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS onnxruntime_mlas EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) +endif() + +# set up source group for MLAS source files +block() + set(source_group_srcs) + foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) + get_target_property(mlas_target_srcs ${mlas_target} SOURCES) + foreach(mlas_target_src ${mlas_target_srcs}) + cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) + if(in_mlas_root) + list(APPEND source_group_srcs ${mlas_target_src}) + endif() + endforeach() + endforeach() +endblock() + + + + + # + # Command line tool for quantization and de-quantization of 2-D fp32 tensors + # based on block-wise quantization of int4 + # + + add_executable(onnxruntime_mlas_q4dq + ${MLAS_SRC_DIR}/q4_dq_cli.cpp + ) + target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") + + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) + if(NOT MLAS_NO_ONNXRUNTIME) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE onnxruntime_common) + endif() + if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE cpuinfo) + endif() + if(NOT WIN32) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${CMAKE_DL_LIBS}) + endif() + if (CMAKE_SYSTEM_NAME STREQUAL "Android") + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${android_shared_libs}) + endif() + + if(WIN32) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE debug Dbghelp Advapi32) + endif() + if (onnxruntime_LINK_LIBATOMIC) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE atomic) + endif() + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE Threads::Threads) + + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") + if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1") + else() + set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") + endif() + endif() diff --git a/src/lib/activate_fp16.cpp b/src/lib/activate_fp16.cpp index 776ec67..564c6a4 100644 --- a/src/lib/activate_fp16.cpp +++ b/src/lib/activate_fp16.cpp @@ -51,12 +51,12 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - return MlasMaximumFloat16x8(ZeroVec, Value); + return MlasMaximumFloat16(ZeroVec, Value); } MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { - return MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(ZeroVec), Value); + return MlasMaximumFloat16(MlasToLowHalfFloat16x4(ZeroVec), Value); } }; @@ -75,7 +75,7 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiplyFloat16x8(Value, AlphaBroadcast); + MLAS_FLOAT16X8 ValueTimesAlpha = MlasMultiplyFloat16(Value, AlphaBroadcast); return MlasBitwiseSelectFloat16x8(MlasCmpLessEqualFloat16x8(Value, ZeroVec), ValueTimesAlpha, Value); } @@ -83,7 +83,7 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { MLAS_FLOAT16X4 ValueTimesAlpha = - MlasMultiplyFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast)); + MlasMultiplyFloat16(Value, MlasToLowHalfFloat16x4(AlphaBroadcast)); return MlasBitwiseSelectFloat16x4( MlasCmpLessEqualFloat16x4(Value, MlasToLowHalfFloat16x4(ZeroVec)), ValueTimesAlpha, Value); @@ -539,16 +539,16 @@ struct MLAS_HALF_ACTIVATION_FUNCTION { MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - Value = MlasMaximumFloat16x8(MinimumBroadcast, Value); - Value = MlasMinimumFloat16x8(MaximumBroadcast, Value); + Value = MlasMaximumFloat16(MinimumBroadcast, Value); + Value = MlasMinimumFloat16(MaximumBroadcast, Value); return Value; } MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { - Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); - Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); + Value = MlasMaximumFloat16(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); + Value = MlasMinimumFloat16(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); return Value; } }; @@ -573,19 +573,19 @@ struct MLAS_HALF_ACTIVATION_FUNCTION MLAS_FLOAT16X8 Activate(MLAS_FLOAT16X8 Value) { - Value = MlasMultiplyAddFloat16x8(Value, AlphaBroadcast, BetaBroadcast); - Value = MlasMinimumFloat16x8(MaximumBroadcast, Value); - Value = MlasMaximumFloat16x8(MinimumBroadcast, Value); + Value = MlasMultiplyAddFloat16(Value, AlphaBroadcast, BetaBroadcast); + Value = MlasMinimumFloat16(MaximumBroadcast, Value); + Value = MlasMaximumFloat16(MinimumBroadcast, Value); return Value; } MLAS_FLOAT16X4 Activate(MLAS_FLOAT16X4 Value) { - Value = MlasMultiplyAddFloat16x4(Value, MlasToLowHalfFloat16x4(AlphaBroadcast), + Value = MlasMultiplyAddFloat16(Value, MlasToLowHalfFloat16x4(AlphaBroadcast), MlasToLowHalfFloat16x4(BetaBroadcast)); - Value = MlasMinimumFloat16x4(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); - Value = MlasMaximumFloat16x4(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); + Value = MlasMinimumFloat16(MlasToLowHalfFloat16x4(MaximumBroadcast), Value); + Value = MlasMaximumFloat16(MlasToLowHalfFloat16x4(MinimumBroadcast), Value); return Value; } @@ -692,7 +692,7 @@ MlasActivationKernel( MLAS_FLOAT16X8 AVec = MlasLoadFloat16x8(addsrc); MLAS_FLOAT16X8 Vector = MlasLoadFloat16x8(buffer); addsrc += 8; - Vector = MlasAddFloat16x8(Vector, AVec); + Vector = MlasAddFloat16(Vector, AVec); Vector = ActivationFunction.Activate(Vector); MlasStoreFloat16x8(buffer, Vector); buffer += 8; @@ -703,7 +703,7 @@ MlasActivationKernel( MLAS_FLOAT16X4 AVec = MlasLoadFloat16x4(addsrc); MLAS_FLOAT16X4 Vector = MlasLoadFloat16x4(buffer); addsrc += 4; - Vector = MlasAddFloat16x4(Vector, AVec); + Vector = MlasAddFloat16(Vector, AVec); Vector = ActivationFunction.Activate(Vector); MlasStoreFloat16x4(buffer, Vector); buffer += 4; @@ -715,7 +715,7 @@ MlasActivationKernel( MLAS_FLOAT16X4 buf; std::memcpy(&addbuf, addsrc, n * sizeof(_mlas_fp16_)); std::memcpy(&buf, buffer, n * sizeof(_mlas_fp16_)); - buf = MlasAddFloat16x4(buf, addbuf); + buf = MlasAddFloat16(buf, addbuf); buf = ActivationFunction.Activate(buf); MlasStorePartialFloat16x4(buffer, buf, n); } diff --git a/src/lib/amd64/ConvSymKernelAvx2.asm b/src/lib/amd64/ConvSymKernelAvx2.asm index a42d7ff..9c334be 100644 --- a/src/lib/amd64/ConvSymKernelAvx2.asm +++ b/src/lib/amd64/ConvSymKernelAvx2.asm @@ -23,6 +23,87 @@ INCLUDE ConvSymKernelCommon.inc INCLUDE AssembleAvxVnni.inc .list +extern CheckSaturationForVPMADDUBSW:proc + +CheckSaturation MACRO VecReg1Num, VecReg2Num + +; +; Save all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11). no RSI, RDI. +; + + push_reg rax + push_reg rcx + push_reg rdx + push_reg r8 + push_reg r9 + push_reg r10 + push_reg r11 + + sub rsp, 512 ; reserve space for 16 YMM registers (32 bytes) + +; +; Save YMM registers (YMM0 to YMM15). +; + + vmovdqu YMMWORD PTR [rsp], ymm0 + vmovdqu YMMWORD PTR [rsp+32], ymm1 + vmovdqu YMMWORD PTR [rsp+64], ymm2 + vmovdqu YMMWORD PTR [rsp+96], ymm3 + vmovdqu YMMWORD PTR [rsp+128], ymm4 + vmovdqu YMMWORD PTR [rsp+160], ymm5 + vmovdqu YMMWORD PTR [rsp+192], ymm6 + vmovdqu YMMWORD PTR [rsp+224], ymm7 + vmovdqu YMMWORD PTR [rsp+256], ymm8 + vmovdqu YMMWORD PTR [rsp+288], ymm9 + vmovdqu YMMWORD PTR [rsp+320], ymm10 + vmovdqu YMMWORD PTR [rsp+352], ymm11 + vmovdqu YMMWORD PTR [rsp+384], ymm12 + vmovdqu YMMWORD PTR [rsp+416], ymm13 + vmovdqu YMMWORD PTR [rsp+448], ymm14 + vmovdqu YMMWORD PTR [rsp+480], ymm15 + + lea rcx, [rsp+32*VecReg1Num] ; first operand (unsigned) + lea rdx, [rsp+32*VecReg2Num] ; second operand (signed) + + call CheckSaturationForVPMADDUBSW + +; +; Restore YMM registers. +; + + vmovdqu ymm0, YMMWORD PTR [rsp] + vmovdqu ymm1, YMMWORD PTR [rsp+32] + vmovdqu ymm2, YMMWORD PTR [rsp+64] + vmovdqu ymm3, YMMWORD PTR [rsp+96] + vmovdqu ymm4, YMMWORD PTR [rsp+128] + vmovdqu ymm5, YMMWORD PTR [rsp+160] + vmovdqu ymm6, YMMWORD PTR [rsp+192] + vmovdqu ymm7, YMMWORD PTR [rsp+224] + vmovdqu ymm8, YMMWORD PTR [rsp+256] + vmovdqu ymm9, YMMWORD PTR [rsp+288] + vmovdqu ymm10, YMMWORD PTR [rsp+320] + vmovdqu ymm11, YMMWORD PTR [rsp+352] + vmovdqu ymm12, YMMWORD PTR [rsp+384] + vmovdqu ymm13, YMMWORD PTR [rsp+416] + vmovdqu ymm14, YMMWORD PTR [rsp+448] + vmovdqu ymm15, YMMWORD PTR [rsp+480] + + add rsp, 512 ; clean up the reserved stack space + +; +; Restore all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11), no RSI, RDI. +; + + pop r11 + pop r10 + pop r9 + pop r8 + pop rdx + pop rcx + pop rax + + ENDM + ; ; Macro Description: ; @@ -50,9 +131,15 @@ INCLUDE AssembleAvxVnni.inc MultiplyAccumulateRowAvx2 MACRO Vec1Reg, Vec2Reg +IFDEF ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER + CheckSaturation 2,0 +ENDIF vpmaddubsw ymm3,ymm2,ymm0 vpmaddwd ymm3,ymm3,ymm12 vpaddd Vec1Reg,Vec1Reg,ymm3 +IFDEF ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER + CheckSaturation 2,1 +ENDIF vpmaddubsw ymm2,ymm2,ymm1 vpmaddwd ymm2,ymm2,ymm12 vpaddd Vec2Reg,Vec2Reg,ymm2 diff --git a/src/lib/fp16_neon_common.cpp b/src/lib/cast_kernel_neon.cpp similarity index 99% rename from src/lib/fp16_neon_common.cpp rename to src/lib/cast_kernel_neon.cpp index 29734c2..8a385c9 100644 --- a/src/lib/fp16_neon_common.cpp +++ b/src/lib/cast_kernel_neon.cpp @@ -6,7 +6,7 @@ Licensed under the MIT License. Module Name: - fp16_neon_common.cpp + cast_kernel_neon.cpp Abstract: diff --git a/src/lib/compute.cpp b/src/lib/compute.cpp index 73df23e..96a2398 100644 --- a/src/lib/compute.cpp +++ b/src/lib/compute.cpp @@ -20,6 +20,7 @@ Module Name: --*/ #include "mlasi.h" +#include "softmax.h" // // Bundles the constants for use by kernels written in assembly. @@ -68,12 +69,13 @@ MLAS_INTERNAL_DATA const float MlasMinimumF32Value = std::numeric_limits: // threads. // +template struct MLAS_SOFTMAX_WORK_BLOCK { ptrdiff_t ThreadCountN; bool LogSoftmax; bool SmoothSoftmax; - const float* Input; - float* Output; + const T* Input; + T* Output; size_t N; size_t D; }; @@ -244,9 +246,10 @@ Return Value: } } +template <> void MLASCALL -MlasComputeExp( +MlasComputeExp( const float* Input, float* Output, size_t N @@ -280,6 +283,20 @@ Return Value: #endif } +template <> +void MLASCALL +MlasComputeExp( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +) { + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + if (dispatch == nullptr || dispatch->Exp_Fp16 == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Exp_Fp16 is not supported."); + } + dispatch->Exp_Fp16(Input, Output, N); +} + MLAS_FORCEINLINE MLAS_FLOAT32X4 MlasComputeSumExpVector( @@ -783,10 +800,18 @@ Return Value: } } +template void MlasComputeSoftmaxThreaded( void* Context, ptrdiff_t Index +); + +template <> +void +MlasComputeSoftmaxThreaded( + void* Context, + ptrdiff_t Index ) /*++ @@ -807,7 +832,7 @@ Return Value: --*/ { - const auto* WorkBlock = (MLAS_SOFTMAX_WORK_BLOCK*)Context; + const auto* WorkBlock = (MLAS_SOFTMAX_WORK_BLOCK*)Context; // // Partition the operation along the N dimension. @@ -906,11 +931,85 @@ Return Value: } } +template <> +void +MlasComputeSoftmaxThreaded( + void* Context, + ptrdiff_t Index +) +/*++ + +Routine Description: + + This routine is invoked from a worker thread to execute a segment of a + softmax or log softmax operation. + +Arguments: + + Context - Supplies the pointer to the context for the threaded operation. + + ThreadId - Supplies the current index of the threaded operation. + +Return Value: + + None. + +--*/ +{ + const auto* WorkBlock = (MLAS_SOFTMAX_WORK_BLOCK*)Context; + size_t n; + size_t CountN; + MlasPartitionWork(Index, WorkBlock->ThreadCountN, WorkBlock->N, &n, &CountN); + + const size_t D = WorkBlock->D; + const bool LogSoftmax = WorkBlock->LogSoftmax; + const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; + + const MLAS_FP16* Input = WorkBlock->Input + n * D; + MLAS_FP16* Output = WorkBlock->Output + n * D; + + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + if (dispatch == nullptr || + dispatch->ReduceMax_Fp16 == nullptr || + dispatch->SumExp_Fp16 == nullptr || + (LogSoftmax && dispatch->LogSoftmax_Fp16 == nullptr) || + (!LogSoftmax && dispatch->Softmax_Fp16 == nullptr)) { + MLAS_THROW_EX(std::runtime_error, "Lacks kernels for fp16 softmax."); + } + + while (CountN > 0) { + MLAS_FP16 Maximum = dispatch->ReduceMax_Fp16(Input, D); + MLAS_FP16 NegativeMaximum = Maximum.Negate(); + if (SmoothSoftmax && !NegativeMaximum.IsNegative()) { + NegativeMaximum = MLAS_FP16::FromBits(0); + } + + MLAS_FP16* Temp = LogSoftmax ? nullptr : Output; + MLAS_FP16 Accumulation = dispatch->SumExp_Fp16(Input, Temp, D, NegativeMaximum); + float accumulation_fp32 = Accumulation.ToFloat(); + + if (SmoothSoftmax) { + accumulation_fp32 += expf(NegativeMaximum.ToFloat()); + } + + if (LogSoftmax) { + dispatch->LogSoftmax_Fp16(Input, Output, D, NegativeMaximum, MLAS_FP16(std::log(accumulation_fp32))); + } else { + dispatch->Softmax_Fp16(Output, Output, D, MLAS_FP16(accumulation_fp32)); + } + + Input += D; + Output += D; + CountN--; + } +} + +template void MLASCALL MlasComputeSoftmax( - const float* Input, - float* Output, + const T* Input, + T* Output, size_t N, size_t D, bool LogSoftmax, @@ -949,7 +1048,7 @@ Return Value: --*/ { - MLAS_SOFTMAX_WORK_BLOCK WorkBlock; + MLAS_SOFTMAX_WORK_BLOCK WorkBlock; // // Capture the softmax parameters to the work block. @@ -985,5 +1084,67 @@ Return Value: WorkBlock.ThreadCountN = ThreadCountN; - MlasExecuteThreaded(MlasComputeSoftmaxThreaded, &WorkBlock, ThreadCountN, ThreadPool); + MlasExecuteThreaded(MlasComputeSoftmaxThreaded, &WorkBlock, ThreadCountN, ThreadPool); +} + +template +void +MLASCALL +MlasComputeSoftmax( + const float* Input, + float* Output, + size_t N, + size_t D, + bool LogSoftmax, + bool SmoothSoftmax, + MLAS_THREADPOOL* ThreadPool +); + +template +void +MLASCALL +MlasComputeSoftmax( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + size_t D, + bool LogSoftmax, + bool SmoothSoftmax, + MLAS_THREADPOOL* ThreadPool +); + +template <> +bool +MLASCALL +MlasGQASupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB +) { + if (!MlasHGemmSupported(TransA, TransB)) { + return false; + } + + const auto* softmax_dispatch = GetMlasPlatform().SoftmaxDispatch; + if (softmax_dispatch == nullptr || + softmax_dispatch->Tanh_Fp16 == nullptr || + softmax_dispatch->Softcap_Fp16 == nullptr || + softmax_dispatch->SumExp_Fp16 == nullptr || + softmax_dispatch->Softmax_Fp16 == nullptr || + softmax_dispatch->ReduceMax_Fp16 == nullptr) { + return false; + } + + return true; +} + +template <> +bool +MLASCALL +MlasGQASupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB +) { + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + return true; } diff --git a/src/lib/dwconv.cpp b/src/lib/dwconv.cpp index d48d9cb..0fff937 100644 --- a/src/lib/dwconv.cpp +++ b/src/lib/dwconv.cpp @@ -43,7 +43,7 @@ MlasConvDepthwiseKernel( MLAS_FLOAT16X8 InputVector = MlasLoadFloat16x8(&Input[k][ChannelOffset]); MLAS_FLOAT16X8 FilterVector = MlasLoadFloat16x8(&Filter[ChannelKernelOffset]); - Accumulator = MlasMultiplyAddFloat16x8(InputVector, FilterVector, Accumulator); + Accumulator = MlasMultiplyAddFloat16(InputVector, FilterVector, Accumulator); ChannelKernelOffset += Channels; } MlasStoreFloat16x8(Output, Accumulator); @@ -61,7 +61,7 @@ MlasConvDepthwiseKernel( MLAS_FLOAT16X4 InputVector = MlasLoadFloat16x4(&Input[k][ChannelOffset]); MLAS_FLOAT16X4 FilterVector = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]); - Accumulator = MlasMultiplyAddFloat16x4(InputVector, FilterVector, Accumulator); + Accumulator = MlasMultiplyAddFloat16(InputVector, FilterVector, Accumulator); ChannelKernelOffset += Channels; } MlasStoreFloat16x4(Output, Accumulator); @@ -80,7 +80,7 @@ MlasConvDepthwiseKernel( MLAS_FLOAT16X4 InputValue = MlasLoadFloat16x4(&Input[k][ChannelOffset]); MLAS_FLOAT16X4 FilterValue = MlasLoadFloat16x4(&Filter[ChannelKernelOffset]); - Accumulator = MlasMultiplyAddFloat16x4(InputValue, FilterValue, Accumulator); + Accumulator = MlasMultiplyAddFloat16(InputValue, FilterValue, Accumulator); ChannelKernelOffset += Channels; } MlasStorePartialFloat16x4(Output, Accumulator, c); diff --git a/src/lib/eltwise.cpp b/src/lib/eltwise.cpp new file mode 100644 index 0000000..f63d71b --- /dev/null +++ b/src/lib/eltwise.cpp @@ -0,0 +1,71 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise.cpp + +Abstract: + + This module implements routines to compute element-wise operations on two vectors. + + Currently supported element-wise operations: + - Add + +--*/ + +#include "mlasi.h" +#include "eltwise.h" + +template <> +void +MLASCALL +MlasEltwiseAdd( + const float* left, + const float* right, + float* output, + size_t N +) { + while (N > 0) { + if (N >= 4) { + MLAS_FLOAT32X4 LeftVec = MlasLoadFloat32x4(left); + MLAS_FLOAT32X4 RightVec = MlasLoadFloat32x4(right); + + MLAS_FLOAT32X4 ResultVec = MlasAddFloat32x4(LeftVec, RightVec); + + MlasStoreFloat32x4(output, ResultVec); + + left += 4; + right += 4; + output += 4; + N -= 4; + } else { + *output = *left + *right; + + left += 1; + right += 1; + output += 1; + N -= 1; + } + } +} + + +template <> +void +MLASCALL +MlasEltwiseAdd( + const MLAS_FP16* left, + const MLAS_FP16* right, + MLAS_FP16* output, + size_t N +) { + const auto* dispatch = GetMlasPlatform().EltwiseDispatch; + if (dispatch == nullptr || dispatch->Add_Fp16 == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Add_Fp16 is not supported."); + } + dispatch->Add_Fp16(left, right, output, N); +} diff --git a/src/lib/eltwise.h b/src/lib/eltwise.h new file mode 100644 index 0000000..a8345c4 --- /dev/null +++ b/src/lib/eltwise.h @@ -0,0 +1,37 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + element-wise operations. + +--*/ +#pragma once + +#include "mlasi.h" + +struct MLAS_ELTWISE_DISPATCH { + /** + * @brief Compute the element-wise addition of the two given vectors + * @param left Address of the left operand + * @param right Address of the right operand + * @param output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input arrays + */ + typedef void(Add_Fp16_Fn)( + const MLAS_FP16* left, + const MLAS_FP16* right, + MLAS_FP16* output, + size_t N + ); + + Add_Fp16_Fn* Add_Fp16 = nullptr; +}; diff --git a/src/lib/eltwise_kernel_neon.cpp b/src/lib/eltwise_kernel_neon.cpp new file mode 100644 index 0000000..415c128 --- /dev/null +++ b/src/lib/eltwise_kernel_neon.cpp @@ -0,0 +1,32 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon.cpp + +Abstract: + + This module implements the element-wise kernels for ARM NEON. + +--*/ + +#include "eltwise.h" +#include "eltwise_kernel_neon.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_ELTWISE_DISPATCH MlasEltwiseDispatchNeon = []() { + MLAS_ELTWISE_DISPATCH d; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + if (MlasFp16AccelerationSupported()) { + d.Add_Fp16 = eltwise_neon::Add_Kernel_Fp16; + } +#endif + return d; +}(); diff --git a/src/lib/eltwise_kernel_neon.h b/src/lib/eltwise_kernel_neon.h new file mode 100644 index 0000000..d99a3e9 --- /dev/null +++ b/src/lib/eltwise_kernel_neon.h @@ -0,0 +1,28 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + element-wise operations on ARM cpu. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" + +namespace eltwise_neon { + +void Add_Kernel_Fp16(const MLAS_FP16* left, const MLAS_FP16* right, MLAS_FP16* output, size_t N); + +} // namespace eltwise_neon diff --git a/src/lib/eltwise_kernel_neon_fp16.cpp b/src/lib/eltwise_kernel_neon_fp16.cpp new file mode 100644 index 0000000..decbdb5 --- /dev/null +++ b/src/lib/eltwise_kernel_neon_fp16.cpp @@ -0,0 +1,118 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + eltwise_kernel_neon_fp16.cpp + +Abstract: + + This module implements the fp16 element-wise kernels for ARM NEON. + +--*/ +#include +#include + +#include "fp16_common.h" +#include "eltwise.h" +#include "eltwise_kernel_neon.h" + +namespace eltwise_neon { + +void Add_Kernel_Fp16(const MLAS_FP16* left, const MLAS_FP16* right, MLAS_FP16* output, size_t N) { + const auto* left_fp16 = reinterpret_cast(left); + const auto* right_fp16 = reinterpret_cast(right); + auto* output_fp16 = reinterpret_cast<_mlas_fp16_*>(output); + + while (N >= 32) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto l1 = MlasLoadFloat16x8(left_fp16 + 8); + auto l2 = MlasLoadFloat16x8(left_fp16 + 16); + auto l3 = MlasLoadFloat16x8(left_fp16 + 24); + + auto r0 = MlasLoadFloat16x8(right_fp16); + auto r1 = MlasLoadFloat16x8(right_fp16 + 8); + auto r2 = MlasLoadFloat16x8(right_fp16 + 16); + auto r3 = MlasLoadFloat16x8(right_fp16 + 24); + + auto o0 = MlasAddFloat16(l0, r0); + auto o1 = MlasAddFloat16(l1, r1); + auto o2 = MlasAddFloat16(l2, r2); + auto o3 = MlasAddFloat16(l3, r3); + + MlasStoreFloat16x8(output_fp16, o0); + MlasStoreFloat16x8(output_fp16 + 8, o1); + MlasStoreFloat16x8(output_fp16 + 16, o2); + MlasStoreFloat16x8(output_fp16 + 24, o3); + + left_fp16 += 32; + right_fp16 += 32; + output_fp16 += 32; + N -= 32; + } + + if (N & 16) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto l1 = MlasLoadFloat16x8(left_fp16 + 8); + + auto r0 = MlasLoadFloat16x8(right_fp16); + auto r1 = MlasLoadFloat16x8(right_fp16 + 8); + + auto o0 = MlasAddFloat16(l0, r0); + auto o1 = MlasAddFloat16(l1, r1); + + MlasStoreFloat16x8(output_fp16, o0); + MlasStoreFloat16x8(output_fp16 + 8, o1); + + left_fp16 += 16; + right_fp16 += 16; + output_fp16 += 16; + N -= 16; + } + + if (N & 8) { + auto l0 = MlasLoadFloat16x8(left_fp16); + auto r0 = MlasLoadFloat16x8(right_fp16); + auto o0 = MlasAddFloat16(l0, r0); + MlasStoreFloat16x8(output_fp16, o0); + + left_fp16 += 8; + right_fp16 += 8; + output_fp16 += 8; + N -= 8; + } + + if (N & 4) { + auto l0 = MlasLoadFloat16x4(left_fp16); + auto r0 = MlasLoadFloat16x4(right_fp16); + auto o0 = MlasAddFloat16(l0, r0); + MlasStoreFloat16x4(output_fp16, o0); + + left_fp16 += 4; + right_fp16 += 4; + output_fp16 += 4; + N -= 4; + } + + if (N == 3) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 3); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 3); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 3); + } else if (N == 2) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 2); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 2); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 2); + } else if (N == 1) { + auto l0 = MlasLoadPartialFloat16x4(left_fp16, 1); + auto r0 = MlasLoadPartialFloat16x4(right_fp16, 1); + auto o0 = MlasAddFloat16(l0, r0); + MlasStorePartialFloat16x4(output_fp16, o0, 1); + } +} + +} // namespace eltwise_neon diff --git a/src/lib/fp16_common.h b/src/lib/fp16_common.h index f4c4990..d4713cc 100644 --- a/src/lib/fp16_common.h +++ b/src/lib/fp16_common.h @@ -27,10 +27,28 @@ typedef float16x8_t MLAS_FLOAT16X8; typedef float16x4_t MLAS_FLOAT16X4; typedef uint16x8_t MLAS_UINT16X8; typedef uint16x4_t MLAS_UINT16X4; +typedef int16x8_t MLAS_INT16X8; +typedef int16x4_t MLAS_INT16X4; MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasReinterpretAsFloat16x8(MLAS_INT32X4 Vector) { return vreinterpretq_f16_s32(Vector); } +MlasReinterpretInt32AsFloat16(MLAS_INT32X4 Vector) { return vreinterpretq_f16_s32(Vector); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X8 +MlasReinterpretInt16AsFloat16(MLAS_INT16X8 Vector) { return vreinterpretq_f16_s16(Vector); } + +MLAS_FORCEINLINE +MLAS_FLOAT16X4 +MlasReinterpretInt16AsFloat16(MLAS_INT16X4 Vector) { return vreinterpret_f16_s16(Vector); } + +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasReinterpretFloat16AsInt16(MLAS_FLOAT16X8 Vector) { return vreinterpretq_s16_f16(Vector); } + +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasReinterpretFloat16AsInt16(MLAS_FLOAT16X4 Vector) { return vreinterpret_s16_f16(Vector); } MLAS_FORCEINLINE MLAS_FLOAT16X8 @@ -142,94 +160,114 @@ MlasToLowHalfFloat16x4(MLAS_FLOAT16X8 V) MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasAddFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vaddq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasAddFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasAddFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vadd_f16(Vector1, Vector2); } +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasAddInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +{ + return vaddq_s16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasAddInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +{ + return vadd_s16(Vector1, Vector2); +} + MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasSubtractFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasSubtractFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vsubq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasSubtractFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasSubtractFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vsub_f16(Vector1, Vector2); } +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasSubtractInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +{ + return vsubq_s16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasSubtractInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +{ + return vsub_s16(Vector1, Vector2); +} + MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMultiplyFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasMultiplyFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vmulq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMultiplyFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasMultiplyFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vmul_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasDivFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasDivideFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vdivq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasDivFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasDivideFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vdiv_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Vector3) +MlasMultiplyAddFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X8 Dest) { - return vfmaq_f16(Vector3, Vector1, Vector2); + return vfmaq_f16(Dest, Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMultiplyAddFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2, MLAS_FLOAT16X4 Vector3) +MlasMultiplyAddFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2, MLAS_FLOAT16X4 Dest) { - return vfma_f16(Vector3, Vector1, Vector2); + return vfma_f16(Dest, Vector1, Vector2); } - MLAS_FORCEINLINE void MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, _mlas_fp16_ Scalar2, MLAS_FLOAT16X8 Vector3) { - MlasMultiplyAddFloat16x8(Vector1, MlasBroadcastFloat16x8(Scalar2), Vector3); + MlasMultiplyAddFloat16(Vector1, MlasBroadcastFloat16x8(Scalar2), Vector3); } MLAS_FORCEINLINE void MlasMultiplyAddFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, _mlas_fp16_ Scalar3) { - MlasMultiplyAddFloat16x8(Vector1, Vector2, MlasBroadcastFloat16x8(Scalar3)); -} - -MLAS_FORCEINLINE -MLAS_FLOAT16X8 -MlasDivideFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) -{ - return vdivq_f16(Vector1, Vector2); + MlasMultiplyAddFloat16(Vector1, Vector2, MlasBroadcastFloat16x8(Scalar3)); } MLAS_FORCEINLINE @@ -277,50 +315,127 @@ MlasBlendFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2, MLAS_FLOAT16X MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMaximumFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasMaximumFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vmaxq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMaximumFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasMaximumFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vmax_f16(Vector1, Vector2); } +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasMaximumInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +{ + return vmaxq_s16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasMaximumInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +{ + return vmax_s16(Vector1, Vector2); +} + MLAS_FORCEINLINE MLAS_FLOAT16X8 -MlasMinimumFloat16x8(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) +MlasMinimumFloat16(MLAS_FLOAT16X8 Vector1, MLAS_FLOAT16X8 Vector2) { return vminq_f16(Vector1, Vector2); } MLAS_FORCEINLINE MLAS_FLOAT16X4 -MlasMinimumFloat16x4(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) +MlasMinimumFloat16(MLAS_FLOAT16X4 Vector1, MLAS_FLOAT16X4 Vector2) { return vmin_f16(Vector1, Vector2); } +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasMinimumInt16(MLAS_INT16X8 Vector1, MLAS_INT16X8 Vector2) +{ + return vminq_s16(Vector1, Vector2); +} + +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasMinimumInt16(MLAS_INT16X4 Vector1, MLAS_INT16X4 Vector2) +{ + return vmin_s16(Vector1, Vector2); +} + MLAS_FORCEINLINE MLAS_FLOAT16X8 MlasClampFloat16x8(MLAS_FLOAT16X8 Value, _mlas_fp16_ LowerRange, _mlas_fp16_ UpperRange) { - Value = MlasMaximumFloat16x8(MlasBroadcastFloat16x8(LowerRange), Value); - Value = MlasMinimumFloat16x8(MlasBroadcastFloat16x8(UpperRange), Value); + Value = MlasMaximumFloat16(MlasBroadcastFloat16x8(LowerRange), Value); + Value = MlasMaximumFloat16(MlasBroadcastFloat16x8(UpperRange), Value); + return Value; +} + +template +MLAS_FORCEINLINE +T +MlasClampFloat16(T Value, T LowerRange, T UpperRange) +{ + Value = MlasMaximumFloat16(LowerRange, Value); + Value = MlasMinimumFloat16(UpperRange, Value); + return Value; +} + +template +MLAS_FORCEINLINE +T +MlasClampInt16(T Value, T LowerRange, T UpperRange) +{ + Value = MlasMaximumInt16(LowerRange, Value); + Value = MlasMinimumInt16(UpperRange, Value); return Value; } MLAS_FORCEINLINE _mlas_fp16_ -MlasReduceAddFloat16x8(MLAS_FLOAT16X8 Vector) +MlasReduceAddFloat16(MLAS_FLOAT16X8 Vector) { + Vector = vpaddq_f16(Vector, Vector); Vector = vpaddq_f16(Vector, Vector); Vector = vpaddq_f16(Vector, Vector); return vgetq_lane_u16(vreinterpretq_u16_f16(Vector), 0); } +MLAS_FORCEINLINE +_mlas_fp16_ +MlasReduceAddFloat16(MLAS_FLOAT16X4 Vector) +{ + Vector = vpadd_f16(Vector, Vector); + Vector = vpadd_f16(Vector, Vector); + return vget_lane_u16(vreinterpret_u16_f16(Vector), 0); +} + +MLAS_FORCEINLINE +_mlas_fp16_ +MlasReduceMaximumFloat16(MLAS_FLOAT16X8 Vector) +{ + Vector = vpmaxq_f16(Vector, Vector); + Vector = vpmaxq_f16(Vector, Vector); + Vector = vpmaxq_f16(Vector, Vector); + return vgetq_lane_u16(vreinterpretq_u16_f16(Vector), 0); +} + +MLAS_FORCEINLINE +_mlas_fp16_ +MlasReduceMaximumFloat16(MLAS_FLOAT16X4 Vector) +{ + Vector = vpmax_f16(Vector, Vector); + Vector = vpmax_f16(Vector, Vector); + return vget_lane_u16(vreinterpret_u16_f16(Vector), 0); +} + MLAS_FORCEINLINE MLAS_UINT16X8 MlasCmpLessEqualFloat16x8(MLAS_FLOAT16X8 left, MLAS_FLOAT16X8 right) @@ -349,4 +464,119 @@ MlasBitwiseSelectFloat16x4(MLAS_UINT16X4 select, MLAS_FLOAT16X4 ones, MLAS_FLOAT return vbsl_f16(select, ones, zeros); } +MLAS_FORCEINLINE +void +Transpose8x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3, + MLAS_FLOAT16X8& v4, MLAS_FLOAT16X8& v5, MLAS_FLOAT16X8& v6, MLAS_FLOAT16X8& v7) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // |v40|v41|v42|v43|v44|v45|v46|v47| + // |v50|v51|v52|v53|v54|v55|v56|v57| + // |v60|v61|v62|v63|v64|v65|v66|v67| + // |v70|v71|v72|v73|v74|v75|v76|v77| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + float16x8x2_t t45 = vtrnq_f16(v4, v5); + float16x8x2_t t67 = vtrnq_f16(v6, v7); + // |v00|v10|v02|v12|v04|v14|v06|v16| + // |v01|v11|v03|v13|v05|v15|v07|v17| + // |v20|v30|v22|v32|v24|v34|v26|v36| + // |v21|v31|v23|v33|v25|v35|v27|v37| + // |v40|v50|v42|v52|v44|v54|v46|v56| + // |v41|v51|v43|v53|v45|v55|v47|v57| + // |v60|v70|v62|v72|v64|v74|v66|v76| + // |v61|v71|v63|v73|v65|v75|v67|v77| + float32x4x2_t t02 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0])); + float32x4x2_t t13 = vtrnq_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1])); + float32x4x2_t t46 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[0]), vreinterpretq_f32_f16(t67.val[0])); + float32x4x2_t t57 = vtrnq_f32(vreinterpretq_f32_f16(t45.val[1]), vreinterpretq_f32_f16(t67.val[1])); + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + // |v40|v50|v60|v70|v44|v54|v64|v74| + // |v41|v51|v61|v71|v45|v55|v65|v75| + // |v42|v52|v62|v72|v46|v56|v66|v76| + // |v43|v53|v63|v73|v47|v57|v67|v77| + v0 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0]))); + v4 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[0]), vreinterpretq_f64_f32(t46.val[0]))); + v2 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1]))); + v6 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t02.val[1]), vreinterpretq_f64_f32(t46.val[1]))); + v1 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0]))); + v5 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[0]), vreinterpretq_f64_f32(t57.val[0]))); + v3 = vreinterpretq_f16_f64(vtrn1q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1]))); + v7 = vreinterpretq_f16_f64(vtrn2q_f64(vreinterpretq_f64_f32(t13.val[1]), vreinterpretq_f64_f32(t57.val[1]))); + // |v00|v10|v20|v30|v40|v50|v60|v70| + // |v01|v11|v21|v31|v41|v51|v61|v71| + // |v02|v12|v22|v32|v42|v52|v62|v72| + // |v03|v13|v23|v33|v43|v53|v63|v73| + // |v04|v14|v24|v34|v44|v54|v64|v74| + // |v05|v15|v25|v35|v45|v55|v65|v75| + // |v06|v16|v26|v36|v46|v56|v66|v76| + // |v07|v17|v27|v37|v47|v57|v67|v77| +} + +MLAS_FORCEINLINE +void +Transpose4x8(MLAS_FLOAT16X8& v0, MLAS_FLOAT16X8& v1, MLAS_FLOAT16X8& v2, MLAS_FLOAT16X8& v3) +{ + // |v00|v01|v02|v03|v04|v05|v06|v07| + // |v10|v11|v12|v13|v14|v15|v16|v17| + // |v20|v21|v22|v23|v24|v25|v26|v27| + // |v30|v31|v32|v33|v34|v35|v36|v37| + // => + // |v00|v10|v20|v30|v04|v14|v24|v34| + // |v01|v11|v21|v31|v05|v15|v25|v35| + // |v02|v12|v22|v32|v06|v16|v26|v36| + // |v03|v13|v23|v33|v07|v17|v27|v37| + float16x8x2_t t01 = vtrnq_f16(v0, v1); + float16x8x2_t t23 = vtrnq_f16(v2, v3); + + v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); + v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); + v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); +} + +MLAS_FORCEINLINE +void +Transpose4x4(MLAS_FLOAT16X4& v0, MLAS_FLOAT16X4& v1, MLAS_FLOAT16X4& v2, MLAS_FLOAT16X4& v3) +{ + // |v00|v01|v02|v03| + // |v10|v11|v12|v13| + // |v20|v21|v22|v23| + // |v30|v31|v32|v33| + // => + // |v00|v10|v20|v30| + // |v01|v11|v21|v31| + // |v02|v12|v22|v32| + // |v03|v13|v23|v33| + float16x4x2_t t01 = vtrn_f16(v0, v1); + float16x4x2_t t23 = vtrn_f16(v2, v3); + + v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); + v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); + v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); +} + +template +MLAS_FORCEINLINE +MLAS_INT16X8 +MlasShiftLeftInt16(MLAS_INT16X8 Vector) +{ + return vshlq_n_s16(Vector, ShiftCount); +} + +template +MLAS_FORCEINLINE +MLAS_INT16X4 +MlasShiftLeftInt16(MLAS_INT16X4 Vector) +{ + return vshl_n_s16(Vector, ShiftCount); +} + #endif // fp16 vector intrinsic supported diff --git a/src/lib/halfgemm.cpp b/src/lib/halfgemm.cpp index 49387d2..66a3356 100644 --- a/src/lib/halfgemm.cpp +++ b/src/lib/halfgemm.cpp @@ -324,6 +324,253 @@ MlasHalfGemmKernel( } } +bool +MLASCALL +MlasHGemmSupported( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB +) { + auto* dispatch = GetMlasPlatform().HGemmDispatch; + if (TransA == CblasNoTrans && TransB == CblasTrans) { + return dispatch && + dispatch->HGemmKernel_TransposedB && + dispatch->HPackBKernel_TransposedB && + dispatch->HGemmKernel_PackedB; + } else if (TransA == CblasNoTrans && TransB == CblasNoTrans) { + return dispatch && + dispatch->HGemmKernel_B && + dispatch->HPackBKernel_B && + dispatch->HGemmKernel_PackedB; + } + + return false; +} + +void +HGemmOperation( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t K, // full K slice + const MLAS_HGEMM_DATA_PARAMS* DataParams, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) { + const size_t lda = DataParams->lda; + const size_t ldb = DataParams->ldb; + const size_t ldc = DataParams->ldc; + const _mlas_fp16_ alpha = DataParams->alpha; + const _mlas_fp16_ beta = DataParams->beta; + auto* dispatch = GetMlasPlatform().HGemmDispatch; + constexpr size_t StrideM = 2; + const auto beta_add = MLAS_FP16(1.0f); + constexpr size_t buffer_size = MLAS_HGEMM_STRIDEN * MLAS_HGEMM_STRIDEK; + + if (TransA == CblasNoTrans && TransB == CblasTrans) { + const auto* A = DataParams->A + RangeStartM * lda; + const auto* B = DataParams->B + RangeStartN * ldb; + auto* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + if (RangeCountM <= StrideM) { + if (!dispatch || !dispatch->HGemmKernel_TransposedB) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x Transposed(B) kernels"); + } + // When M is small, B is visited once. The overhead of Pack(B') exceeds the benefits + // from A x Pack(B'). Therefore directly calculate A x B'. + // Without PackB, to utilize memory locality, iterate full K. + constexpr size_t StrideN = MLAS_HGEMM_STRIDEN_THREAD_ALIGN; + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + dispatch->HGemmKernel_TransposedB(A, B, C, RangeCountM, countN, K, lda, ldb, ldc, alpha, beta); + B += countN * ldb; + C += countN; + } + } else { + if (!dispatch || !dispatch->HPackBKernel_TransposedB || !dispatch->HGemmKernel_PackedB) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x Transposed(B) kernels"); + } + // 16N is the smallest pack unit. + // TODO(fajin): optimize alpha == 1 + MLAS_DECLSPEC_ALIGN(MLAS_FP16 PackedB[buffer_size], MLAS_HGEMM_STRIDEN_THREAD_ALIGN * sizeof(_mlas_fp16_)); + size_t StrideN = MLAS_HGEMM_STRIDEN; + size_t StrideK = MLAS_HGEMM_STRIDEK; + if (RangeCountN >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + + } else { + while (StrideN > MLAS_HGEMM_STRIDEN_THREAD_ALIGN && StrideN / 2 >= RangeCountN) { + StrideK *= 2; + StrideN /= 2; + } + } + + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + const MLAS_FP16* a = A; + const MLAS_FP16* b = B; + MLAS_FP16* c = C; + for (size_t k = 0, countK; k < K; k += countK) { + countK = std::min(StrideK, K - k); + dispatch->HPackBKernel_TransposedB(b, PackedB, countN, countK, ldb); + const MLAS_FP16* aa = a; + MLAS_FP16* cc = c; + for (size_t m = 0, countM; m < RangeCountM; m += countM) { + countM = std::min(StrideM, RangeCountM - m); + // First K iteration, beta is applied to the whole C. In rest K iterations, use add mode. + dispatch->HGemmKernel_PackedB( + aa, PackedB, cc, countM, countN, countK, lda, ldc, alpha, k == 0 ? beta : beta_add.val); + aa += countM * lda; + cc += countM * ldc; + } + a += countK; + b += countK; + } + B += countN * ldb; + C += countN; + } + } + } else if (TransA == CblasNoTrans && TransB == CblasNoTrans) { + const auto* A = DataParams->A + RangeStartM * lda; + const auto* B = DataParams->B + RangeStartN; + auto* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + if (RangeCountM <= StrideM) { + if (!dispatch || !dispatch->HGemmKernel_B) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x B kernels"); + } + + // When M is small, B is visited once. The overhead of Pack(B) exceeds the benefits + // from A x Pack(B). Therefore directly calculate A x B. + // When beta is 0 or 1, iterate full N and cache accumulators in C. + // When beta is not 0 or 1, iterate full K, accumulat in register, max 8 accumulators. + // TODO(fajin): merge beta cases with alpha == 1 + dispatch->HGemmKernel_B(A, B, C, RangeCountM, RangeCountN, K, lda, ldb, ldc, alpha, beta); + } else { + if (!dispatch || !dispatch->HPackBKernel_B || !dispatch->HGemmKernel_PackedB) { + MLAS_THROW_EX(std::runtime_error, "hgemm does not have A x B kernels"); + } + // TODO(fajin): optimize blocking for large K small N + // - pack along N + // - loop K in outer loop + // - optimize alpha == 1 case + MLAS_DECLSPEC_ALIGN(MLAS_FP16 PackedB[buffer_size], MLAS_HGEMM_STRIDEN_THREAD_ALIGN * sizeof(_mlas_fp16_)); + size_t StrideN = MLAS_HGEMM_STRIDEN; + size_t StrideK = MLAS_HGEMM_STRIDEK; + if (RangeCountN >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + } else { + while (StrideN > MLAS_HGEMM_STRIDEN_THREAD_ALIGN && StrideN / 2 >= RangeCountN) { + StrideK *= 2; + StrideN /= 2; + } + } + + for (size_t n = 0, countN; n < RangeCountN; n += countN) { + countN = std::min(StrideN, RangeCountN - n); + const MLAS_FP16* a = A; + const MLAS_FP16* b = B; + MLAS_FP16* c = C; + for (size_t k = 0, countK; k < K; k += countK) { + countK = std::min(StrideK, K - k); + dispatch->HPackBKernel_B(b, PackedB, countN, countK, ldb); + const MLAS_FP16* aa = a; + MLAS_FP16* cc = c; + for (size_t m = 0, countM; m < RangeCountM; m += countM) { + countM = std::min(StrideM, RangeCountM - m); + // First K iteration, beta is applied to the whole C. In rest K iterations, use add mode. + dispatch->HGemmKernel_PackedB( + aa, PackedB, cc, countM, countN, countK, lda, ldc, alpha, k == 0 ? beta : beta_add.val); + aa += countM * lda; + cc += countM * ldc; + } + a += countK; + b += countK * ldb; + } + B += countN; + C += countN; + } + } + } else { + MLAS_THROW_EX(std::runtime_error, "hgemm currently only support A x Transpoe(B) or A x B"); + } +} + +void +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_HGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) { + if (!ThreadPool) { + for (size_t gemm_i = 0; gemm_i < BatchSize; gemm_i++) { + HGemmOperation(TransA, TransB, K, &Data[gemm_i], 0, M, 0, N); + } + return; + } + + const double Complexity = double(M) * double(N) * double(K) * double(BatchSize); + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_HGEMM_THREAD_COMPLEXITY)) + 1; + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // Segment the operation across multiple threads. + + ptrdiff_t ThreadsPerGemm = TargetThreadCount / BatchSize; + if (ThreadsPerGemm < 1) { + ThreadsPerGemm = 1; + } + + constexpr size_t StrideM = 128; + + size_t nc = N; + if (ThreadsPerGemm > 1) { + // more than one thread per GEMM + + const size_t BlockedM = MlasDivRoundup(M, StrideM); + const size_t max_nc = MlasDivRoundup(N * BlockedM, ThreadsPerGemm); + if (max_nc < nc) { + nc = std::min( + nc, MlasDivRoundup(max_nc, MLAS_HGEMM_STRIDEN_THREAD_ALIGN) * MLAS_HGEMM_STRIDEN_THREAD_ALIGN); + } + } + const size_t StrideN = nc; + + const size_t ThreadCountM = MlasDivRoundup(M, StrideM); + const size_t ThreadCountN = MlasDivRoundup(N, StrideN); + ThreadsPerGemm = ThreadCountM * ThreadCountN; + + MlasTrySimpleParallel(ThreadPool, ThreadsPerGemm * static_cast(BatchSize), [&](ptrdiff_t tid) { + const auto gemm_i = tid / ThreadsPerGemm; + const auto blk_i = tid % ThreadsPerGemm; + + const ptrdiff_t ThreadIdN = blk_i / ThreadCountM; + const ptrdiff_t ThreadIdM = blk_i % ThreadCountM; + + const size_t RangeStartM = ThreadIdM * StrideM; + const size_t RangeCountM = std::min(M - RangeStartM, (size_t)StrideM); + + const size_t RangeStartN = ThreadIdN * StrideN; + const size_t RangeCountN = std::min(N - RangeStartN, (size_t)StrideN); + + HGemmOperation(TransA, TransB, K, &Data[gemm_i], RangeStartM, RangeCountM, RangeStartN, RangeCountN); + }); +} const MLAS_HALFGEMM_DISPATCH MlasHalfGemmDispatchDefault = { MlasHalfGemmOperation, diff --git a/src/lib/halfgemm.h b/src/lib/halfgemm.h index 61e2fbb..529db48 100644 --- a/src/lib/halfgemm.h +++ b/src/lib/halfgemm.h @@ -513,3 +513,200 @@ MlasHalfGemmGetDispatch() return &MlasHalfGemmDispatchDefault; #endif } + +namespace hgemm_neon { + +void HPackB_TransposedB_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +); + +void HPackB_B_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +); + +void HGemm_TransposedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +); + +void HGemm_B_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +); + +void HGemm_PackedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +); + +} // namespace hgemm_neon + +struct MLAS_HGEMM_DISPATCH { + /** + * @brief Pack the B matrix segment. B is column-major. Elements from CountK rows x N columns are packed + * continuously in row-major. + * First pack CountK rows x 32 columns, then pack CountK rows x 16 columns, then 8. + * If there are < 8 columns left, pad the columns with 0. + * @param B the first element of the B matrix segment. Column major. + * @param[out] PackedB the first element of the packed B matrix segment. + * @param CountN the number of columns of B chunk. + * @param CountK the number of rows of B chunk. + * @param ldb the leading dimension of B. + */ + typedef void(HPackBKernel_TransposedB_Fn) ( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb + ); + + HPackBKernel_TransposedB_Fn* HPackBKernel_TransposedB = nullptr; + + /** + * @brief Pack the B matrix segment. B is row-major. Elements from CountK rows x N columns are packed + * continuously in row-major. + * First pack CountK rows x 32 columns, then pack CountK rows x 16 columns, then 8. + * If there are < 8 columns left, pad the columns with 0. + * @param B the first element of the B matrix segment. Row major. + * @param[out] PackedB the first element of the packed B matrix segment. + * @param CountN the number of columns of B chunk. + * @param CountK the number of rows of B chunk. + * @param ldb the leading dimension of B. + */ + typedef void(HPackBKernel_B_Fn) ( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb + ); + + HPackBKernel_B_Fn* HPackBKernel_B = nullptr; + + /** + * @brief C = alpha * A * Transpose(B) + beta * C. CountM <= 2. B is not packed. Used when M is small. + * + * @param A first row of the A matrix segment. Row major. + * @param B first column of the B matrix segment. Column major. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param CountK the number of columns of A chunk and the number of rows of B chunk. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + * @param alpha the alpha scalar value. + * @param beta the beta scalar value. + */ + typedef void(HGemmKernel_TransposedB_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta + ); + + HGemmKernel_TransposedB_Fn* HGemmKernel_TransposedB = nullptr; + + /** + * @brief C = alpha * A * B + beta * C. CountM <= 2. B is not packed. Used when M is small. + * + * @param A first row of the A matrix segment. Row major. + * @param B first row of the B matrix segment. Row major. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param CountK the number of columns of A chunk and the number of rows of B chunk. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + * @param alpha the alpha scalar value. + * @param beta the beta scalar value. + */ + typedef void(HGemmKernel_B_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta + ); + + HGemmKernel_B_Fn* HGemmKernel_B = nullptr; + + /** + * @brief C = alpha * A * Transpose(B) + beta * C. CountM <= 2. B has been packed using + * HPackBKernel_TransposedB_Fn or HPackBKernel_B_Fn. Use when M is large. + * + * @param A first row of the A matrix segment. Row major. + * @param PackedB first element of the packed B buffer. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param CountK the number of columns of A chunk and the number of rows of B chunk. + * @param lda the leading dimension of A. + * @param ldc the leading dimension of C. + * @param alpha the alpha scalar value. + * @param beta the beta scalar value. + */ + typedef void(HGemmKernel_PackedB_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta + ); + + HGemmKernel_PackedB_Fn* HGemmKernel_PackedB = nullptr; +}; diff --git a/src/lib/halfgemm_kernel_neon_fp16.cpp b/src/lib/halfgemm_kernel_neon_fp16.cpp new file mode 100644 index 0000000..959df7f --- /dev/null +++ b/src/lib/halfgemm_kernel_neon_fp16.cpp @@ -0,0 +1,3174 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + halfgemm_kernel_neon_fp16.cpp + +Abstract: + + This module implements half precision GEMM kernel for neon. + +--*/ + +#include + +#include "halfgemm.h" +#include "fp16_common.h" + +namespace hgemm_neon { + +void HPackB_TransposedB_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +) { + const _mlas_fp16_* B_data = reinterpret_cast(B); + _mlas_fp16_* PackedB_data = reinterpret_cast<_mlas_fp16_*>(PackedB); + const bool Kr0 = (CountK % 4) > 0; + const bool Kr1 = (CountK % 4) > 1; + const bool Kr2 = (CountK % 4) > 2; + const bool Kr3 = CountK & 4; + for (; CountN >= 32; CountN -= 32, B_data += 32 * ldb) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 32; // pack 8 * 16 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + size_t baseb = 0; + size_t basep = 0; + float16x8_t v0 = MlasLoadFloat16x8(b); + float16x8_t v1 = MlasLoadFloat16x8(b + ldb); + float16x8_t v2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t v3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t v4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t v5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t v6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t v7 = MlasLoadFloat16x8(b + 7 * ldb); + for (size_t i = 0; i < 3; ++i, baseb += 8 * ldb, basep += 8) { + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + MlasStoreFloat16x8(PackedB_data + basep, v0); + MlasStoreFloat16x8(PackedB_data + basep + 32, v1); + MlasStoreFloat16x8(PackedB_data + basep + 64, v2); + MlasStoreFloat16x8(PackedB_data + basep + 96, v3); + MlasStoreFloat16x8(PackedB_data + basep + 128, v4); + MlasStoreFloat16x8(PackedB_data + basep + 160, v5); + MlasStoreFloat16x8(PackedB_data + basep + 192, v6); + MlasStoreFloat16x8(PackedB_data + basep + 224, v7); + v0 = MlasLoadFloat16x8(b + baseb + 8 * ldb); + v1 = MlasLoadFloat16x8(b + baseb + 9 * ldb); + v2 = MlasLoadFloat16x8(b + baseb + 10 * ldb); + v3 = MlasLoadFloat16x8(b + baseb + 11 * ldb); + v4 = MlasLoadFloat16x8(b + baseb + 12 * ldb); + v5 = MlasLoadFloat16x8(b + baseb + 13 * ldb); + v6 = MlasLoadFloat16x8(b + baseb + 14 * ldb); + v7 = MlasLoadFloat16x8(b + baseb + 15 * ldb); + } + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + MlasStoreFloat16x8(PackedB_data + basep, v0); + MlasStoreFloat16x8(PackedB_data + basep + 32, v1); + MlasStoreFloat16x8(PackedB_data + basep + 64, v2); + MlasStoreFloat16x8(PackedB_data + basep + 96, v3); + MlasStoreFloat16x8(PackedB_data + basep + 128, v4); + MlasStoreFloat16x8(PackedB_data + basep + 160, v5); + MlasStoreFloat16x8(PackedB_data + basep + 192, v6); + MlasStoreFloat16x8(PackedB_data + basep + 224, v7); + } + + if (Kr3) { + size_t baseb = 0; + size_t basep = 0; + float16x4_t v0 = MlasLoadFloat16x4(b); + float16x4_t v1 = MlasLoadFloat16x4(b + ldb); + float16x4_t v2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t v3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t v4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t v5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t v6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t v7 = MlasLoadFloat16x4(b + 7 * ldb); + for (size_t i = 0; i < 3; ++i, baseb += 8 * ldb, basep += 8) { + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data + basep, v0); + MlasStoreFloat16x4(PackedB_data + basep + 4, v4); + MlasStoreFloat16x4(PackedB_data + basep + 32, v1); + MlasStoreFloat16x4(PackedB_data + basep + 36, v5); + MlasStoreFloat16x4(PackedB_data + basep + 64, v2); + MlasStoreFloat16x4(PackedB_data + basep + 68, v6); + MlasStoreFloat16x4(PackedB_data + basep + 96, v3); + MlasStoreFloat16x4(PackedB_data + basep + 100, v7); + v0 = MlasLoadFloat16x4(b + baseb + 8 * ldb); + v1 = MlasLoadFloat16x4(b + baseb + 9 * ldb); + v2 = MlasLoadFloat16x4(b + baseb + 10 * ldb); + v3 = MlasLoadFloat16x4(b + baseb + 11 * ldb); + v4 = MlasLoadFloat16x4(b + baseb + 12 * ldb); + v5 = MlasLoadFloat16x4(b + baseb + 13 * ldb); + v6 = MlasLoadFloat16x4(b + baseb + 14 * ldb); + v7 = MlasLoadFloat16x4(b + baseb + 15 * ldb); + } + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data + basep, v0); + MlasStoreFloat16x4(PackedB_data + basep + 4, v4); + MlasStoreFloat16x4(PackedB_data + basep + 32, v1); + MlasStoreFloat16x4(PackedB_data + basep + 36, v5); + MlasStoreFloat16x4(PackedB_data + basep + 64, v2); + MlasStoreFloat16x4(PackedB_data + basep + 68, v6); + MlasStoreFloat16x4(PackedB_data + basep + 96, v3); + MlasStoreFloat16x4(PackedB_data + basep + 100, v7); + k -= 4, b += 4, PackedB_data += 4 * 32; + } + + if (Kr0) { + size_t baseb = 0; + size_t basep = 0; + float16x4_t v0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t v1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t v2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t v3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t v4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t v5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t v6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t v7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + for (size_t i = 0; i < 3; ++i, baseb += 8 * ldb, basep += 8) { + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data + basep, v0); + MlasStoreFloat16x4(PackedB_data + basep + 4, v4); + if (Kr1) { + MlasStoreFloat16x4(PackedB_data + basep + 32, v1); + MlasStoreFloat16x4(PackedB_data + basep + 36, v5); + } + if (Kr2) { + MlasStoreFloat16x4(PackedB_data + basep + 64, v2); + MlasStoreFloat16x4(PackedB_data + basep + 68, v6); + } + v0 = MlasLoadPartialFloat16x4(b + baseb + 8 * ldb, k); + v1 = MlasLoadPartialFloat16x4(b + baseb + 9 * ldb, k); + v2 = MlasLoadPartialFloat16x4(b + baseb + 10 * ldb, k); + v3 = MlasLoadPartialFloat16x4(b + baseb + 11 * ldb, k); + v4 = MlasLoadPartialFloat16x4(b + baseb + 12 * ldb, k); + v5 = MlasLoadPartialFloat16x4(b + baseb + 13 * ldb, k); + v6 = MlasLoadPartialFloat16x4(b + baseb + 14 * ldb, k); + v7 = MlasLoadPartialFloat16x4(b + baseb + 15 * ldb, k); + + } + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data + basep, v0); + MlasStoreFloat16x4(PackedB_data + basep + 4, v4); + if (Kr1) { + MlasStoreFloat16x4(PackedB_data + basep + 32, v1); + MlasStoreFloat16x4(PackedB_data + basep + 36, v5); + } + if (Kr2) { + MlasStoreFloat16x4(PackedB_data + basep + 64, v2); + MlasStoreFloat16x4(PackedB_data + basep + 68, v6); + } + PackedB_data += k * 32; + } + } + + if (CountN & 16) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 16; // pack 8 * 16 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v0 = MlasLoadFloat16x8(b); + float16x8_t v1 = MlasLoadFloat16x8(b + ldb); + float16x8_t v2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t v3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t v4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t v5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t v6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t v7 = MlasLoadFloat16x8(b + 7 * ldb); + float16x8_t v8 = MlasLoadFloat16x8(b + 8 * ldb); + float16x8_t v9 = MlasLoadFloat16x8(b + 9 * ldb); + float16x8_t vA = MlasLoadFloat16x8(b + 10 * ldb); + float16x8_t vB = MlasLoadFloat16x8(b + 11 * ldb); + float16x8_t vC = MlasLoadFloat16x8(b + 12 * ldb); + float16x8_t vD = MlasLoadFloat16x8(b + 13 * ldb); + float16x8_t vE = MlasLoadFloat16x8(b + 14 * ldb); + float16x8_t vF = MlasLoadFloat16x8(b + 15 * ldb); + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + Transpose8x8(v8, v9, vA, vB, vC, vD, vE, vF); + + MlasStoreFloat16x8(PackedB_data, v0); + MlasStoreFloat16x8(PackedB_data + 8, v8); + MlasStoreFloat16x8(PackedB_data + 16, v1); + MlasStoreFloat16x8(PackedB_data + 24, v9); + MlasStoreFloat16x8(PackedB_data + 32, v2); + MlasStoreFloat16x8(PackedB_data + 40, vA); + MlasStoreFloat16x8(PackedB_data + 48, v3); + MlasStoreFloat16x8(PackedB_data + 56, vB); + MlasStoreFloat16x8(PackedB_data + 64, v4); + MlasStoreFloat16x8(PackedB_data + 72, vC); + MlasStoreFloat16x8(PackedB_data + 80, v5); + MlasStoreFloat16x8(PackedB_data + 88, vD); + MlasStoreFloat16x8(PackedB_data + 96, v6); + MlasStoreFloat16x8(PackedB_data + 104, vE); + MlasStoreFloat16x8(PackedB_data + 112, v7); + MlasStoreFloat16x8(PackedB_data + 120, vF); + } + + if (Kr3) { + float16x4_t v0 = MlasLoadFloat16x4(b); + float16x4_t v1 = MlasLoadFloat16x4(b + ldb); + float16x4_t v2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t v3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t v4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t v5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t v6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t v7 = MlasLoadFloat16x4(b + 7 * ldb); + float16x4_t v8 = MlasLoadFloat16x4(b + 8 * ldb); + float16x4_t v9 = MlasLoadFloat16x4(b + 9 * ldb); + float16x4_t vA = MlasLoadFloat16x4(b + 10 * ldb); + float16x4_t vB = MlasLoadFloat16x4(b + 11 * ldb); + float16x4_t vC = MlasLoadFloat16x4(b + 12 * ldb); + float16x4_t vD = MlasLoadFloat16x4(b + 13 * ldb); + float16x4_t vE = MlasLoadFloat16x4(b + 14 * ldb); + float16x4_t vF = MlasLoadFloat16x4(b + 15 * ldb); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + Transpose4x4(v8, v9, vA, vB); + Transpose4x4(vC, vD, vE, vF); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v8); + MlasStoreFloat16x4(PackedB_data + 12, vC); + MlasStoreFloat16x4(PackedB_data + 16, v1); + MlasStoreFloat16x4(PackedB_data + 20, v5); + MlasStoreFloat16x4(PackedB_data + 24, v9); + MlasStoreFloat16x4(PackedB_data + 28, vD); + MlasStoreFloat16x4(PackedB_data + 32, v2); + MlasStoreFloat16x4(PackedB_data + 36, v6); + MlasStoreFloat16x4(PackedB_data + 40, vA); + MlasStoreFloat16x4(PackedB_data + 44, vE); + MlasStoreFloat16x4(PackedB_data + 48, v3); + MlasStoreFloat16x4(PackedB_data + 52, v7); + MlasStoreFloat16x4(PackedB_data + 56, vB); + MlasStoreFloat16x4(PackedB_data + 60, vF); + + k -= 4, b += 4, PackedB_data += 4 * 16; + } + + if (Kr0) { + float16x4_t v0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t v1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t v2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t v3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t v4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t v5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t v6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t v7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + float16x4_t v8 = MlasLoadPartialFloat16x4(b + 8 * ldb, k); + float16x4_t v9 = MlasLoadPartialFloat16x4(b + 9 * ldb, k); + float16x4_t vA = MlasLoadPartialFloat16x4(b + 10 * ldb, k); + float16x4_t vB = MlasLoadPartialFloat16x4(b + 11 * ldb, k); + float16x4_t vC = MlasLoadPartialFloat16x4(b + 12 * ldb, k); + float16x4_t vD = MlasLoadPartialFloat16x4(b + 13 * ldb, k); + float16x4_t vE = MlasLoadPartialFloat16x4(b + 14 * ldb, k); + float16x4_t vF = MlasLoadPartialFloat16x4(b + 15 * ldb, k); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + Transpose4x4(v8, v9, vA, vB); + Transpose4x4(vC, vD, vE, vF); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v8); + MlasStoreFloat16x4(PackedB_data + 12, vC); + if (Kr1) { + MlasStoreFloat16x4(PackedB_data + 16, v1); + MlasStoreFloat16x4(PackedB_data + 20, v5); + MlasStoreFloat16x4(PackedB_data + 24, v9); + MlasStoreFloat16x4(PackedB_data + 28, vD); + } + if (Kr2) { + MlasStoreFloat16x4(PackedB_data + 32, v2); + MlasStoreFloat16x4(PackedB_data + 36, v6); + MlasStoreFloat16x4(PackedB_data + 40, vA); + MlasStoreFloat16x4(PackedB_data + 44, vE); + } + + PackedB_data += k * 16; + } + + CountN -= 16, B_data += 16 * ldb; + } + + if (CountN & 8) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 8; // pack 8 * 8 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v0 = MlasLoadFloat16x8(b); + float16x8_t v1 = MlasLoadFloat16x8(b + ldb); + float16x8_t v2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t v3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t v4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t v5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t v6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t v7 = MlasLoadFloat16x8(b + 7 * ldb); + Transpose8x8(v0, v1, v2, v3, v4, v5, v6, v7); + + MlasStoreFloat16x8(PackedB_data, v0); + MlasStoreFloat16x8(PackedB_data + 8, v1); + MlasStoreFloat16x8(PackedB_data + 16, v2); + MlasStoreFloat16x8(PackedB_data + 24, v3); + MlasStoreFloat16x8(PackedB_data + 32, v4); + MlasStoreFloat16x8(PackedB_data + 40, v5); + MlasStoreFloat16x8(PackedB_data + 48, v6); + MlasStoreFloat16x8(PackedB_data + 56, v7); + } + + if (Kr3) { + float16x4_t v0 = MlasLoadFloat16x4(b); + float16x4_t v1 = MlasLoadFloat16x4(b + ldb); + float16x4_t v2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t v3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t v4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t v5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t v6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t v7 = MlasLoadFloat16x4(b + 7 * ldb); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + MlasStoreFloat16x4(PackedB_data + 8, v1); + MlasStoreFloat16x4(PackedB_data + 12, v5); + MlasStoreFloat16x4(PackedB_data + 16, v2); + MlasStoreFloat16x4(PackedB_data + 20, v6); + MlasStoreFloat16x4(PackedB_data + 24, v3); + MlasStoreFloat16x4(PackedB_data + 28, v7); + k -= 4, b += 4, PackedB_data += 4 * 8; + } + + if (Kr0) { + float16x4_t v0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t v1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t v2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t v3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t v4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t v5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t v6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t v7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + Transpose4x4(v0, v1, v2, v3); + Transpose4x4(v4, v5, v6, v7); + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v4); + if (Kr1) { + MlasStoreFloat16x4(PackedB_data + 8, v1); + MlasStoreFloat16x4(PackedB_data + 12, v5); + } + if (Kr2) { + MlasStoreFloat16x4(PackedB_data + 16, v2); + MlasStoreFloat16x4(PackedB_data + 20, v6); + } + + PackedB_data += k * 8; + } + + B_data += 8 * ldb; + CountN -= 8; + } + + if (CountN > 0) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + constexpr size_t step = 8 * 8; // pack extended 8 * 8 + for (; k >= 8; k -= 8, b += 8, PackedB_data += step) { + float16x8_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadFloat16x8(b + i * ldb); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x8(); + } + Transpose8x8(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]); + MlasStoreFloat16x8(PackedB_data, v[0]); + MlasStoreFloat16x8(PackedB_data + 8, v[1]); + MlasStoreFloat16x8(PackedB_data + 16, v[2]); + MlasStoreFloat16x8(PackedB_data + 24, v[3]); + MlasStoreFloat16x8(PackedB_data + 32, v[4]); + MlasStoreFloat16x8(PackedB_data + 40, v[5]); + MlasStoreFloat16x8(PackedB_data + 48, v[6]); + MlasStoreFloat16x8(PackedB_data + 56, v[7]); + } + + if (Kr3) { + float16x4_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadFloat16x4(b + i * ldb); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x4(); + } + Transpose4x4(v[0], v[1], v[2], v[3]); + Transpose4x4(v[4], v[5], v[6], v[7]); + MlasStoreFloat16x4(PackedB_data, v[0]); + MlasStoreFloat16x4(PackedB_data + 4, v[4]); + MlasStoreFloat16x4(PackedB_data + 8, v[1]); + MlasStoreFloat16x4(PackedB_data + 12, v[5]); + MlasStoreFloat16x4(PackedB_data + 16, v[2]); + MlasStoreFloat16x4(PackedB_data + 20, v[6]); + MlasStoreFloat16x4(PackedB_data + 24, v[3]); + MlasStoreFloat16x4(PackedB_data + 28, v[7]); + k -= 4, b += 4, PackedB_data += 4 * 8; + } + + if (Kr0) { + float16x4_t v[8]; + size_t i = 0; + for (; i < CountN; ++i) { + v[i] = MlasLoadPartialFloat16x4(b + i * ldb, k); + } + for (; i < 8; ++i) { + v[i] = MlasZeroFloat16x4(); + } + Transpose4x4(v[0], v[1], v[2], v[3]); + Transpose4x4(v[4], v[5], v[6], v[7]); + MlasStoreFloat16x4(PackedB_data, v[0]); + MlasStoreFloat16x4(PackedB_data + 4, v[4]); + if (Kr1) { + MlasStoreFloat16x4(PackedB_data + 8, v[1]); + MlasStoreFloat16x4(PackedB_data + 12, v[5]); + } + if (Kr2) { + MlasStoreFloat16x4(PackedB_data + 16, v[2]); + MlasStoreFloat16x4(PackedB_data + 20, v[6]); + } + } + } +} + +void HPackB_B_Kernel( + const MLAS_FP16* B, + MLAS_FP16* PackedB, + size_t CountN, + size_t CountK, + size_t ldb +) { + const _mlas_fp16_* B_data = reinterpret_cast(B); + _mlas_fp16_* PackedB_data = reinterpret_cast<_mlas_fp16_*>(PackedB); + + for (; CountN >= 32; CountN -= 32, B_data += 32) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + uint16x8x4_t v0 = vld4q_u16(b); + for (; k >= 2; --k, b += ldb, PackedB_data += 32) { + vst4q_u16(PackedB_data, v0); + v0 = vld4q_u16(b + ldb); + } + vst4q_u16(PackedB_data, v0); + PackedB_data += 32; + } + + if (CountN & 16) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + uint16x8x2_t v0 = vld2q_u16(b); + for (; k >= 2; --k, b += ldb, PackedB_data += 16) { + vst2q_u16(PackedB_data, v0); + v0 = vld2q_u16(b + ldb); + } + vst2q_u16(PackedB_data, v0); + PackedB_data += 16; + CountN -= 16, B_data += 16; + } + + if (CountN & 8) { + const _mlas_fp16_* b = B_data; + size_t k = CountK; + uint16x8_t v0 = vld1q_u16(b); + for (; k >= 2; --k, b += ldb, PackedB_data += 8) { + vst1q_u16(PackedB_data, v0); + v0 = vld1q_u16(b + ldb); + } + vst1q_u16(PackedB_data, v0); + PackedB_data += 8; + + B_data += 8; + CountN -= 8; + } + + if (CountN > 4) { + float16x4_t v0 = MlasLoadFloat16x4(B_data); + float16x4_t v1 = MlasLoadPartialFloat16x4(B_data + 4, CountN - 4); + for (; CountK >= 2; B_data += ldb, PackedB_data += 8, --CountK) { + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v1); + v0 = MlasLoadFloat16x4(B_data + ldb); + v1 = MlasLoadPartialFloat16x4(B_data + ldb + 4, CountN - 4); + } + MlasStoreFloat16x4(PackedB_data, v0); + MlasStoreFloat16x4(PackedB_data + 4, v1); + } else if (CountN > 0) { + float16x4_t v0 = MlasLoadPartialFloat16x4(B_data, CountN); + for (; CountK >= 2; B_data += ldb, PackedB_data += 8, --CountK) { + MlasStoreFloat16x4(PackedB_data, v0); + v0 = MlasLoadPartialFloat16x4(B_data + ldb, CountN); + } + MlasStoreFloat16x4(PackedB_data, v0); + } +} + +MLAS_FORCEINLINE +float16x8_t addq_f16x4(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3) { + v0 = vaddq_f16(v0, v1); + v2 = vaddq_f16(v2, v3); + v0 = vaddq_f16(v0, v2); + return v0; +} + +MLAS_FORCEINLINE +float16x8_t addq_f16x8(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x8_t v4, float16x8_t v5, float16x8_t v6, float16x8_t v7) { + return vaddq_f16(addq_f16x4(v0, v1, v2, v3), addq_f16x4(v4, v5, v6, v7)); +} + +MLAS_FORCEINLINE +float16x8_t maq_lane_f16_accu(float16x8_t accu0, float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x4_t a0) { + accu0 = vfmaq_lane_f16(accu0, v0, a0, 0); + accu0 = vfmaq_lane_f16(accu0, v1, a0, 1); + accu0 = vfmaq_lane_f16(accu0, v2, a0, 2); + accu0 = vfmaq_lane_f16(accu0, v3, a0, 3); + return accu0; +} + +MLAS_FORCEINLINE +float16x8_t maq_laneq_f16_accu(float16x8_t accu0, float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, + float16x8_t v4, float16x8_t v5, float16x8_t v6, float16x8_t v7, float16x8_t a0) { + accu0 = vfmaq_laneq_f16(accu0, v0, a0, 0); + accu0 = vfmaq_laneq_f16(accu0, v1, a0, 1); + accu0 = vfmaq_laneq_f16(accu0, v2, a0, 2); + accu0 = vfmaq_laneq_f16(accu0, v3, a0, 3); + accu0 = vfmaq_laneq_f16(accu0, v4, a0, 4); + accu0 = vfmaq_laneq_f16(accu0, v5, a0, 5); + accu0 = vfmaq_laneq_f16(accu0, v6, a0, 6); + accu0 = vfmaq_laneq_f16(accu0, v7, a0, 7); + return accu0; +} + +MLAS_FORCEINLINE +float16x4_t ma_laneq_f16_accu(float16x4_t accu0, float16x4_t v0, float16x4_t v1, float16x4_t v2, float16x4_t v3, + float16x4_t v4, float16x4_t v5, float16x4_t v6, float16x4_t v7, float16x8_t a0) { + accu0 = vfma_laneq_f16(accu0, v0, a0, 0); + accu0 = vfma_laneq_f16(accu0, v1, a0, 1); + accu0 = vfma_laneq_f16(accu0, v2, a0, 2); + accu0 = vfma_laneq_f16(accu0, v3, a0, 3); + accu0 = vfma_laneq_f16(accu0, v4, a0, 4); + accu0 = vfma_laneq_f16(accu0, v5, a0, 5); + accu0 = vfma_laneq_f16(accu0, v6, a0, 6); + accu0 = vfma_laneq_f16(accu0, v7, a0, 7); + return accu0; +} + +MLAS_FORCEINLINE +float16x4_t ma_lane_f16_accu(float16x4_t accu, float16x4_t v0, float16x4_t v1, float16x4_t v2, float16x4_t v3, + float16x4_t a0) { + accu = vfma_lane_f16(accu, v0, a0, 0); + accu = vfma_lane_f16(accu, v1, a0, 1); + accu = vfma_lane_f16(accu, v2, a0, 2); + accu = vfma_lane_f16(accu, v3, a0, 3); + return accu; +} + +// beta_behavior: beta == 0.0f16 -> 0, beta == 1.0f16 -> 1, otherwise -> 2 +template +void HGemm_TransposedB_Kernel_Impl( + const _mlas_fp16_* A_data, + const _mlas_fp16_* B_data, + _mlas_fp16_* C_data, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + const bool largeK = CountK >= 8; + const bool Kr0 = (CountK & 3); + const bool Kr1 = (CountK & 3) > 1; + const bool Kr2 = (CountK & 3) > 2; + const bool Kr3 = (CountK & 4); + for (; CountN >= 8; CountN -= 8, B_data += 8 * ldb, C_data += 8) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu04 = MlasZeroFloat16x8(); + float16x8_t accu05 = MlasZeroFloat16x8(); + float16x8_t accu06 = MlasZeroFloat16x8(); + float16x8_t accu07 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13, accu14, accu15, accu16, accu17; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + accu14 = MlasZeroFloat16x8(); + accu15 = MlasZeroFloat16x8(); + accu16 = MlasZeroFloat16x8(); + accu17 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b4 = MlasLoadFloat16x8(b + 4 * ldb); + float16x8_t b5 = MlasLoadFloat16x8(b + 5 * ldb); + float16x8_t b6 = MlasLoadFloat16x8(b + 6 * ldb); + float16x8_t b7 = MlasLoadFloat16x8(b + 7 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + for (; k >= 16; k -= 8, a += 8, b += 8) { + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + accu04 = vfmaq_f16(accu04, b4, a0); + accu05 = vfmaq_f16(accu05, b5, a0); + accu06 = vfmaq_f16(accu06, b6, a0); + accu07 = vfmaq_f16(accu07, b7, a0); + if constexpr (CountM == 2) { + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + accu14 = vfmaq_f16(accu14, b4, a1); + accu15 = vfmaq_f16(accu15, b5, a1); + accu16 = vfmaq_f16(accu16, b6, a1); + accu17 = vfmaq_f16(accu17, b7, a1); + } + b0 = MlasLoadFloat16x8(b + 8); + b1 = MlasLoadFloat16x8(b + 8 + ldb); + b2 = MlasLoadFloat16x8(b + 8 + 2 * ldb); + b3 = MlasLoadFloat16x8(b + 8 + 3 * ldb); + b4 = MlasLoadFloat16x8(b + 8 + 4 * ldb); + b5 = MlasLoadFloat16x8(b + 8 + 5 * ldb); + b6 = MlasLoadFloat16x8(b + 8 + 6 * ldb); + b7 = MlasLoadFloat16x8(b + 8 + 7 * ldb); + a0 = MlasLoadFloat16x8(a + 8); + } + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + accu04 = vfmaq_f16(accu04, b4, a0); + accu05 = vfmaq_f16(accu05, b5, a0); + accu06 = vfmaq_f16(accu06, b6, a0); + accu07 = vfmaq_f16(accu07, b7, a0); + if constexpr (CountM == 2) { + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + accu14 = vfmaq_f16(accu14, b4, a1); + accu15 = vfmaq_f16(accu15, b5, a1); + accu16 = vfmaq_f16(accu16, b6, a1); + accu17 = vfmaq_f16(accu17, b7, a1); + } + k -= 8, a += 8, b += 8; + } + Transpose8x8(accu00, accu01, accu02, accu03, accu04, accu05, accu06, accu07); + accu00 = addq_f16x8(accu00, accu01, accu02, accu03, accu04, accu05, accu06, accu07); + if constexpr (CountM == 2) { + Transpose8x8(accu10, accu11, accu12, accu13, accu14, accu15, accu16, accu17); + accu10 = addq_f16x8(accu10, accu11, accu12, accu13, accu14, accu15, accu16, accu17); + } + + if (Kr3) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + float16x4_t b4 = MlasLoadFloat16x4(b + 4 * ldb); + float16x4_t b5 = MlasLoadFloat16x4(b + 5 * ldb); + float16x4_t b6 = MlasLoadFloat16x4(b + 6 * ldb); + float16x4_t b7 = MlasLoadFloat16x4(b + 7 * ldb); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4); + float16x8_t v1 = vcombine_f16(b1, b5); + float16x8_t v2 = vcombine_f16(b2, b6); + float16x8_t v3 = vcombine_f16(b3, b7); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu00 = maq_lane_f16_accu(accu00, v0, v1, v2, v3, a0); + if constexpr (CountM == 2) { + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu10 = maq_lane_f16_accu(accu10, v0, v1, v2, v3, a1); + } + k -= 4, a += 4, b += 4; + } + + if (Kr0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + float16x4_t b4 = MlasLoadPartialFloat16x4(b + 4 * ldb, k); + float16x4_t b5 = MlasLoadPartialFloat16x4(b + 5 * ldb, k); + float16x4_t b6 = MlasLoadPartialFloat16x4(b + 6 * ldb, k); + float16x4_t b7 = MlasLoadPartialFloat16x4(b + 7 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + Transpose4x4(b4, b5, b6, b7); + float16x8_t v0 = vcombine_f16(b0, b4); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + accu00 = vfmaq_lane_f16(accu00, v0, a0, 0); + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu10 = vfmaq_lane_f16(accu10, v0, a1, 0); + } + if (Kr1) { + float16x8_t v1 = vcombine_f16(b1, b5); + accu00 = vfmaq_lane_f16(accu00, v1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, v1, a1, 1); + } + } + if (Kr2) { + float16x8_t v2 = vcombine_f16(b2, b6); + accu00 = vfmaq_lane_f16(accu00, v2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, v2, a1, 2); + } + } + } + + if constexpr (beta_behavior == 1) { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t c0 = MlasLoadFloat16x8(C_data); + accu00 = vfmaq_f16(c0, accu00, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + if constexpr (CountM == 2) { + float16x8_t c1 = MlasLoadFloat16x8(C_data + ldc); + accu10 = vfmaq_f16(c1, accu10, alpha_v); + MlasStoreFloat16x8(C_data + ldc, accu10); + } + } else if constexpr (beta_behavior == 2) { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v = MlasBroadcastFloat16x8(beta); + float16x8_t c0 = MlasLoadFloat16x8(C_data); + accu00 = vfmaq_f16(vmulq_f16(c0, beta_v), accu00, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + if constexpr (CountM == 2) { + float16x8_t c1 = MlasLoadFloat16x8(C_data + ldc); + accu10 = vfmaq_f16(vmulq_f16(c1, beta_v), accu10, alpha_v); + MlasStoreFloat16x8(C_data + ldc, accu10); + } + } else { + float16x8_t alpha_v = MlasBroadcastFloat16x8(alpha); + accu00 = vmulq_f16(accu00, alpha_v); + MlasStoreFloat16x8(C_data, accu00); + if constexpr (CountM == 2) { + accu10 = vmulq_f16(accu10, alpha_v); + MlasStoreFloat16x8(C_data + ldc, accu10); + } + } + } + + if (CountN & 4) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x8_t b0 = MlasLoadFloat16x8(b); + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b3 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t a0 = MlasLoadFloat16x8(a); + for (; k >= 16; k -= 8, a += 8, b += 8) { + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + if constexpr (CountM == 2) { + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + } + b0 = MlasLoadFloat16x8(b + 8); + b1 = MlasLoadFloat16x8(b + 8 + ldb); + b2 = MlasLoadFloat16x8(b + 8 + 2 * ldb); + b3 = MlasLoadFloat16x8(b + 8 + 3 * ldb); + a0 = MlasLoadFloat16x8(a + 8); + } + accu00 = vfmaq_f16(accu00, b0, a0); + accu01 = vfmaq_f16(accu01, b1, a0); + accu02 = vfmaq_f16(accu02, b2, a0); + accu03 = vfmaq_f16(accu03, b3, a0); + if constexpr (CountM == 2) { + float16x8_t a1 = MlasLoadFloat16x8(a + lda); + accu10 = vfmaq_f16(accu10, b0, a1); + accu11 = vfmaq_f16(accu11, b1, a1); + accu12 = vfmaq_f16(accu12, b2, a1); + accu13 = vfmaq_f16(accu13, b3, a1); + } + k -= 8, a += 8, b += 8; + } + Transpose4x8(accu00, accu01, accu02, accu03); + accu00 = addq_f16x4(accu00, accu01, accu02, accu03); + float16x4_t accu0 = vadd_f16(vget_low_f16(accu00), vget_high_f16(accu00)), accu1; + if constexpr (CountM == 2) { + Transpose4x8(accu10, accu11, accu12, accu13); + accu10 = addq_f16x4(accu10, accu11, accu12, accu13); + accu1 = vadd_f16(vget_low_f16(accu10), vget_high_f16(accu10)); + } + + if (Kr3) { + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu0 = ma_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + if constexpr (CountM == 2) { + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu1 = ma_lane_f16_accu(accu1, b0, b1, b2, b3, a1); + } + k -= 4, a += 4, b += 4; + } + + if (Kr0) { + float16x4_t b0 = MlasLoadPartialFloat16x4(b, k); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, k); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, k); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, k); + Transpose4x4(b0, b1, b2, b3); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + accu0 = vfma_lane_f16(accu0, b0, a0, 0); + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu1 = vfma_lane_f16(accu1, b0, a1, 0); + } + if (Kr1) { + accu0 = vfma_lane_f16(accu0, b1, a0, 1); + if constexpr (CountM == 2) { + accu1 = vfma_lane_f16(accu1, b1, a1, 1); + } + } + if (Kr2) { + accu0 = vfma_lane_f16(accu0, b2, a0, 2); + if constexpr (CountM == 2) { + accu1 = vfma_lane_f16(accu1, b2, a1, 2); + } + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t c0 = MlasLoadFloat16x4(C_data); + accu0 = vfma_f16(c0, accu0, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadFloat16x4(C_data + ldc); + accu1 = vfma_f16(c1, accu1, alpha_v); + MlasStoreFloat16x4(C_data + ldc, accu1); + } + } else if constexpr (beta_behavior == 2) { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + float16x4_t c0 = MlasLoadFloat16x4(C_data); + accu0 = vfma_f16(vmul_f16(c0, beta_v), accu0, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadFloat16x4(C_data + ldc); + accu1 = vfma_f16(vmul_f16(c1, beta_v), accu1, alpha_v); + MlasStoreFloat16x4(C_data + ldc, accu1); + } + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu0 = vmul_f16(accu0, alpha_v); + MlasStoreFloat16x4(C_data, accu0); + if constexpr (CountM == 2) { + accu1 = vmul_f16(accu1, alpha_v); + MlasStoreFloat16x4(C_data + ldc, accu1); + } + } + + CountN -= 4, B_data += 4 * ldb, C_data += 4; + } + + if (CountN > 0) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu0[4], accu1[4]; + size_t i = 0; + for (i = 0; i < 4; ++i) { + accu0[i] = MlasZeroFloat16x8(); + if constexpr (CountM == 2) { + accu1[i] = MlasZeroFloat16x8(); + } + } + for (; k >= 8; k -= 8, a += 8, b += 8) { + float16x8_t a0 = MlasLoadFloat16x8(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x8(a + lda); + } + for (i = 0; i < CountN; ++i) { + float16x8_t bi = MlasLoadFloat16x8(b + i * ldb); + accu0[i] = vfmaq_f16(accu0[i], bi, a0); + if constexpr (CountM == 2) { + accu1[i] = vfmaq_f16(accu1[i], bi, a1); + } + } + } + Transpose4x8(accu0[0], accu0[1], accu0[2], accu0[3]); + float16x8_t accu00 = addq_f16x4(accu0[0], accu0[1], accu0[2], accu0[3]); + float16x4_t accu_0 = vadd_f16(vget_low_f16(accu00), vget_high_f16(accu00)), accu_1; + if constexpr (CountM == 2) { + Transpose4x8(accu1[0], accu1[1], accu1[2], accu1[3]); + float16x8_t accu10 = addq_f16x4(accu1[0], accu1[1], accu1[2], accu1[3]); + accu_1 = vadd_f16(vget_low_f16(accu10), vget_high_f16(accu10)); + } + + if (Kr3) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadFloat16x4(b + i * ldb); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadFloat16x4(a); + accu_0 = ma_lane_f16_accu(accu_0, bs[0], bs[1], bs[2], bs[3], a0); + if constexpr (CountM == 2) { + float16x4_t a1 = MlasLoadFloat16x4(a + lda); + accu_1 = ma_lane_f16_accu(accu_1, bs[0], bs[1], bs[2], bs[3], a1); + } + k -= 4, a += 4, b += 4; + } + + if (Kr0) { + float16x4_t bs[4]; + for (i = 0; i < CountN; ++i) { + bs[i] = MlasLoadPartialFloat16x4(b + i * ldb, k); + } + for (; i < 4; ++i) { + bs[i] = MlasZeroFloat16x4(); + } + Transpose4x4(bs[0], bs[1], bs[2], bs[3]); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + accu_0 = vfma_lane_f16(accu_0, bs[0], a0, 0); + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu_1 = vfma_lane_f16(accu_1, bs[0], a1, 0); + } + if (Kr1) { + accu_0 = vfma_lane_f16(accu_0, bs[1], a0, 1); + if constexpr (CountM == 2) { + accu_1 = vfma_lane_f16(accu_1, bs[1], a1, 1); + } + } + if (Kr2) { + accu_0 = vfma_lane_f16(accu_0, bs[2], a0, 2); + if constexpr (CountM == 2) { + accu_1 = vfma_lane_f16(accu_1, bs[2], a1, 2); + } + } + } + + if constexpr (beta_behavior == 1) { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t c0 = MlasLoadPartialFloat16x4(C_data, CountN); + accu_0 = vfma_f16(c0, accu_0, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadPartialFloat16x4(C_data + ldc, CountN); + accu_1 = vfma_f16(c1, accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } + } else if constexpr (beta_behavior == 2) { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v = MlasBroadcastFloat16x4(beta); + float16x4_t c0 = MlasLoadPartialFloat16x4(C_data, CountN); + accu_0 = vfma_f16(vmul_f16(c0, beta_v), accu_0, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadPartialFloat16x4(C_data + ldc, CountN); + accu_1 = vfma_f16(vmul_f16(c1, beta_v), accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } + } else { + float16x4_t alpha_v = MlasBroadcastFloat16x4(alpha); + accu_0 = vmul_f16(accu_0, alpha_v); + MlasStorePartialFloat16x4(C_data, accu_0, CountN); + if constexpr (CountM == 2) { + accu_1 = vmul_f16(accu_1, alpha_v); + MlasStorePartialFloat16x4(C_data + ldc, accu_1, CountN); + } + } + } +} + +// Full K. Directly save to C. +void HGemm_TransposedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + if (CountM > 2) { + MLAS_THROW_EX(std::runtime_error, "HGemm_TransposedB_Kernel only support <= 2 rows"); + } + const auto* A_data = reinterpret_cast(A); + const auto* B_data = reinterpret_cast(B); + auto* C_data = reinterpret_cast<_mlas_fp16_*>(C); + const auto f16_0 = MLAS_FP16(0.0f); + const auto f16_1 = MLAS_FP16(1.0f); + if (CountM == 1) { + if (beta == f16_0.val) { + HGemm_TransposedB_Kernel_Impl<0, 1>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedB_Kernel_Impl<1, 1>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else { + HGemm_TransposedB_Kernel_Impl<2, 1>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } + } else { + if (beta == f16_0.val) { + HGemm_TransposedB_Kernel_Impl<0, 2>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_TransposedB_Kernel_Impl<1, 2>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } else { + HGemm_TransposedB_Kernel_Impl<2, 2>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } + } +} + +// handle C = alpha * A * B + beta * C where alpha != 1 or beta != 0 or 1 +template +void HGemm_B_Kernel_Complicated( + const _mlas_fp16_* A_data, + const _mlas_fp16_* B_data, + _mlas_fp16_* C_data, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + const size_t ldb4 = ldb * 4; + float16x8_t alpha_v8 = MlasBroadcastFloat16x8(alpha); + float16x8_t beta_v8 = MlasBroadcastFloat16x8(beta); + float16x4_t alpha_v4 = MlasBroadcastFloat16x4(alpha); + float16x4_t beta_v4 = MlasBroadcastFloat16x4(beta); + const bool largeK = CountK >= 4; + const bool Kr0 = CountK & 3; + const bool Kr1 = (CountK & 3 ) > 1; + const bool Kr2 = (CountK & 3 ) > 2; + for (; CountN >= 32; CountN -= 32, B_data += 32, C_data += 32) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b12 = MlasLoadFloat16x8(b + ldb + 16); + float16x8_t b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + float16x8_t b32 = MlasLoadFloat16x8(b + 3 * ldb + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + float16x8_t b13 = MlasLoadFloat16x8(b + ldb + 24); + float16x8_t b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + float16x8_t b33 = MlasLoadFloat16x8(b + 3 * ldb + 24); + for (; k >= 8; k -= 4, a += 4, b += ldb4) { + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + 4 + lda); + } + b00 = MlasLoadFloat16x8(b + ldb4); + b10 = MlasLoadFloat16x8(b + 5 * ldb); + b20 = MlasLoadFloat16x8(b + 6 * ldb); + b30 = MlasLoadFloat16x8(b + 7 * ldb); + b01 = MlasLoadFloat16x8(b + ldb4 + 8); + b11 = MlasLoadFloat16x8(b + 5 * ldb + 8); + b21 = MlasLoadFloat16x8(b + 6 * ldb + 8); + b31 = MlasLoadFloat16x8(b + 7 * ldb + 8); + b02 = MlasLoadFloat16x8(b + ldb4 + 16); + b12 = MlasLoadFloat16x8(b + 5 * ldb + 16); + b22 = MlasLoadFloat16x8(b + 6 * ldb + 16); + b32 = MlasLoadFloat16x8(b + 7 * ldb + 16); + b03 = MlasLoadFloat16x8(b + ldb4 + 24); + b13 = MlasLoadFloat16x8(b + 5 * ldb + 24); + b23 = MlasLoadFloat16x8(b + 6 * ldb + 24); + b33 = MlasLoadFloat16x8(b + 7 * ldb + 24); + } + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + } + k -= 4, a += 4, b += ldb4; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b12 = MlasLoadFloat16x8(b + ldb + 16); + float16x8_t b13 = MlasLoadFloat16x8(b + ldb + 24); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + float16x8_t b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + } + + float16x8_t c00 = MlasLoadFloat16x8(C_data); + float16x8_t c01 = MlasLoadFloat16x8(C_data + 8); + float16x8_t c02 = MlasLoadFloat16x8(C_data + 16); + float16x8_t c03 = MlasLoadFloat16x8(C_data + 24); + MlasStoreFloat16x8(C_data, vfmaq_f16(vmulq_f16(c00, beta_v8), accu00, alpha_v8)); + MlasStoreFloat16x8(C_data + 8, vfmaq_f16(vmulq_f16(c01, beta_v8), accu01, alpha_v8)); + MlasStoreFloat16x8(C_data + 16, vfmaq_f16(vmulq_f16(c02, beta_v8), accu02, alpha_v8)); + MlasStoreFloat16x8(C_data + 24, vfmaq_f16(vmulq_f16(c03, beta_v8), accu03, alpha_v8)); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C_data + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C_data + ldc + 8); + float16x8_t c12 = MlasLoadFloat16x8(C_data + ldc + 16); + float16x8_t c13 = MlasLoadFloat16x8(C_data + ldc + 24); + MlasStoreFloat16x8(C_data + ldc, vfmaq_f16(vmulq_f16(c10, beta_v8), accu10, alpha_v8)); + MlasStoreFloat16x8(C_data + ldc + 8, vfmaq_f16(vmulq_f16(c11, beta_v8), accu11, alpha_v8)); + MlasStoreFloat16x8(C_data + ldc + 16, vfmaq_f16(vmulq_f16(c12, beta_v8), accu12, alpha_v8)); + MlasStoreFloat16x8(C_data + ldc + 24, vfmaq_f16(vmulq_f16(c13, beta_v8), accu13, alpha_v8)); + } + } + + if (CountN & 16) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + for (; k >= 8; k -= 4, a += 4, b += ldb4) { + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + 4 + lda); + } + b00 = MlasLoadFloat16x8(b + ldb4); + b10 = MlasLoadFloat16x8(b + 5 * ldb); + b20 = MlasLoadFloat16x8(b + 6 * ldb); + b30 = MlasLoadFloat16x8(b + 7 * ldb); + b01 = MlasLoadFloat16x8(b + ldb4 + 8); + b11 = MlasLoadFloat16x8(b + 5 * ldb + 8); + b21 = MlasLoadFloat16x8(b + 6 * ldb + 8); + b31 = MlasLoadFloat16x8(b + 7 * ldb + 8); + + } + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + } + k -= 4, a += 4, b += ldb4; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + } + } + } + + float16x8_t c00 = MlasLoadFloat16x8(C_data); + float16x8_t c01 = MlasLoadFloat16x8(C_data + 8); + MlasStoreFloat16x8(C_data, vfmaq_f16(vmulq_f16(c00, beta_v8), accu00, alpha_v8)); + MlasStoreFloat16x8(C_data + 8, vfmaq_f16(vmulq_f16(c01, beta_v8), accu01, alpha_v8)); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C_data + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C_data + ldc + 8); + MlasStoreFloat16x8(C_data + ldc, vfmaq_f16(vmulq_f16(c10, beta_v8), accu10, alpha_v8)); + MlasStoreFloat16x8(C_data + ldc + 8, vfmaq_f16(vmulq_f16(c11, beta_v8), accu11, alpha_v8)); + } + + CountN -= 16, B_data += 16, C_data += 16; + } + + if (CountN & 8) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + for (; k >= 8; k -= 4, a += 4, b += ldb4) { + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + 4 + lda); + } + b00 = MlasLoadFloat16x8(b + ldb4); + b10 = MlasLoadFloat16x8(b + 5 * ldb); + b20 = MlasLoadFloat16x8(b + 6 * ldb); + b30 = MlasLoadFloat16x8(b + 7 * ldb); + } + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + k -= 4, a += 4, b += ldb4; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + } + } + } + + float16x8_t c00 = MlasLoadFloat16x8(C_data); + MlasStoreFloat16x8(C_data, vfmaq_f16(vmulq_f16(c00, beta_v8), accu00, alpha_v8)); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C_data + ldc); + MlasStoreFloat16x8(C_data + ldc, vfmaq_f16(vmulq_f16(c10, beta_v8), accu10, alpha_v8)); + } + + CountN -= 8, B_data += 8, C_data += 8; + } + + if (CountN & 4) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x4_t accu00 = MlasZeroFloat16x4(); + float16x4_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x4(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x4_t b00 = MlasLoadFloat16x4(b); + float16x4_t b10 = MlasLoadFloat16x4(b + ldb); + float16x4_t b20 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b30 = MlasLoadFloat16x4(b + 3 * ldb); + for (; k >= 8; k -= 4, a += 4, b += ldb4) { + accu00 = ma_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = ma_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + 4 + lda); + } + b00 = MlasLoadFloat16x4(b + ldb4); + b10 = MlasLoadFloat16x4(b + 5 * ldb); + b20 = MlasLoadFloat16x4(b + 6 * ldb); + b30 = MlasLoadFloat16x4(b + 7 * ldb); + } + accu00 = ma_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = ma_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + k -= 4, a += 4, b += ldb4; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x4_t b00 = MlasLoadFloat16x4(b); + accu00 = vfma_lane_f16(accu00, b00, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b00, a1, 0); + } + if (Kr1) { + float16x4_t b10 = MlasLoadFloat16x4(b + ldb); + accu00 = vfma_lane_f16(accu00, b10, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b10, a1, 1); + } + } + if (Kr2) { + float16x4_t b20 = MlasLoadFloat16x4(b + 2 * ldb); + accu00 = vfma_lane_f16(accu00, b20, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b20, a1, 2); + } + } + } + + float16x4_t c00 = MlasLoadFloat16x4(C_data); + MlasStoreFloat16x4(C_data, vfma_f16(vmul_f16(c00, beta_v4), accu00, alpha_v4)); + if constexpr (CountM == 2) { + float16x4_t c10 = MlasLoadFloat16x4(C_data + ldc); + MlasStoreFloat16x4(C_data + ldc, vfma_f16(vmul_f16(c10, beta_v4), accu10, alpha_v4)); + } + + CountN -= 4, B_data += 4, C_data += 4; + } + + if (CountN > 0) { + const auto* a = A_data; + const auto* b = B_data; + size_t k = CountK; + float16x4_t accu00 = MlasZeroFloat16x4(); + float16x4_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x4(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x4_t b00 = MlasLoadPartialFloat16x4(b, CountN); + float16x4_t b10 = MlasLoadPartialFloat16x4(b + ldb, CountN); + float16x4_t b20 = MlasLoadPartialFloat16x4(b + 2 * ldb, CountN); + float16x4_t b30 = MlasLoadPartialFloat16x4(b + 3 * ldb, CountN); + for (; k >= 8; k -= 4, a += 4, b += ldb4) { + accu00 = ma_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = ma_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + 4 + lda); + } + b00 = MlasLoadPartialFloat16x4(b + ldb4, CountN); + b10 = MlasLoadPartialFloat16x4(b + 5 * ldb, CountN); + b20 = MlasLoadPartialFloat16x4(b + 6 * ldb, CountN); + b30 = MlasLoadPartialFloat16x4(b + 7 * ldb, CountN); + } + accu00 = ma_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + if constexpr (CountM == 2) { + accu10 = ma_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + } + k -= 4, a += 4, b += ldb4; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x4_t b00 = MlasLoadPartialFloat16x4(b, CountN); + accu00 = vfma_lane_f16(accu00, b00, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b00, a1, 0); + } + if (Kr1) { + float16x4_t b10 = MlasLoadPartialFloat16x4(b + ldb, CountN); + accu00 = vfma_lane_f16(accu00, b10, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b10, a1, 1); + } + } + if (Kr2) { + float16x4_t b20 = MlasLoadPartialFloat16x4(b + 2 * ldb, CountN); + accu00 = vfma_lane_f16(accu00, b20, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b20, a1, 2); + } + } + } + + float16x4_t c00 = MlasLoadPartialFloat16x4(C_data, CountN); + MlasStorePartialFloat16x4(C_data, vfma_f16(vmul_f16(c00, beta_v4), accu00, alpha_v4), CountN); + if constexpr (CountM == 2) { + float16x4_t c10 = MlasLoadPartialFloat16x4(C_data + ldc, CountN); + MlasStorePartialFloat16x4(C_data + ldc, vfma_f16(vmul_f16(c10, beta_v4), accu10, alpha_v4), CountN); + } + } +} + +// Handle C = A * B + C or C = A * B +template +void HGemm_B_Kernel_Simple( + const _mlas_fp16_* A_data, + const _mlas_fp16_* B_data, + _mlas_fp16_* C_data, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc +) { + const size_t ldb4 = ldb * 4; + const bool largeN = CountN >= 32; + const bool Nr0 = (CountN % 4) > 0; + const bool Kr1 = (CountK % 4) > 1; + const bool Kr2 = (CountK % 4) > 2; + if constexpr (zero_mode) { + // process first K + if (CountK >= 4) { + float16x4_t a0 = MlasLoadFloat16x4(A_data), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(A_data + lda); + } + size_t n = CountN; + const auto* b = B_data; + auto* c = C_data; + if (largeN) { + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b12 = MlasLoadFloat16x8(b + ldb + 16); + float16x8_t b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + float16x8_t b32 = MlasLoadFloat16x8(b + 3 * ldb + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + float16x8_t b13 = MlasLoadFloat16x8(b + ldb + 24); + float16x8_t b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + float16x8_t b33 = MlasLoadFloat16x8(b + 3 * ldb + 24); + for (; n >= 64; n -= 32, b += 32, c += 32) { + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + float16x8_t accu12 = MlasZeroFloat16x8(); + float16x8_t accu13 = MlasZeroFloat16x8(); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + b00 = MlasLoadFloat16x8(b + 32); + b10 = MlasLoadFloat16x8(b + ldb + 32); + b20 = MlasLoadFloat16x8(b + 2 * ldb + 32); + b30 = MlasLoadFloat16x8(b + 3 * ldb + 32); + b01 = MlasLoadFloat16x8(b + 40); + b11 = MlasLoadFloat16x8(b + ldb + 40); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 40); + b31 = MlasLoadFloat16x8(b + 3 * ldb + 40); + b02 = MlasLoadFloat16x8(b + 48); + b12 = MlasLoadFloat16x8(b + ldb + 48); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 48); + b32 = MlasLoadFloat16x8(b + 3 * ldb + 48); + b03 = MlasLoadFloat16x8(b + 56); + b13 = MlasLoadFloat16x8(b + ldb + 56); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 56); + b33 = MlasLoadFloat16x8(b + 3 * ldb + 56); + } + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + float16x8_t accu12 = MlasZeroFloat16x8(); + float16x8_t accu13 = MlasZeroFloat16x8(); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + n -= 32, b += 32, c += 32; + } + if (n & 16) { + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasZeroFloat16x8(); + float16x8_t accu11 = MlasZeroFloat16x8(); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + } + n -= 16, b += 16, c += 16; + } + if (n & 8) { + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + MlasStoreFloat16x8(c, accu00); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasZeroFloat16x8(); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + MlasStoreFloat16x8(c + ldc, accu10); + } + n -= 8, b += 8, c += 8; + } + if (n & 4) { + float16x4_t accu00 = MlasZeroFloat16x4(); + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + accu00 = ma_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + MlasStoreFloat16x4(c, accu00); + if constexpr (CountM == 2) { + float16x4_t accu10 = MlasZeroFloat16x4(); + accu10 = ma_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + MlasStoreFloat16x4(c + ldc, accu10); + } + n -= 4, b += 4, c += 4; + } + if (Nr0) { + float16x4_t accu00 = MlasZeroFloat16x4(); + float16x4_t b0 = MlasLoadPartialFloat16x4(b, n); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, n); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, n); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, n); + accu00 = ma_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + MlasStorePartialFloat16x4(c, accu00, n); + if constexpr (CountM == 2) { + float16x4_t accu10 = MlasZeroFloat16x4(); + accu10 = ma_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + MlasStorePartialFloat16x4(c + ldc, accu10, n); + } + } + CountK -= 4, B_data += ldb4, A_data += 4; + } else if (CountK > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(A_data, CountK), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(A_data + lda, CountK); + } + size_t n = CountN; + const auto* b = B_data; + auto* c = C_data; + if (largeN) { + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + float16x8_t b10 = MlasZeroFloat16x8(); + float16x8_t b11 = MlasZeroFloat16x8(); + float16x8_t b12 = MlasZeroFloat16x8(); + float16x8_t b13 = MlasZeroFloat16x8(); + float16x8_t b20 = MlasZeroFloat16x8(); + float16x8_t b21 = MlasZeroFloat16x8(); + float16x8_t b22 = MlasZeroFloat16x8(); + float16x8_t b23 = MlasZeroFloat16x8(); + if (Kr1) { + b10 = MlasLoadFloat16x8(b + ldb); + b11 = MlasLoadFloat16x8(b + ldb + 8); + b12 = MlasLoadFloat16x8(b + ldb + 16); + b13 = MlasLoadFloat16x8(b + ldb + 24); + } + if (Kr2) { + b20 = MlasLoadFloat16x8(b + 2 * ldb); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + } + for (; n >= 64; n -= 32, b += 32, c += 32) { + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + } + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + b00 = MlasLoadFloat16x8(b + 32); + b01 = MlasLoadFloat16x8(b + 40); + b02 = MlasLoadFloat16x8(b + 48); + b03 = MlasLoadFloat16x8(b + 56); + if (Kr1) { + b10 = MlasLoadFloat16x8(b + ldb + 32); + b11 = MlasLoadFloat16x8(b + ldb + 40); + b12 = MlasLoadFloat16x8(b + ldb + 48); + b13 = MlasLoadFloat16x8(b + ldb + 56); + } + if (Kr2) { + b20 = MlasLoadFloat16x8(b + 2 * ldb + 32); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 40); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 48); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 56); + } + } + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + } + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + n -= 32, b += 32, c += 32; + } + if (n & 16) { + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + } + n -= 16, b += 16, c += 16; + } + if (n & 8) { + float16x8_t accu00 = MlasZeroFloat16x8(), accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + } + float16x8_t b0 = MlasLoadFloat16x8(b); + accu00 = vfmaq_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + accu00 = vfmaq_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + accu00 = vfmaq_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b2, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + } + n -= 8, b += 8, c += 8; + } + if (n & 4) { + float16x4_t accu00 = MlasZeroFloat16x4(), accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x4(); + } + float16x4_t b0 = MlasLoadFloat16x4(b); + accu00 = vfma_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + accu00 = vfma_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + accu00 = vfma_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b2, a1, 2); + } + } + MlasStoreFloat16x4(c, accu00); + if constexpr (CountM == 2) { + MlasStoreFloat16x4(c + ldc, accu10); + } + n -= 4, b += 4, c += 4; + } + if (Nr0) { + float16x4_t accu00 = MlasZeroFloat16x4(), accu10; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x4(); + } + float16x4_t b0 = MlasLoadPartialFloat16x4(b, n); + accu00 = vfma_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, n); + accu00 = vfma_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, n); + accu00 = vfma_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b2, a1, 2); + } + } + MlasStorePartialFloat16x4(c, accu00, n); + if constexpr (CountM == 2) { + MlasStorePartialFloat16x4(c + ldc, accu10, n); + } + } + + CountK -= CountK, B_data += ldb * CountK, A_data += CountK; + } + } + + for (; CountK >= 4; CountK -= 4, B_data += ldb4, A_data += 4) { + float16x4_t a0 = MlasLoadFloat16x4(A_data), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(A_data + lda); + } + size_t n = CountN; + const auto* b = B_data; + auto* c = C_data; + if (largeN) { + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b12 = MlasLoadFloat16x8(b + ldb + 16); + float16x8_t b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + float16x8_t b32 = MlasLoadFloat16x8(b + 3 * ldb + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + float16x8_t b13 = MlasLoadFloat16x8(b + ldb + 24); + float16x8_t b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + float16x8_t b33 = MlasLoadFloat16x8(b + 3 * ldb + 24); + for (; n >= 64; n -= 32, b += 32, c += 32) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t accu02 = MlasLoadFloat16x8(c + 16); + float16x8_t accu03 = MlasLoadFloat16x8(c + 24); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasLoadFloat16x8(c + ldc); + float16x8_t accu11 = MlasLoadFloat16x8(c + ldc + 8); + float16x8_t accu12 = MlasLoadFloat16x8(c + ldc + 16); + float16x8_t accu13 = MlasLoadFloat16x8(c + ldc + 24); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + b00 = MlasLoadFloat16x8(b + 32); + b10 = MlasLoadFloat16x8(b + ldb + 32); + b20 = MlasLoadFloat16x8(b + 2 * ldb + 32); + b30 = MlasLoadFloat16x8(b + 3 * ldb + 32); + b01 = MlasLoadFloat16x8(b + 40); + b11 = MlasLoadFloat16x8(b + ldb + 40); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 40); + b31 = MlasLoadFloat16x8(b + 3 * ldb + 40); + b02 = MlasLoadFloat16x8(b + 48); + b12 = MlasLoadFloat16x8(b + ldb + 48); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 48); + b32 = MlasLoadFloat16x8(b + 3 * ldb + 48); + b03 = MlasLoadFloat16x8(b + 56); + b13 = MlasLoadFloat16x8(b + ldb + 56); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 56); + b33 = MlasLoadFloat16x8(b + 3 * ldb + 56); + } + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t accu02 = MlasLoadFloat16x8(c + 16); + float16x8_t accu03 = MlasLoadFloat16x8(c + 24); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasLoadFloat16x8(c + ldc); + float16x8_t accu11 = MlasLoadFloat16x8(c + ldc + 8); + float16x8_t accu12 = MlasLoadFloat16x8(c + ldc + 16); + float16x8_t accu13 = MlasLoadFloat16x8(c + ldc + 24); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + n -= 32, b += 32, c += 32; + } + if (n & 16) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + float16x8_t b31 = MlasLoadFloat16x8(b + 3 * ldb + 8); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasLoadFloat16x8(c + ldc); + float16x8_t accu11 = MlasLoadFloat16x8(c + ldc + 8); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + } + n -= 16, b += 16, c += 16; + } + if (n & 8) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b30 = MlasLoadFloat16x8(b + 3 * ldb); + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + MlasStoreFloat16x8(c, accu00); + if constexpr (CountM == 2) { + float16x8_t accu10 = MlasLoadFloat16x8(c + ldc); + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + MlasStoreFloat16x8(c + ldc, accu10); + } + n -= 8, b += 8, c += 8; + } + if (n & 4) { + float16x4_t accu00 = MlasLoadFloat16x4(c); + float16x4_t b0 = MlasLoadFloat16x4(b); + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + float16x4_t b3 = MlasLoadFloat16x4(b + 3 * ldb); + accu00 = ma_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + MlasStoreFloat16x4(c, accu00); + if constexpr (CountM == 2) { + float16x4_t accu10 = MlasLoadFloat16x4(c + ldc); + accu10 = ma_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + MlasStoreFloat16x4(c + ldc, accu10); + } + n -= 4, b += 4, c += 4; + } + if (Nr0) { + float16x4_t accu00 = MlasLoadPartialFloat16x4(c, n); + float16x4_t b0 = MlasLoadPartialFloat16x4(b, n); + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, n); + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, n); + float16x4_t b3 = MlasLoadPartialFloat16x4(b + 3 * ldb, n); + accu00 = ma_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + MlasStorePartialFloat16x4(c, accu00, n); + if constexpr (CountM == 2) { + float16x4_t accu10 = MlasLoadPartialFloat16x4(c + ldc, n); + accu10 = ma_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + MlasStorePartialFloat16x4(c + ldc, accu10, n); + } + } + } + + if (CountK > 0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(A_data, CountK), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(A_data + lda, CountK); + } + size_t n = CountN; + const auto* b = B_data; + auto* c = C_data; + if (largeN) { + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + float16x8_t b02 = MlasLoadFloat16x8(b + 16); + float16x8_t b03 = MlasLoadFloat16x8(b + 24); + float16x8_t b10 = MlasZeroFloat16x8(); + float16x8_t b11 = MlasZeroFloat16x8(); + float16x8_t b12 = MlasZeroFloat16x8(); + float16x8_t b13 = MlasZeroFloat16x8(); + float16x8_t b20 = MlasZeroFloat16x8(); + float16x8_t b21 = MlasZeroFloat16x8(); + float16x8_t b22 = MlasZeroFloat16x8(); + float16x8_t b23 = MlasZeroFloat16x8(); + if (Kr1) { + b10 = MlasLoadFloat16x8(b + ldb); + b11 = MlasLoadFloat16x8(b + ldb + 8); + b12 = MlasLoadFloat16x8(b + ldb + 16); + b13 = MlasLoadFloat16x8(b + ldb + 24); + } + if (Kr2) { + b20 = MlasLoadFloat16x8(b + 2 * ldb); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 16); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 24); + } + for (; n >= 64; n -= 32, b += 32, c += 32) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t accu02 = MlasLoadFloat16x8(c + 16); + float16x8_t accu03 = MlasLoadFloat16x8(c + 24); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasLoadFloat16x8(c + ldc); + accu11 = MlasLoadFloat16x8(c + ldc + 8); + accu12 = MlasLoadFloat16x8(c + ldc + 16); + accu13 = MlasLoadFloat16x8(c + ldc + 24); + } + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + b00 = MlasLoadFloat16x8(b + 32); + b01 = MlasLoadFloat16x8(b + 40); + b02 = MlasLoadFloat16x8(b + 48); + b03 = MlasLoadFloat16x8(b + 56); + if (Kr1) { + b10 = MlasLoadFloat16x8(b + ldb + 32); + b11 = MlasLoadFloat16x8(b + ldb + 40); + b12 = MlasLoadFloat16x8(b + ldb + 48); + b13 = MlasLoadFloat16x8(b + ldb + 56); + } + if (Kr2) { + b20 = MlasLoadFloat16x8(b + 2 * ldb + 32); + b21 = MlasLoadFloat16x8(b + 2 * ldb + 40); + b22 = MlasLoadFloat16x8(b + 2 * ldb + 48); + b23 = MlasLoadFloat16x8(b + 2 * ldb + 56); + } + } + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t accu02 = MlasLoadFloat16x8(c + 16); + float16x8_t accu03 = MlasLoadFloat16x8(c + 24); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasLoadFloat16x8(c + ldc); + accu11 = MlasLoadFloat16x8(c + ldc + 8); + accu12 = MlasLoadFloat16x8(c + ldc + 16); + accu13 = MlasLoadFloat16x8(c + ldc + 24); + } + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + MlasStoreFloat16x8(c + 16, accu02); + MlasStoreFloat16x8(c + 24, accu03); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + MlasStoreFloat16x8(c + ldc + 16, accu12); + MlasStoreFloat16x8(c + ldc + 24, accu13); + } + n -= 32, b += 32, c += 32; + } + if (n & 16) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu01 = MlasLoadFloat16x8(c + 8); + float16x8_t accu10, accu11; + if constexpr (CountM == 2) { + accu10 = MlasLoadFloat16x8(c + ldc); + accu11 = MlasLoadFloat16x8(c + ldc + 8); + } + float16x8_t b00 = MlasLoadFloat16x8(b); + float16x8_t b01 = MlasLoadFloat16x8(b + 8); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(b + ldb); + float16x8_t b11 = MlasLoadFloat16x8(b + ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(b + 2 * ldb); + float16x8_t b21 = MlasLoadFloat16x8(b + 2 * ldb + 8); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + MlasStoreFloat16x8(c + 8, accu01); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + MlasStoreFloat16x8(c + ldc + 8, accu11); + } + n -= 16, b += 16, c += 16; + } + if (n & 8) { + float16x8_t accu00 = MlasLoadFloat16x8(c); + float16x8_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasLoadFloat16x8(c + ldc); + } + float16x8_t b0 = MlasLoadFloat16x8(b); + accu00 = vfmaq_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x8_t b1 = MlasLoadFloat16x8(b + ldb); + accu00 = vfmaq_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x8_t b2 = MlasLoadFloat16x8(b + 2 * ldb); + accu00 = vfmaq_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b2, a1, 2); + } + } + MlasStoreFloat16x8(c, accu00); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(c + ldc, accu10); + } + n -= 8, b += 8, c += 8; + } + if (n & 4) { + float16x4_t accu00 = MlasLoadFloat16x4(c); + float16x4_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasLoadFloat16x4(c + ldc); + } + float16x4_t b0 = MlasLoadFloat16x4(b); + accu00 = vfma_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x4_t b1 = MlasLoadFloat16x4(b + ldb); + accu00 = vfma_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x4_t b2 = MlasLoadFloat16x4(b + 2 * ldb); + accu00 = vfma_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b2, a1, 2); + } + } + MlasStoreFloat16x4(c, accu00); + if constexpr (CountM == 2) { + MlasStoreFloat16x4(c + ldc, accu10); + } + n -= 4, b += 4, c += 4; + } + if (Nr0) { + float16x4_t accu00 = MlasLoadPartialFloat16x4(c, n); + float16x4_t accu10; + if constexpr (CountM == 2) { + accu10 = MlasLoadPartialFloat16x4(c + ldc, n); + } + float16x4_t b0 = MlasLoadPartialFloat16x4(b, n); + accu00 = vfma_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x4_t b1 = MlasLoadPartialFloat16x4(b + ldb, n); + accu00 = vfma_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x4_t b2 = MlasLoadPartialFloat16x4(b + 2 * ldb, n); + accu00 = vfma_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfma_lane_f16(accu10, b2, a1, 2); + } + } + MlasStorePartialFloat16x4(c, accu00, n); + if constexpr (CountM == 2) { + MlasStorePartialFloat16x4(c + ldc, accu10, n); + } + } + + CountK -= CountK, B_data += ldb * CountK, A_data += CountK; + } + } + +void HGemm_B_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* B, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldb, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + if (CountM > 2) { + MLAS_THROW_EX(std::runtime_error, "HGemm_TransposedB_Kernel only support <= 2 rows"); + } + const auto* A_data = reinterpret_cast(A); + const auto* B_data = reinterpret_cast(B); + auto* C_data = reinterpret_cast<_mlas_fp16_*>(C); + const auto f16_0 = MLAS_FP16(0.0f); + const auto f16_1 = MLAS_FP16(1.0f); + if (CountM == 1) { + if (alpha == f16_1.val && beta == f16_0.val) { + HGemm_B_Kernel_Simple<1, true>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc); + } else if (alpha == f16_1.val && beta == f16_1.val) { + HGemm_B_Kernel_Simple<1, false>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc); + } else { + HGemm_B_Kernel_Complicated<1>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } + } else { + if (alpha == f16_1.val && beta == f16_0.val) { + HGemm_B_Kernel_Simple<2, true>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc); + } else if (alpha == f16_1.val && beta == f16_1.val) { + HGemm_B_Kernel_Simple<2, false>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc); + } else { + HGemm_B_Kernel_Complicated<2>(A_data, B_data, C_data, CountN, CountK, lda, ldb, ldc, alpha, beta); + } + } +} + +// beta_behavior: 0 -> beta == 0, 1 -> beta == 1, 2 -> beta != 0 && beta != 1 +template +void HGemm_PackedB_Kernel_Impl( + const _mlas_fp16_* A, + const _mlas_fp16_* PackedB, + _mlas_fp16_* C, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + const float16x8_t alpha_v8 = MlasBroadcastFloat16x8(alpha); + const float16x8_t beta_v8 = MlasBroadcastFloat16x8(beta); + const float16x4_t alpha_v4 = MlasBroadcastFloat16x4(alpha); + const float16x4_t beta_v4 = MlasBroadcastFloat16x4(beta); + const bool Kr0 = CountK & 3; + const bool Kr1 = (CountK & 3) > 1; + const bool Kr2 = (CountK & 3) > 2; + const bool largeK = CountK >= 4; + for (; CountN >= 32; CountN -= 32, C += 32) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(); + float16x8_t accu02 = MlasZeroFloat16x8(); + float16x8_t accu03 = MlasZeroFloat16x8(); + float16x8_t accu10, accu11, accu12, accu13; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + accu12 = MlasZeroFloat16x8(); + accu13 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 64); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 96); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 72); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 104); + float16x8_t b02 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b12 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b22 = MlasLoadFloat16x8(PackedB + 80); + float16x8_t b32 = MlasLoadFloat16x8(PackedB + 112); + float16x8_t b03 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b13 = MlasLoadFloat16x8(PackedB + 56); + float16x8_t b23 = MlasLoadFloat16x8(PackedB + 88); + float16x8_t b33 = MlasLoadFloat16x8(PackedB + 120); + for (; k >= 8; k -= 4, a += 4, PackedB += 4 * 32) { + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda + 4); + } + b00 = MlasLoadFloat16x8(PackedB + 128); + b10 = MlasLoadFloat16x8(PackedB + 160); + b20 = MlasLoadFloat16x8(PackedB + 192); + b30 = MlasLoadFloat16x8(PackedB + 224); + b01 = MlasLoadFloat16x8(PackedB + 136); + b11 = MlasLoadFloat16x8(PackedB + 168); + b21 = MlasLoadFloat16x8(PackedB + 200); + b31 = MlasLoadFloat16x8(PackedB + 232); + b02 = MlasLoadFloat16x8(PackedB + 144); + b12 = MlasLoadFloat16x8(PackedB + 176); + b22 = MlasLoadFloat16x8(PackedB + 208); + b32 = MlasLoadFloat16x8(PackedB + 240); + b03 = MlasLoadFloat16x8(PackedB + 152); + b13 = MlasLoadFloat16x8(PackedB + 184); + b23 = MlasLoadFloat16x8(PackedB + 216); + b33 = MlasLoadFloat16x8(PackedB + 248); + } + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + accu02 = maq_lane_f16_accu(accu02, b02, b12, b22, b32, a0); + accu03 = maq_lane_f16_accu(accu03, b03, b13, b23, b33, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + accu12 = maq_lane_f16_accu(accu12, b02, b12, b22, b32, a1); + accu13 = maq_lane_f16_accu(accu13, b03, b13, b23, b33, a1); + } + k -= 4, a += 4, PackedB += 4 * 32; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b02 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b03 = MlasLoadFloat16x8(PackedB + 24); + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + accu02 = vfmaq_lane_f16(accu02, b02, a0, 0); + accu03 = vfmaq_lane_f16(accu03, b03, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + accu12 = vfmaq_lane_f16(accu12, b02, a1, 0); + accu13 = vfmaq_lane_f16(accu13, b03, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b12 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b13 = MlasLoadFloat16x8(PackedB + 56); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + accu02 = vfmaq_lane_f16(accu02, b12, a0, 1); + accu03 = vfmaq_lane_f16(accu03, b13, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + accu12 = vfmaq_lane_f16(accu12, b12, a1, 1); + accu13 = vfmaq_lane_f16(accu13, b13, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 64); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 72); + float16x8_t b22 = MlasLoadFloat16x8(PackedB + 80); + float16x8_t b23 = MlasLoadFloat16x8(PackedB + 88); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + accu02 = vfmaq_lane_f16(accu02, b22, a0, 2); + accu03 = vfmaq_lane_f16(accu03, b23, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + accu12 = vfmaq_lane_f16(accu12, b22, a1, 2); + accu13 = vfmaq_lane_f16(accu13, b23, a1, 2); + } + } + PackedB += k * 32; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + float16x8_t c02 = MlasLoadFloat16x8(C + 16); + float16x8_t c03 = MlasLoadFloat16x8(C + 24); + + MlasStoreFloat16x8(C, vfmaq_f16(c00, accu00, alpha_v8)); + MlasStoreFloat16x8(C + 8, vfmaq_f16(c01, accu01, alpha_v8)); + MlasStoreFloat16x8(C + 16, vfmaq_f16(c02, accu02, alpha_v8)); + MlasStoreFloat16x8(C + 24, vfmaq_f16(c03, accu03, alpha_v8)); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + float16x8_t c12 = MlasLoadFloat16x8(C + ldc + 16); + float16x8_t c13 = MlasLoadFloat16x8(C + ldc + 24); + MlasStoreFloat16x8(C + ldc, vfmaq_f16(c10, accu10, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 8, vfmaq_f16(c11, accu11, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 16, vfmaq_f16(c12, accu12, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 24, vfmaq_f16(c13, accu13, alpha_v8)); + } + } else if constexpr (beta_behavior == 2) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + float16x8_t c02 = MlasLoadFloat16x8(C + 16); + float16x8_t c03 = MlasLoadFloat16x8(C + 24); + + MlasStoreFloat16x8(C, vfmaq_f16(vmulq_f16(c00, beta_v8), accu00, alpha_v8)); + MlasStoreFloat16x8(C + 8, vfmaq_f16(vmulq_f16(c01, beta_v8), accu01, alpha_v8)); + MlasStoreFloat16x8(C + 16, vfmaq_f16(vmulq_f16(c02, beta_v8), accu02, alpha_v8)); + MlasStoreFloat16x8(C + 24, vfmaq_f16(vmulq_f16(c03, beta_v8), accu03, alpha_v8)); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + float16x8_t c12 = MlasLoadFloat16x8(C + ldc + 16); + float16x8_t c13 = MlasLoadFloat16x8(C + ldc + 24); + MlasStoreFloat16x8(C + ldc, vfmaq_f16(vmulq_f16(c10, beta_v8), accu10, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 8, vfmaq_f16(vmulq_f16(c11, beta_v8), accu11, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 16, vfmaq_f16(vmulq_f16(c12, beta_v8), accu12, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 24, vfmaq_f16(vmulq_f16(c13, beta_v8), accu13, alpha_v8)); + } + } else { + MlasStoreFloat16x8(C, vmulq_f16(accu00, alpha_v8)); + MlasStoreFloat16x8(C + 8, vmulq_f16(accu01, alpha_v8)); + MlasStoreFloat16x8(C + 16, vmulq_f16(accu02, alpha_v8)); + MlasStoreFloat16x8(C + 24, vmulq_f16(accu03, alpha_v8)); + if constexpr (CountM == 2) { + MlasStoreFloat16x8(C + ldc, vmulq_f16(accu10, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 8, vmulq_f16(accu11, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 16, vmulq_f16(accu12, alpha_v8)); + MlasStoreFloat16x8(C + ldc + 24, vmulq_f16(accu13, alpha_v8)); + } + } + } + + if (CountN & 16) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu01 = MlasZeroFloat16x8(), accu10, accu11; + if constexpr (CountM == 2) { + accu10 = MlasZeroFloat16x8(); + accu11 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + float16x8_t b30 = MlasLoadFloat16x8(PackedB + 48); + float16x8_t b31 = MlasLoadFloat16x8(PackedB + 56); + for (; k >= 8; k -= 4, a += 4, PackedB += 4 * 16) { + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda + 4); + } + b00 = MlasLoadFloat16x8(PackedB + 64); + b01 = MlasLoadFloat16x8(PackedB + 72); + b10 = MlasLoadFloat16x8(PackedB + 80); + b11 = MlasLoadFloat16x8(PackedB + 88); + b20 = MlasLoadFloat16x8(PackedB + 96); + b21 = MlasLoadFloat16x8(PackedB + 104); + b30 = MlasLoadFloat16x8(PackedB + 112); + b31 = MlasLoadFloat16x8(PackedB + 120); + } + accu00 = maq_lane_f16_accu(accu00, b00, b10, b20, b30, a0); + accu01 = maq_lane_f16_accu(accu01, b01, b11, b21, b31, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b00, b10, b20, b30, a1); + accu11 = maq_lane_f16_accu(accu11, b01, b11, b21, b31, a1); + } + k -= 4, a += 4, PackedB += 4 * 16; + } + + if (Kr0) { + float16x8_t b00 = MlasLoadFloat16x8(PackedB); + float16x8_t b01 = MlasLoadFloat16x8(PackedB + 8); + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + accu00 = vfmaq_lane_f16(accu00, b00, a0, 0); + accu01 = vfmaq_lane_f16(accu01, b01, a0, 0); + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + accu10 = vfmaq_lane_f16(accu10, b00, a1, 0); + accu11 = vfmaq_lane_f16(accu11, b01, a1, 0); + } + if (Kr1) { + float16x8_t b10 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b11 = MlasLoadFloat16x8(PackedB + 24); + accu00 = vfmaq_lane_f16(accu00, b10, a0, 1); + accu01 = vfmaq_lane_f16(accu01, b11, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b10, a1, 1); + accu11 = vfmaq_lane_f16(accu11, b11, a1, 1); + } + } + if (Kr2) { + float16x8_t b20 = MlasLoadFloat16x8(PackedB + 32); + float16x8_t b21 = MlasLoadFloat16x8(PackedB + 40); + accu00 = vfmaq_lane_f16(accu00, b20, a0, 2); + accu01 = vfmaq_lane_f16(accu01, b21, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b20, a1, 2); + accu11 = vfmaq_lane_f16(accu11, b21, a1, 2); + } + } + PackedB += k * 16; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + accu00 = vfmaq_f16(c00, accu00, alpha_v8); + accu01 = vfmaq_f16(c01, accu01, alpha_v8); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + accu10 = vfmaq_f16(c10, accu10, alpha_v8); + accu11 = vfmaq_f16(c11, accu11, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } else if constexpr (beta_behavior == 2) { + float16x8_t c00 = MlasLoadFloat16x8(C); + float16x8_t c01 = MlasLoadFloat16x8(C + 8); + accu00 = vfmaq_f16(vmulq_f16(c00, beta_v8), accu00, alpha_v8); + accu01 = vfmaq_f16(vmulq_f16(c01, beta_v8), accu01, alpha_v8); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + if constexpr (CountM == 2) { + float16x8_t c10 = MlasLoadFloat16x8(C + ldc); + float16x8_t c11 = MlasLoadFloat16x8(C + ldc + 8); + accu10 = vfmaq_f16(vmulq_f16(c10, beta_v8), accu10, alpha_v8); + accu11 = vfmaq_f16(vmulq_f16(c11, beta_v8), accu11, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } else { + accu00 = vmulq_f16(accu00, alpha_v8); + accu01 = vmulq_f16(accu01, alpha_v8); + MlasStoreFloat16x8(C, accu00); + MlasStoreFloat16x8(C + 8, accu01); + if constexpr (CountM == 2) { + accu10 = vmulq_f16(accu10, alpha_v8); + accu11 = vmulq_f16(accu11, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + MlasStoreFloat16x8(C + ldc + 8, accu11); + } + } + + CountN -= 16, C += 16; + } + + if (CountN & 8) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu00 = MlasZeroFloat16x8(); + float16x8_t accu10 = MlasZeroFloat16x8(); + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + for (; k >= 8; k -= 4, a += 4, PackedB += 4 * 8) { + accu00 = maq_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda + 4); + } + b0 = MlasLoadFloat16x8(PackedB + 32); + b1 = MlasLoadFloat16x8(PackedB + 40); + b2 = MlasLoadFloat16x8(PackedB + 48); + b3 = MlasLoadFloat16x8(PackedB + 56); + } + accu00 = maq_lane_f16_accu(accu00, b0, b1, b2, b3, a0); + if constexpr (CountM == 2) { + accu10 = maq_lane_f16_accu(accu10, b0, b1, b2, b3, a1); + } + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu00 = vfmaq_lane_f16(accu00, b0, a0, 0); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b0, a1, 0); + } + if (Kr1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu00 = vfmaq_lane_f16(accu00, b1, a0, 1); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b1, a1, 1); + } + } + if (Kr2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu00 = vfmaq_lane_f16(accu00, b2, a0, 2); + if constexpr (CountM == 2) { + accu10 = vfmaq_lane_f16(accu10, b2, a1, 2); + } + } + PackedB += k * 8; + } + + if constexpr (beta_behavior == 1) { + float16x8_t c0 = MlasLoadFloat16x8(C); + accu00 = vfmaq_f16(c0, accu00, alpha_v8); + MlasStoreFloat16x8(C, accu00); + if constexpr (CountM == 2) { + float16x8_t c1 = MlasLoadFloat16x8(C + ldc); + accu10 = vfmaq_f16(c1, accu10, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + } + } else if constexpr (beta_behavior == 2) { + float16x8_t c0 = MlasLoadFloat16x8(C); + accu00 = vfmaq_f16(vmulq_f16(c0, beta_v8), accu00, alpha_v8); + MlasStoreFloat16x8(C, accu00); + if constexpr (CountM == 2) { + float16x8_t c1 = MlasLoadFloat16x8(C + ldc); + accu10 = vfmaq_f16(vmulq_f16(c1, beta_v8), accu10, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + } + } else { + accu00 = vmulq_f16(accu00, alpha_v8); + MlasStoreFloat16x8(C, accu00); + if constexpr (CountM == 2) { + accu10 = vmulq_f16(accu10, alpha_v8); + MlasStoreFloat16x8(C + ldc, accu10); + } + } + + CountN -= 8, C += 8; + } + + if (CountN > 0) { + const auto* a = A; + size_t k = CountK; + float16x8_t accu0 = MlasZeroFloat16x8(), accu1; + if constexpr (CountM == 2) { + accu1 = MlasZeroFloat16x8(); + } + if (largeK) { + float16x4_t a0 = MlasLoadFloat16x4(a), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda); + } + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + float16x8_t b3 = MlasLoadFloat16x8(PackedB + 24); + for (; k >= 8; k -= 4, a += 4, PackedB += 4 * 8) { + accu0 = maq_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + if constexpr (CountM == 2) { + accu1 = maq_lane_f16_accu(accu1, b0, b1, b2, b3, a1); + } + a0 = MlasLoadFloat16x4(a + 4); + if constexpr (CountM == 2) { + a1 = MlasLoadFloat16x4(a + lda + 4); + } + b0 = MlasLoadFloat16x8(PackedB + 32); + b1 = MlasLoadFloat16x8(PackedB + 40); + b2 = MlasLoadFloat16x8(PackedB + 48); + b3 = MlasLoadFloat16x8(PackedB + 56); + } + accu0 = maq_lane_f16_accu(accu0, b0, b1, b2, b3, a0); + if constexpr (CountM == 2) { + accu1 = maq_lane_f16_accu(accu1, b0, b1, b2, b3, a1); + } + k -= 4, a += 4, PackedB += 4 * 8; + } + + if (Kr0) { + float16x4_t a0 = MlasLoadPartialFloat16x4(a, k), a1; + if constexpr (CountM == 2) { + a1 = MlasLoadPartialFloat16x4(a + lda, k); + } + float16x8_t b0 = MlasLoadFloat16x8(PackedB); + accu0 = vfmaq_lane_f16(accu0, b0, a0, 0); + if constexpr (CountM == 2) { + accu1 = vfmaq_lane_f16(accu1, b0, a1, 0); + } + if (Kr1) { + float16x8_t b1 = MlasLoadFloat16x8(PackedB + 8); + accu0 = vfmaq_lane_f16(accu0, b1, a0, 1); + if constexpr (CountM == 2) { + accu1 = vfmaq_lane_f16(accu1, b1, a1, 1); + } + } + if (Kr2) { + float16x8_t b2 = MlasLoadFloat16x8(PackedB + 16); + accu0 = vfmaq_lane_f16(accu0, b2, a0, 2); + if constexpr (CountM == 2) { + accu1 = vfmaq_lane_f16(accu1, b2, a1, 2); + } + } + PackedB += k * 8; + } + + float16x4_t accu0_low = vget_low_f16(accu0); + float16x4_t accu0_high = vget_high_f16(accu0); + float16x4_t accu1_low, accu1_high; + if constexpr (CountM == 2) { + accu1_low = vget_low_f16(accu1); + accu1_high = vget_high_f16(accu1); + } + + if (CountN & 4) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadFloat16x4(C); + MlasStoreFloat16x4(C, vfma_f16(c0, accu0_low, alpha_v4)); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadFloat16x4(C + ldc); + MlasStoreFloat16x4(C + ldc, vfma_f16(c1, accu1_low, alpha_v4)); + } + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadFloat16x4(C); + MlasStoreFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v4), accu0_low, alpha_v4)); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadFloat16x4(C + ldc); + MlasStoreFloat16x4(C + ldc, vfma_f16(vmul_f16(c1, beta_v4), accu1_low, alpha_v4)); + } + } else { + MlasStoreFloat16x4(C, vmul_f16(accu0_low, alpha_v4)); + if constexpr (CountM == 2) { + MlasStoreFloat16x4(C + ldc, vmul_f16(accu1_low, alpha_v4)); + } + } + CountN -= 4, C += 4; + accu0_low = accu0_high; + if constexpr (CountM == 2) { + accu1_low = accu1_high; + } + } + + if (CountN) { + if constexpr (beta_behavior == 1) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + MlasStorePartialFloat16x4(C, vfma_f16(c0, accu0_low, alpha_v4), CountN); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadPartialFloat16x4(C + ldc, CountN); + MlasStorePartialFloat16x4(C + ldc, vfma_f16(c1, accu1_low, alpha_v4), CountN); + } + } else if constexpr (beta_behavior == 2) { + float16x4_t c0 = MlasLoadPartialFloat16x4(C, CountN); + MlasStorePartialFloat16x4(C, vfma_f16(vmul_f16(c0, beta_v4), accu0_low, alpha_v4), CountN); + if constexpr (CountM == 2) { + float16x4_t c1 = MlasLoadPartialFloat16x4(C + ldc, CountN); + MlasStorePartialFloat16x4(C + ldc, vfma_f16(vmul_f16(c1, beta_v4), accu1_low, alpha_v4), CountN); + } + } else { + MlasStorePartialFloat16x4(C, vmul_f16(accu0_low, alpha_v4), CountN); + if constexpr (CountM == 2) { + MlasStorePartialFloat16x4(C + ldc, vmul_f16(accu1_low, alpha_v4), CountN); + } + } + } + } +} + +void HGemm_PackedB_Kernel( + const MLAS_FP16* A, + const MLAS_FP16* PackedB, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t lda, + size_t ldc, + _mlas_fp16_ alpha, + _mlas_fp16_ beta +) { + if (CountM > 2) { + MLAS_THROW_EX(std::runtime_error, "HGemm_PackedB_Kernel only support <= 2 rows"); + } + + const auto* A_data = reinterpret_cast(A); + const auto* PackedB_data = reinterpret_cast(PackedB); + auto* C_data = reinterpret_cast<_mlas_fp16_*>(C); + const auto f16_0 = MLAS_FP16(0.0f); + const auto f16_1 = MLAS_FP16(1.0f); + if (CountM == 1) { + if (beta == f16_0.val) { + HGemm_PackedB_Kernel_Impl<0, 1>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_PackedB_Kernel_Impl<1, 1>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else { + HGemm_PackedB_Kernel_Impl<2, 1>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } + } else { + if (beta == f16_0.val) { + HGemm_PackedB_Kernel_Impl<0, 2>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else if (beta == f16_1.val) { + HGemm_PackedB_Kernel_Impl<1, 2>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } else { + HGemm_PackedB_Kernel_Impl<2, 2>(A_data, PackedB_data, C_data, CountN, CountK, lda, ldc, alpha, beta); + } + } +} + +} // namespace hgemm_neon diff --git a/src/lib/hgemm_kernel_neon.cpp b/src/lib/hgemm_kernel_neon.cpp new file mode 100644 index 0000000..1531ce9 --- /dev/null +++ b/src/lib/hgemm_kernel_neon.cpp @@ -0,0 +1,30 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + hgemm_kernel_neon.cpp + +Abstract: + + This module implements half precision GEMM kernel for neon. + +--*/ + +#include "mlasi.h" +#include "halfgemm.h" + +const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon = [](){ + MLAS_HGEMM_DISPATCH d; +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + d.HPackBKernel_TransposedB = hgemm_neon::HPackB_TransposedB_Kernel; + d.HPackBKernel_B = hgemm_neon::HPackB_B_Kernel; + d.HGemmKernel_TransposedB = hgemm_neon::HGemm_TransposedB_Kernel; + d.HGemmKernel_B = hgemm_neon::HGemm_B_Kernel; + d.HGemmKernel_PackedB = hgemm_neon::HGemm_PackedB_Kernel; +#endif + return d; +}(); diff --git a/src/lib/hqnbitgemm_kernel_neon_fp16.cpp b/src/lib/hqnbitgemm_kernel_neon_fp16.cpp index 69e37d2..5b1f9d7 100644 --- a/src/lib/hqnbitgemm_kernel_neon_fp16.cpp +++ b/src/lib/hqnbitgemm_kernel_neon_fp16.cpp @@ -93,39 +93,6 @@ Transpose8x8(uint8x8_t& v0, uint8x8_t& v1, uint8x8_t& v2, uint8x8_t& v3, v7 = vreinterpret_u8_u32(c3.val[1]); } -MLAS_FORCEINLINE void -Transpose4x8(float16x8_t& v0, float16x8_t& v1, float16x8_t& v2, float16x8_t& v3) -{ - // |v00|v01|v02|v03|v04|v05|v06|v07| - // |v10|v11|v12|v13|v14|v15|v16|v17| - // |v20|v21|v22|v23|v24|v25|v26|v27| - // |v30|v31|v32|v33|v34|v35|v36|v37| - // => - // |v00|v10|v20|v30|v04|v14|v24|v34| - // |v01|v11|v21|v31|v05|v15|v25|v35| - // |v02|v12|v22|v32|v06|v16|v26|v36| - // |v03|v13|v23|v33|v07|v17|v27|v37| - float16x8x2_t t01 = vtrnq_f16(v0, v1); - float16x8x2_t t23 = vtrnq_f16(v2, v3); - - v0 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); - v1 = vreinterpretq_f16_f32(vtrn1q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); - v2 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[0]), vreinterpretq_f32_f16(t23.val[0]))); - v3 = vreinterpretq_f16_f32(vtrn2q_f32(vreinterpretq_f32_f16(t01.val[1]), vreinterpretq_f32_f16(t23.val[1]))); -} - -MLAS_FORCEINLINE void -Transpose4x4(float16x4_t& v0, float16x4_t& v1, float16x4_t& v2, float16x4_t& v3) -{ - float16x4x2_t t01 = vtrn_f16(v0, v1); - float16x4x2_t t23 = vtrn_f16(v2, v3); - - v0 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); - v1 = vreinterpret_f16_f32(vtrn1_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); - v2 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[0]), vreinterpret_f32_f16(t23.val[0]))); - v3 = vreinterpret_f16_f32(vtrn2_f32(vreinterpret_f32_f16(t01.val[1]), vreinterpret_f32_f16(t23.val[1]))); -} - void HQ4BitGemmPackQuantBData_CompFp16( size_t N, diff --git a/src/lib/intrinsics/avx2/saturation_check_avx2.cpp b/src/lib/intrinsics/avx2/saturation_check_avx2.cpp new file mode 100644 index 0000000..5ff4c0e --- /dev/null +++ b/src/lib/intrinsics/avx2/saturation_check_avx2.cpp @@ -0,0 +1,62 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + saturation_check_avx2.cpp + +Abstract: + + This module implements logic to check saturation of the VPMADDUBSW + instruction. + +--*/ + +#include + +#include +#include + +namespace onnxruntime +{ +extern std::atomic saturation_count; +} + +extern "C" void +CheckSaturationForVPMADDUBSW(const __m256i* unsigned_ptr, const __m256i* signed_ptr) +{ + // Load data from memory (unaligned load) + __m256i unsigned_data = _mm256_loadu_si256(unsigned_ptr); + __m256i signed_data = _mm256_loadu_si256(signed_ptr); + + alignas(32) uint8_t unsigned_bytes[32]; // Unsigned input values + alignas(32) int8_t signed_bytes[32]; // Signed input values + + // Store the data into the byte arrays + _mm256_store_si256(reinterpret_cast<__m256i*>(unsigned_bytes), unsigned_data); + _mm256_store_si256(reinterpret_cast<__m256i*>(signed_bytes), signed_data); + + bool saturation_detected = false; + + // Iterate through the 16 pairs of 8-bit unsigned and signed values + for (int i = 0; i < 16; ++i) { + // Perform the VPMADDUBSW operation in higher precision (int32_t) + int32_t computed_value = + static_cast(signed_bytes[2 * i]) * static_cast(static_cast(unsigned_bytes[2 * i])) + + static_cast(signed_bytes[2 * i + 1]) * static_cast(static_cast(unsigned_bytes[2 * i + 1])); + + // If the computed value exceeds the 16-bit signed integer range, saturation occurred + if (computed_value > INT16_MAX || computed_value < INT16_MIN) { + saturation_detected = true; + break; + } + } + + // If saturation is detected, log a warning (only log once based on the atomic count) + if (saturation_detected && ++onnxruntime::saturation_count < 2) { + std::cerr << "Warning: saturation detected in VPMADDUBSW instruction." << std::endl; + } +} diff --git a/src/lib/kai_ukernel_interface.cpp b/src/lib/kai_ukernel_interface.cpp new file mode 100644 index 0000000..fdada83 --- /dev/null +++ b/src/lib/kai_ukernel_interface.cpp @@ -0,0 +1,81 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include "kai_ukernel_interface.h" +#include "mlasi.h" + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm.h" + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod = + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod}; + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod = + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod}; + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod = + {kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod}; + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm = + {kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm}; + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel() { + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + return kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm; + } else { + return kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod; + } +} + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel() { + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + return kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod; + } else { + return kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod; + } +} diff --git a/src/lib/kai_ukernel_interface.h b/src/lib/kai_ukernel_interface.h new file mode 100644 index 0000000..1a6f111 --- /dev/null +++ b/src/lib/kai_ukernel_interface.h @@ -0,0 +1,12 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp_qsi4c32p_interface.h" + +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemmUKernel(); +const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& GetKleidiAIGemvUKernel(); diff --git a/src/lib/mlasi.h b/src/lib/mlasi.h index 0533a5e..184816a 100644 --- a/src/lib/mlasi.h +++ b/src/lib/mlasi.h @@ -18,6 +18,7 @@ Module Name: #pragma once #include +#include #include #include #include @@ -301,6 +302,8 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // Define the default strides to step through slices of the input matrices. // +#define MLAS_HGEMM_STRIDEN 128 +#define MLAS_HGEMM_STRIDEK 128 #define MLAS_SGEMM_STRIDEN 128 #define MLAS_SGEMM_STRIDEK 128 #define MLAS_SGEMM_PACKED_STRIDEN 128 @@ -317,6 +320,7 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // the effort at this time. // +#define MLAS_HGEMM_STRIDEN_THREAD_ALIGN 32 #define MLAS_SGEMM_STRIDEN_THREAD_ALIGN 16 #define MLAS_DGEMM_STRIDEN_THREAD_ALIGN 8 #define MLAS_QGEMM_STRIDEN_THREAD_ALIGN 16 @@ -944,6 +948,7 @@ extern "C" { #define MLAS_SGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#define MLAS_HGEMM_THREAD_COMPLEXITY 65536 #if defined(__aarch64__) && defined(__linux__) #define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) @@ -992,9 +997,14 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmRelaxedSimd; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10; +#if defined(MLAS_TARGET_WASM_RELAXED_SIMD) +extern bool HasUSDot(); +#endif + // // Symmetric quantized qgemm dispatch structure // @@ -1039,7 +1049,10 @@ extern const MLAS_FPQ4GEMM_DISPATCH MlasFpQ4GemmDispatchAvx512; struct MLAS_QNBIT_GEMM_DISPATCH; -extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon; +const MLAS_QNBIT_GEMM_DISPATCH& +GetMlasQNBitGemmDispatchNeon( + bool InitializeWithDotSupport +); extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; @@ -1049,6 +1062,27 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; +// +// Rotary embedding dispatch structure. +// +struct MLAS_ROPE_DISPATCH; +extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon; +extern const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2; + +// +// half gemm dispatch structure +// +struct MLAS_HGEMM_DISPATCH; +extern const MLAS_HGEMM_DISPATCH MlasHGemmDispatchNeon; + +// softmax dispatch structure +struct MLAS_SOFTMAX_DISPATCH; +extern const MLAS_SOFTMAX_DISPATCH MlasSoftmaxDispatchNeon; + +// eltwise dispatch structure +struct MLAS_ELTWISE_DISPATCH; +extern const MLAS_ELTWISE_DISPATCH MlasEltwiseDispatchNeon; + // // Quantized depthwise convolution kernels. // @@ -1111,6 +1145,10 @@ struct MLAS_PLATFORM { MLAS_PLATFORM(void); + // TODO: move to cpuinfo + bool Avx2Supported_ = false; + bool Avx512Supported_ = false; + #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; #endif @@ -1208,6 +1246,11 @@ struct MLAS_PLATFORM { MLAS_CAST_F16_TO_F32_KERNEL* CastF16ToF32Kernel; MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel; + + const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr}; + const MLAS_HGEMM_DISPATCH* HGemmDispatch{nullptr}; + const MLAS_SOFTMAX_DISPATCH* SoftmaxDispatch{nullptr}; + const MLAS_ELTWISE_DISPATCH* EltwiseDispatch{nullptr}; }; inline @@ -1417,6 +1460,9 @@ MlasConvDepthwiseFloat_CHW( #endif #elif defined(MLAS_TARGET_WASM_SIMD) #define MLAS_WASM_SIMD_INTRINSICS +#if defined(MLAS_TARGET_WASM_RELAXED_SIMD) +#define MLAS_WASM_RELAXED_SIMD_INTRINSICS +#endif #elif defined(MLAS_TARGET_LARCH64) #define MLAS_LSX_INTRINSICS #endif @@ -2227,6 +2273,8 @@ MlasMaximumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) #elif defined(MLAS_VSX_INTRINSICS) // Don't use vec_max to avoid undefined behavior if NAN return vec_sel(Vector2, Vector1, vec_cmpgt(Vector1, Vector2)); +#elif defined(MLAS_WASM_RELAXED_SIMD_INTRINSICS) + return wasm_f32x4_relaxed_max(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_max(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) @@ -2247,6 +2295,8 @@ MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) #elif defined(MLAS_VSX_INTRINSICS) // Don't use vec_min to avoid undefined behavior if NAN return vec_sel(Vector2, Vector1, vec_cmpgt(Vector2, Vector1)); +#elif defined(MLAS_WASM_RELAXED_SIMD_INTRINSICS) + return wasm_f32x4_relaxed_min(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_min(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) diff --git a/src/lib/platform.cpp b/src/lib/platform.cpp index 2aea7a9..7e875c2 100644 --- a/src/lib/platform.cpp +++ b/src/lib/platform.cpp @@ -375,6 +375,8 @@ Return Value: if (((Cpuid1[2] & 0x1000) != 0) && ((Cpuid7[1] & 0x20) != 0)) { + this->Avx2Supported_ = true; + this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAvx2; this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx2; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx2; @@ -402,6 +404,7 @@ Return Value: this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx2; this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelAvx2; this->CastF32ToF16Kernel = &MlasCastF32ToF16KernelAvx2; + this->RopeDispatch = &MlasRopeDispatchAvx2; // @@ -465,6 +468,8 @@ Return Value: if ((Cpuid7[1] & 0xC0020000) == 0xC0020000) { + this->Avx512Supported_ = true; + this->GemmU8S8Kernel = MlasGemmU8S8KernelAvx512Core; this->GemvU8S8Kernel = MlasGemvU8S8KernelAvx512Core; this->GemmU8U8Kernel = MlasGemmU8U8KernelAvx512Core; @@ -543,28 +548,25 @@ Return Value: this->SymmQgemmDispatch = &MlasSymmQgemmS8DispatchNeon; this->ConvSymU8S8Dispatch = &MlasConvSymU8DispatchNeon; this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchNeon; - this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchNeon; + this->RopeDispatch = &MlasRopeDispatchNeon; + this->HGemmDispatch = &MlasHGemmDispatchNeon; + this->SoftmaxDispatch = &MlasSoftmaxDispatchNeon; + this->EltwiseDispatch = &MlasEltwiseDispatchNeon; // // Check if the processor supports ASIMD dot product instructions. // - bool HasDotProductInstructions; - -#if defined(_WIN32) - HasDotProductInstructions = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); -#else - // Use the cpuinfo value which is read from sysctl and has some additional special cases. - // https://github.com/pytorch/cpuinfo/blob/959002f82d7962a473d8bf301845f2af720e0aa4/src/arm/mach/init.c#L369-L379 + // Note: // Do NOT use ID_AA64ISAR0_EL1. It causes illegal instruction errors on Mac M1 and ARMv8-A chips // as well as failing on other ARM chips as it is an EL1 level register that requires extra // privileges to read. // // uint64_t isar0_el1; // asm("mrs %[reg], ID_AA64ISAR0_EL1\n" : [reg] "=r"(isar0_el1) : :); - // HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; - HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); -#endif + // const bool HasDotProductInstructions = ((isar0_el1 >> 44) & 0xfu) == 0x1u; + + const bool HasDotProductInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); if (HasDotProductInstructions) { this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUdot; @@ -575,6 +577,8 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } + this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions); + #if defined(__linux__) // // Check if the processor supports ASIMD I8MM instructions. diff --git a/src/lib/pooling_fp16.cpp b/src/lib/pooling_fp16.cpp index 98e8473..9765192 100644 --- a/src/lib/pooling_fp16.cpp +++ b/src/lib/pooling_fp16.cpp @@ -80,23 +80,23 @@ PoolInit16x4() } template<> -MLAS_FORCEINLINE +MLAS_FORCEINLINE MLAS_FLOAT16X8 PoolAggregate16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 element) { - return MlasMaximumFloat16x8(agg, element); + return MlasMaximumFloat16(agg, element); } template<> -MLAS_FORCEINLINE +MLAS_FORCEINLINE MLAS_FLOAT16X4 PoolAggregate16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X4 element) { - return MlasMaximumFloat16x4(agg, element); + return MlasMaximumFloat16(agg, element); } template<> -MLAS_FORCEINLINE +MLAS_FORCEINLINE MLAS_FLOAT16X8 PoolSummary16x8(MLAS_FLOAT16X8 agg, size_t size) { @@ -105,7 +105,7 @@ PoolSummary16x8(MLAS_FLOAT16X8 agg, size_t size) } template<> -MLAS_FORCEINLINE +MLAS_FORCEINLINE MLAS_FLOAT16X4 PoolSummary16x4(MLAS_FLOAT16X4 agg, size_t size) { @@ -144,28 +144,28 @@ template <> MLAS_FORCEINLINE MLAS_FLOAT16X8 PoolAggregate16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 element) { - return MlasAddFloat16x8(agg, element); + return MlasAddFloat16(agg, element); } template <> MLAS_FORCEINLINE MLAS_FLOAT16X4 PoolAggregate16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X4 element) { - return MlasAddFloat16x4(agg, element); + return MlasAddFloat16(agg, element); } template <> MLAS_FORCEINLINE MLAS_FLOAT16X8 PoolSummary16x8(MLAS_FLOAT16X8 agg, MLAS_FLOAT16X8 context) { - return MlasDivFloat16x8(agg, context); + return MlasDivideFloat16(agg, context); } template <> MLAS_FORCEINLINE MLAS_FLOAT16X4 PoolSummary16x4(MLAS_FLOAT16X4 agg, MLAS_FLOAT16X8 context) { - return MlasDivFloat16x4(agg, MlasToLowHalfFloat16x4(context)); + return MlasDivideFloat16(agg, MlasToLowHalfFloat16x4(context)); } diff --git a/src/lib/q4_dq.cpp b/src/lib/q4_dq.cpp index df61d3e..c543770 100644 --- a/src/lib/q4_dq.cpp +++ b/src/lib/q4_dq.cpp @@ -20,10 +20,13 @@ Module Name: #include "q4common.h" -template -constexpr size_t BlkQ4BufSize(size_t N, size_t K) { - const size_t KBlocks = MlasDivRoundup(K, T::BlkLen); - return N * KBlocks * T::BlobSize; +template +constexpr +size_t +BlkQ4BufSize(size_t N, size_t K) +{ + const size_t KBlocks = MlasDivRoundup(K, T::BlkLen); + return N * KBlocks * T::BlobSize; } size_t @@ -325,7 +328,7 @@ struct BitsTraits { static constexpr float halfRange = static_cast(kMid - kMin); // number of qbit elements to pack into whole bytes - static constexpr int kPackSize = (qbits == 8) ? 1 : (qbits == 4) ? 2 : (qbits == 2) ? 4 : 0; + static constexpr int kPackSize = (qbits == 8) ? 1 : ((qbits == 4) ? 2 : ((qbits == 2) ? 4 : 0)); static_assert(kPackSize != 0, "Packing to whole bytes not supported for this qbits!"); }; @@ -384,12 +387,14 @@ range2scale(float min, float max, ScaleT& scale) /** - * @brief Blockwise quantization methods + * TODO(fajin): use int4/8 for symmetric quantization so the (vq - zp) operation in MatMulNBits can be saved. + * @brief Blockwise quantization methods. Source is row major. Dest, scale and zp are column major. + * Always quantize to unsigned int. * @tparam ElementT source data type, e.g. fp32/fp16 * @tparam block_size number of elemenets quantized together * @tparam qbits number of bits in each quantized element - * @tparam Columnwise true: elements in a block come from one single column - * false: elements in a block come from one single row + * @tparam Columnwise true: quantize along src column, pack along src column. + * false: quantize along src row, pack along src column. */ template < typename ElementT, @@ -399,11 +404,18 @@ template < struct BlockwiseQuantizer { // To support other qbits, need to add bit packing code for // storing to dst and zero points - static_assert(qbits == 4, "Only 4b block quantization is supported!"); + static_assert(qbits == 2 || qbits == 4 || qbits == 8, "Only 2b, 4b and 8b block quantization is supported!"); using QuantBlk = std::conditional_t, Shape2D<1, block_size>>; using ThreadBlk = Shape2D::kPackSize, QuantBlk::kColumn>; + static + MLAS_FORCEINLINE + int GetElem(int val, int idx) + { + return (val >> (qbits * idx)) & ((1 << qbits) - 1); + } + static MLAS_FORCEINLINE void quantizeMetaShape(int rows, int columns, int& meta_rows, int& meta_cols) @@ -437,14 +449,14 @@ struct BlockwiseQuantizer { scale_num_elements = meta_rows * meta_cols; if (zero_point_bytes) { - // this works for qbits == 4 but may need to be updated for other qbits values + // this works for qbits == 2, 4 or 8 but may need to be updated for other qbits values *zero_point_bytes = ((meta_rows * qbits + 7) / 8) * meta_cols; } } /** * @brief Quantized a Matrix shape [rows, columns], resulting quantized - * and packed data are stored in column major (transposed) + * and packed data are stored in column major (transposed). * @param[out] dst pointer to the quantized weights, column major: [columns, rows] * @param[out] scale pointer to the scales, column major: [columns/QuantBlk::kColumn, rows/QuantBlk::kRow] * @param[out] zero_points pointer to the zero points, same shape as scale @@ -476,8 +488,10 @@ struct BlockwiseQuantizer { MlasTryBatchParallel( thread_pool, total_thrd_blks, [&](ptrdiff_t block_idx) { - uint8_t zp_bytes[BitsTraits::kPackSize]; - std::fill_n(zp_bytes, BitsTraits::kPackSize, (uint8_t)8); + constexpr int kPackSize = BitsTraits::kPackSize; + uint8_t zp_bytes[kPackSize], vi[kPackSize]; + std::fill_n(zp_bytes, kPackSize, (uint8_t)BitsTraits::kMid); + std::fill_n(vi, kPackSize, 0); const int32_t r_blk_idx = static_cast(block_idx / thrd_col_blks); const int32_t c_blk_idx = static_cast(block_idx % thrd_col_blks); @@ -492,7 +506,7 @@ struct BlockwiseQuantizer { const int meta_col = c / QuantBlk::kColumn; // compute scale and zero point - for (int kpack = 0; kpack < BitsTraits::kPackSize; kpack++) { + for (int kpack = 0; kpack < kPackSize; kpack++) { // scan a single block to extract range [min, max] float min = std::numeric_limits::max(); @@ -518,40 +532,42 @@ struct BlockwiseQuantizer { } } - // !! 4b specific code as we need to pack 2 4b numbers into one byte if (zero_points != nullptr) { - const int32_t meta_idx = meta_col * ((row_blks + 1) / 2) + meta_row / 2; - zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + const int32_t meta_idx = meta_col * ((row_blks + kPackSize - 1) / kPackSize) + meta_row / kPackSize; + if constexpr (qbits == 8) { + zero_points[meta_idx] = zp_bytes[0]; + } else if constexpr (qbits == 4) { + zero_points[meta_idx] = (zp_bytes[0] & 0xf) | (zp_bytes[1] << 4); + } else if constexpr (qbits == 2) { + zero_points[meta_idx] = (zp_bytes[0] & 0x3) | (zp_bytes[1] << 2) | (zp_bytes[2] << 4) | (zp_bytes[3] << 6); + } else { + MLAS_THROW_EX(std::runtime_error, "Unsupported qbits"); + } } for (int32_t j = c; j < c_end; ++j) { const int32_t meta_c = j / QuantBlk::kColumn; - for (int32_t i = r; i < r_end; i += 2) { - const int32_t meta_r = i / QuantBlk::kRow; - const float scale = static_cast(scales[meta_c * row_blks + meta_r]); - const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; - const int8_t zp = zp_bytes[meta_r & 1]; - const int8_t zp1 = zp_bytes[((i + 1) / QuantBlk::kRow) & 1]; - - const float v0 = static_cast(src[i * leadingDimension + j]); - const uint8_t vi0 = (uint8_t)std::clamp(roundf(v0 * reciprocal_scale + zp), - 0.0f, BitsTraits::kMaxFp); - - uint8_t vi1 = (uint8_t)zp; - if (i + 1 < r_end) { - float reciprocal_scale1 = reciprocal_scale; - if constexpr (QuantBlk::kRow == 1) { - const float scale1 = - static_cast(scales[meta_c * row_blks + meta_r + 1]); - reciprocal_scale1 = scale1 ? 1.0f / scale1 : 0.0f; - } - const float v1 = static_cast(src[(i + 1) * leadingDimension + j]); - vi1 = (uint8_t)std::clamp(roundf(v1 * reciprocal_scale1 + zp1), 0.0f, - BitsTraits::kMaxFp); + for (int32_t i = r; i < r_end; i += kPackSize) { + for (int l = 0; l < kPackSize && i + l < r_end; l++) { + const int32_t meta_r = (i + l) / QuantBlk::kRow; + const float scale = static_cast(scales[meta_c * row_blks + meta_r]); + const float reciprocal_scale = scale ? 1.0f / scale : 0.0f; + const int32_t zp = zp_bytes[meta_r % kPackSize]; + + const float v = static_cast(src[(i + l) * leadingDimension + j]); + vi[l] = (uint8_t)std::clamp(roundf(v * reciprocal_scale + zp), + 0.0f, BitsTraits::kMaxFp); } - // !! 4b specific code - dst[j * q_rows + i / 2] = (vi0 & 0xf) | (vi1 << 4); + if constexpr (qbits == 8) { + dst[j * q_rows + i / kPackSize] = vi[0]; + } else if constexpr (qbits == 4) { + dst[j * q_rows + i / kPackSize] = (vi[0] & 0xf) | (vi[1] << 4); + } else if constexpr (qbits == 2) { + dst[j * q_rows + i / kPackSize] = (vi[0] & 0x3) | (vi[1] << 2) | (vi[2] << 4) | (vi[3] << 6); + } else { + MLAS_THROW_EX(std::runtime_error, "Unsupported qbits"); + } } } }); @@ -586,6 +602,7 @@ struct BlockwiseQuantizer { int q_rows, q_cols; quantizedShape(rows, columns, q_rows, q_cols); + constexpr int32_t kPackSize = BitsTraits::kPackSize; MlasTryBatchParallel( thread_pool, total_thrd_blks, @@ -602,38 +619,22 @@ struct BlockwiseQuantizer { for (int32_t j = c; j < c_end; ++j) { const int32_t meta_col = j / QuantBlk::kColumn; - // !! 4b specific code - // the whole loop is 4b specific due to sub 8 bit packing - // and unpacking. We can potentially make this qbits generic - // by wraping the packing/unpacking code like cutlass::Array - for (int32_t i = r; i < r_end; i += 2) { + for (int32_t i = r; i < r_end; ++i) { const int32_t meta_row = i / QuantBlk::kRow; - - const float scale0 = - static_cast(scales[meta_col * row_blks + meta_row]); - + const float scale = static_cast(scales[meta_col * row_blks + meta_row]); const int zp_pair = - (zero_points == nullptr) - ? 0x88 - : zero_points[meta_col * ((row_blks + 1) / 2) + meta_row / 2]; - const int zp0 = (meta_row & 1) ? (zp_pair >> 4) : (zp_pair & 0xf); - - const uint8_t vi0 = weights[j * q_rows + i / 2] & 0xf; - const float v0 = (static_cast(vi0) - zp0) * scale0; - - dst[j * rows + i] = static_cast(v0); - if ((i + 1) < r_end) { - float scale1 = scale0; - int zp1 = zp0; - if constexpr (QuantBlk::kRow == 1) { - scale1 = - static_cast(scales[meta_col * row_blks + meta_row + 1]); - zp1 = (zp_pair >> 4) & 0xf; - } - const uint8_t vi1 = weights[j * q_rows + i / 2] >> 4; - const float v1 = (static_cast(vi1) - zp1) * scale1; - dst[j * rows + (i + 1)] = static_cast(v1); - } + zero_points + ? zero_points[meta_col * ((row_blks + kPackSize - 1) / kPackSize) + meta_row / kPackSize] + : 0; + const int vi_pair = weights[j * q_rows + i / kPackSize]; + + const int zp = + zero_points + ? GetElem(zp_pair, meta_row % kPackSize) + : BitsTraits::kMid; + const int vi = GetElem(vi_pair, i % kPackSize); + const float v = (vi - zp) * scale; + dst[j * rows + i] = ElementT(v); } } }); @@ -1413,6 +1414,27 @@ MlasBlockwiseQuantizedShape( } } +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); template void @@ -1436,6 +1458,50 @@ MlasBlockwiseQuantMetaShape( int& meta_cols ); + template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantMetaShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& meta_rows, + int& meta_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + template void MlasBlockwiseQuantizedShape( @@ -1458,9 +1524,31 @@ MlasBlockwiseQuantizedShape( int& q_cols ); + template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + +template +void +MlasBlockwiseQuantizedShape( + int block_size, + bool columnwise, + int rows, + int columns, + int& q_rows, + int& q_cols + ); + +template void MLASCALL MlasBlockwiseQuantizedBufferSizes( - int qbits, int block_size, bool columnwise, int rows, @@ -1475,75 +1563,108 @@ MlasBlockwiseQuantizedBufferSizes( *q_zero_point_size_in_bytes = 0; } - if (qbits == 4) { - switch (block_size) { - case 16: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 32: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 64: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 128: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; - - case 256: - if (columnwise) { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } else { - BlockwiseQuantizer::quantizedBufferSizes( - rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes - ); - } - break; + switch (block_size) { + case 16: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; - default: - // Only block size 16, 32, 64, 128, 256 are supported. - break; - } + case 32: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 64: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 128: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + case 256: + if (columnwise) { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } else { + BlockwiseQuantizer::quantizedBufferSizes( + rows, columns, q_data_size_in_bytes, q_scale_num_elements, q_zero_point_size_in_bytes + ); + } + break; + + default: + // Only block size 16, 32, 64, 128, 256 are supported. + break; } } +template +void MLASCALL +MlasBlockwiseQuantizedBufferSizes<2>( + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + +template +void MLASCALL +MlasBlockwiseQuantizedBufferSizes<4>( + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); + +template +void MLASCALL +MlasBlockwiseQuantizedBufferSizes<8>( + int block_size, + bool columnwise, + int rows, + int columns, + size_t& q_data_size_in_bytes, + size_t& q_scale_num_elements, + size_t* q_zero_point_size_in_bytes +); template void @@ -1617,6 +1738,36 @@ MlasQuantizeBlockwise( } } +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + float* scales, + uint8_t* zero_points, + const float* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + +template +void +MlasQuantizeBlockwise( + uint8_t* dst, + MLAS_FP16* scales, + uint8_t* zero_points, + const MLAS_FP16* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + template void MlasQuantizeBlockwise( @@ -1647,6 +1798,35 @@ MlasQuantizeBlockwise( MLAS_THREADPOOL* thread_pool ); + template + void + MlasQuantizeBlockwise( + uint8_t* dst, + float* scales, + uint8_t* zero_points, + const float* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); + + template + void + MlasQuantizeBlockwise( + uint8_t* dst, + MLAS_FP16* scales, + uint8_t* zero_points, + const MLAS_FP16* src, + int block_size, + bool columnwise, + int rows, + int columns, + int leading_dimension, + MLAS_THREADPOOL* thread_pool + ); template void @@ -1714,6 +1894,32 @@ MlasDequantizeBlockwise( } } +template void +MlasDequantizeBlockwise( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasDequantizeBlockwise( + MLAS_FP16* dst, + const uint8_t* src, + const MLAS_FP16* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + template void MlasDequantizeBlockwise( float* dst, @@ -1727,6 +1933,45 @@ MlasDequantizeBlockwise( MLAS_THREADPOOL* thread_pool ); +template void +MlasDequantizeBlockwise( + MLAS_FP16* dst, + const uint8_t* src, + const MLAS_FP16* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasDequantizeBlockwise( + float* dst, + const uint8_t* src, + const float* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + +template void +MlasDequantizeBlockwise( + MLAS_FP16* dst, + const uint8_t* src, + const MLAS_FP16* scales, + const uint8_t* zero_points, + int block_size, + bool columnwise, + int rows, + int columns, + MLAS_THREADPOOL* thread_pool +); + template bool MlasQDQQuantizeBlockwise( diff --git a/src/lib/q4common.h b/src/lib/q4common.h index 74f8058..5febbd8 100644 --- a/src/lib/q4common.h +++ b/src/lib/q4common.h @@ -30,7 +30,6 @@ Module Name: } while (false) #endif - #include "mlas_q4.h" #include "mlasi.h" diff --git a/src/lib/q4gemm_avx512.cpp b/src/lib/q4gemm_avx512.cpp index fd71777..f7af82e 100644 --- a/src/lib/q4gemm_avx512.cpp +++ b/src/lib/q4gemm_avx512.cpp @@ -21,7 +21,6 @@ Module Name: #include #include -#include struct MLAS_FP_Q4_GEMM_KERNEL_AVX512VNNI { static constexpr size_t StrideM = 256; diff --git a/src/lib/qgemm.cpp b/src/lib/qgemm.cpp index 859fcd0..f5b33d2 100644 --- a/src/lib/qgemm.cpp +++ b/src/lib/qgemm.cpp @@ -144,14 +144,7 @@ MlasGemmBatch( const double Complexity = double(M) * double(N) * double(K) * double(BatchN); - ptrdiff_t TargetThreadCount; - - if (Complexity < double(MLAS_QGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { - TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; - } else { - TargetThreadCount = GetMlasPlatform().MaximumThreadCount; - } - + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_QGEMM_THREAD_COMPLEXITY)) + 1; ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); if (TargetThreadCount >= MaximumThreadCount) { @@ -308,7 +301,7 @@ size_t MLASCALL MlasGemmPackBSize( size_t N, - size_t K, + size_t K, bool AIsSigned, bool BIsSigned ) @@ -479,7 +472,7 @@ size_t MLASCALL MlasSymmQgemmPackBSize( size_t N, - size_t K, + size_t K, bool AIsSigned ) { diff --git a/src/lib/qgemm.h b/src/lib/qgemm.h index bcd878e..596267c 100644 --- a/src/lib/qgemm.h +++ b/src/lib/qgemm.h @@ -886,6 +886,14 @@ MlasGemmQuantGetDispatch( if(BIsSigned || !AIsSigned) { GemmQuantDispatch = &MlasGemmU8X8DispatchNeon; } +#elif defined(MLAS_TARGET_WASM_RELAXED_SIMD) + if (!AIsSigned) { + if (HasUSDot()) { + GemmQuantDispatch = &MlasGemmU8X8DispatchWasmRelaxedSimd; + } else { + GemmQuantDispatch = &MlasGemmU8X8DispatchWasmSimd; + } + } #elif defined(MLAS_TARGET_WASM_SIMD) if (!AIsSigned) { GemmQuantDispatch = &MlasGemmU8X8DispatchWasmSimd; diff --git a/src/lib/qgemm_kernel_wasmrelaxedsimd.cpp b/src/lib/qgemm_kernel_wasmrelaxedsimd.cpp new file mode 100644 index 0000000..a3a0fa7 --- /dev/null +++ b/src/lib/qgemm_kernel_wasmrelaxedsimd.cpp @@ -0,0 +1,563 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_wasmrelaxedsimd.cpp + +Abstract: + + This module implements QGEMM kernel for WebAssembly Relaxed SIMD128. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +bool HasUSDot() { +// Check out-of-bounds behavior of Relaxed Integer Dot Product with Accumulation with signed and unsigned input (e.g. vpdpbusd). + const v128_t int8_input = wasm_i8x16_const(0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0); + const volatile v128_t xint8_input = wasm_i8x16_const(0, 0, 0, -128, 0, 0, -128, 0, 0, -128, 0, 0, -128, 0, 0, 0); // volatile to confuse Clang which otherwise ICE's + const v128_t xint8_output = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(int8_input, xint8_input, wasm_i8x16_const_splat(0)); + + const volatile v128_t overflow_input = wasm_i8x16_const(-128, -128, -128, -128, -128, -128, -1, -1, -1, -1, -128, -128, -1, -1, -1, -1); // volatile to confuse Clang which otherwise ICE's + const v128_t overflow_output = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(wasm_i8x16_const_splat(-128), overflow_input, wasm_i8x16_const_splat(0)); + return !wasm_v128_any_true(wasm_v128_or( + wasm_v128_xor(xint8_output, wasm_i32x4_const_splat(128)), + wasm_v128_xor(overflow_output, wasm_i32x4_const(-65536, -98048, -98048, -130560)))); +} + +// wasm implementation of "_mm_unpacklo_epi8" +v128_t __attribute__((__always_inline__, __nodebug__)) wasm_i8x16_unpacklo_relaxed(v128_t a, v128_t b) { + return wasm_i8x16_shuffle(a, b, 0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23); +} + +// wasm implementation of "_mm_unpacklo_epi16" +v128_t __attribute__((__always_inline__, __nodebug__)) wasm_i16x8_unpacklo_relaxed(v128_t a, v128_t b) { + return wasm_i8x16_shuffle(a, b, 0, 1, 16, 17, 2, 3, 18, 19, 4, 5, 20, 21, 6, 7, 22, 23); +} + +// wasm implementation of "_mm_unpackhi_epi16" +v128_t __attribute__((__always_inline__, __nodebug__)) wasm_i16x8_unpackhi_relaxed(v128_t a, v128_t b) { + return wasm_i8x16_shuffle(a, b, 8, 9, 24, 25, 10, 11, 26, 27, 12, 13, 28, 29, 14, 15, 30, 31); +} + +struct MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD +{ + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef uint8_t OffsetAType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 4; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 12, 128, 128 }; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{0, 0, 0}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::Strides; + +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmQuantFixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (!BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template<> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned + ) +{ + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + const v128_t ZeroVector = wasm_i64x2_const(0, 0); + const v128_t OnesWordBroadcast = wasm_i16x8_splat(1); + uint8_t PaddedMatrixAData[8] = { 0 }; + + // + // Process a single row of matrix A in a loop. + // + + while (CountM > 0) { + + const uint8_t* a = A; + size_t k = CountK; + v128_t ReductionVector = ZeroVector; + + // + // Copy the source bytes to the packed buffer. + // + // The packed buffer has the same data ordering as the source bytes, + // but CountK is aligned up to a multiple of 4 to maintain 32-bit + // alignment. All extra bytes are zero-padded. + // + // Zero extend the source bytes to 16-bits and accumulate + // into an intermediate per-row + // accumulator. CountK cannot be greater than 128 to avoid overflowing + // these signed 16-bit accumulators. + // + + while (k >= 8) { + + v128_t Bytes = wasm_v128_load64_zero(&a[0]); + v128_t Words = wasm_i8x16_unpacklo_relaxed(Bytes, ZeroVector); + + ReductionVector = wasm_i16x8_add(ReductionVector, Words); + + wasm_v128_store64_lane(&D[0], Bytes, 0); + + a += 8; + D += 8; + k -= 8; + } + + if (k > 0) { + + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + uint8_t* padded = PaddedMatrixAData; + uint8_t* padded_end = padded + k; + + do { + padded[0] = a[0]; + padded++; + a++; + } while (padded < padded_end); + + v128_t Bytes = wasm_v128_load64_zero(PaddedMatrixAData); + v128_t Words = wasm_i8x16_unpacklo_relaxed(Bytes, ZeroVector); + + ReductionVector = wasm_i16x8_add(ReductionVector, Words); + + // + // Copy quads of 8-bit values from the vector to the packed + // buffer and rotate the vector for the next iteration. + // + + for (size_t quads = (k + 3) / 4; quads > 0; quads--) { + *((int32_t*)D) = wasm_i32x4_extract_lane(Bytes, 0); + D += 4; + Bytes = wasm_i32x4_shuffle(Bytes, wasm_i32x4_splat(0), 1, 2, 3, 0); + } + } + + // + // Reduce the partial accumulators. + // + + ReductionVector = wasm_i32x4_dot_i16x8(ReductionVector, OnesWordBroadcast); + ReductionVector = wasm_i32x4_add(ReductionVector, + wasm_i32x4_shuffle(ReductionVector, wasm_i32x4_splat(0), 2, 3, 2, 3)); + ReductionVector = wasm_i32x4_add(ReductionVector, + wasm_i32x4_shuffle(ReductionVector, wasm_i32x4_splat(0), 1, 0, 1, 0)); + + *RowSumBuffer++ = wasm_i32x4_extract_lane(ReductionVector, 0); + + A += lda; + CountM -= 1; + } +} + + +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackBProcessWasmRelaxedSimd( + MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* D, + v128_t BytesRow0, + v128_t BytesRow1, + v128_t BytesRow2, + v128_t BytesRow3, + v128_t BitFlipVector, + v128_t OnesByteBroadcast, + v128_t ColumnSums[2] +) +{ + v128_t PairsInterleaved0 = wasm_i8x16_unpacklo_relaxed(BytesRow0, BytesRow1); + v128_t PairsInterleaved1 = wasm_i8x16_unpacklo_relaxed(BytesRow2, BytesRow3); + + PairsInterleaved0 = wasm_v128_xor(PairsInterleaved0, BitFlipVector); + PairsInterleaved1 = wasm_v128_xor(PairsInterleaved1, BitFlipVector); + + v128_t QuadsInterleaved0 = wasm_i16x8_unpacklo_relaxed(PairsInterleaved0, PairsInterleaved1); + v128_t QuadsInterleaved1 = wasm_i16x8_unpackhi_relaxed(PairsInterleaved0, PairsInterleaved1); + + ColumnSums[0] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(QuadsInterleaved0, OnesByteBroadcast, ColumnSums[0]); + ColumnSums[1] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(QuadsInterleaved1, OnesByteBroadcast, ColumnSums[1]); + + wasm_v128_store(&D[0], QuadsInterleaved0); + wasm_v128_store(&D[16], QuadsInterleaved1); +} + +template<> +void +MlasGemmQuantCopyPackB( + MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + const v128_t OnesByteBroadcast = wasm_i8x16_splat(1); + const v128_t BitFlipVector = wasm_i32x4_splat(BIsSigned ? 0 : 0x80808080); + + // + // Process 8 columns of matrix B in a loop. + // + + while (CountN >= 8) { + + const uint8_t* b = B; + size_t k = CountK; + v128_t ColumnSums[2]; + + ColumnSums[0] = wasm_i64x2_const(0, 0); + ColumnSums[1] = wasm_i64x2_const(0, 0); + + // + // Interleave rows of matrix B and write to the packed buffer. + // + + while (k >= MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK) { + + v128_t BytesRow0 = wasm_v128_load64_zero(&b[0]); + v128_t BytesRow1 = wasm_v128_load64_zero(&b[ldb]); + v128_t BytesRow2 = wasm_v128_load64_zero(&b[ldb * 2]); + v128_t BytesRow3 = wasm_v128_load64_zero(&b[ldb * 3]); + + MlasGemmU8X8CopyPackBProcessWasmRelaxedSimd(D, BytesRow0, BytesRow1, BytesRow2, BytesRow3, BitFlipVector, OnesByteBroadcast, ColumnSums); + + b += ldb * 4; + D += 32; + k -= 4; + } + + if (k > 0) { + + v128_t BytesRow0 = wasm_v128_load64_zero(&b[0]); + v128_t BytesRow1 = BitFlipVector; + v128_t BytesRow2 = BitFlipVector; + v128_t BytesRow3 = BitFlipVector; + + if (k >= 2) { + BytesRow1 = wasm_v128_load64_zero(&b[ldb]); + } + + if (k >= 3) { + BytesRow2 = wasm_v128_load64_zero(&b[ldb * 2]); + } + + MlasGemmU8X8CopyPackBProcessWasmRelaxedSimd(D, BytesRow0, BytesRow1, BytesRow2, BytesRow3, BitFlipVector, OnesByteBroadcast, ColumnSums); + + D += 32; + } + + wasm_v128_store(&ColumnSumBuffer[0], ColumnSums[0]); + wasm_v128_store(&ColumnSumBuffer[4], ColumnSums[1]); + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + + const uint8_t* b = B; + size_t k = CountK; + v128_t ColumnSums[2]; + uint8_t PaddedMatrixBData[32]; + + wasm_v128_store(&PaddedMatrixBData[0], BitFlipVector); + wasm_v128_store(&PaddedMatrixBData[16], BitFlipVector); + + ColumnSums[0] = wasm_i64x2_const(0, 0); + ColumnSums[1] = wasm_i64x2_const(0, 0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k >= MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK) { + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + do { + padded[0] = bcopy[0]; + padded[8] = bcopy[ldb]; + padded[16] = bcopy[ldb * 2]; + padded[24] = bcopy[ldb * 3]; + padded++; + bcopy++; + } while (padded < padded_end); + + v128_t BytesRow0 = wasm_v128_load64_zero(&PaddedMatrixBData[0]); + v128_t BytesRow1 = wasm_v128_load64_zero(&PaddedMatrixBData[8]); + v128_t BytesRow2 = wasm_v128_load64_zero(&PaddedMatrixBData[16]); + v128_t BytesRow3 = wasm_v128_load64_zero(&PaddedMatrixBData[24]); + + MlasGemmU8X8CopyPackBProcessWasmRelaxedSimd(D, BytesRow0, BytesRow1, BytesRow2, BytesRow3, BitFlipVector, OnesByteBroadcast, ColumnSums); + + b += ldb * 4; + D += 32; + k -= 4; + } + + if (k > 0) { + + const uint8_t* bcopy = b; + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + + wasm_v128_store(&PaddedMatrixBData[0], BitFlipVector); + wasm_v128_store(&PaddedMatrixBData[16], BitFlipVector); + + if (k == 3) { + do { + padded[0] = bcopy[0]; + padded[8] = bcopy[ldb]; + padded[16] = bcopy[ldb * 2]; + padded++; + bcopy++; + } while (padded < padded_end); + } else if (k == 2) { + do { + padded[0] = bcopy[0]; + padded[8] = bcopy[ldb]; + padded++; + bcopy++; + } while (padded < padded_end); + } else { + do { + padded[0] = bcopy[0]; + padded++; + bcopy++; + } while (padded < padded_end); + } + + v128_t BytesRow0 = wasm_v128_load64_zero(&PaddedMatrixBData[0]); + v128_t BytesRow1 = wasm_v128_load64_zero(&PaddedMatrixBData[8]); + v128_t BytesRow2 = wasm_v128_load64_zero(&PaddedMatrixBData[16]); + v128_t BytesRow3 = wasm_v128_load64_zero(&PaddedMatrixBData[24]); + + MlasGemmU8X8CopyPackBProcessWasmRelaxedSimd(D, BytesRow0, BytesRow1, BytesRow2, BytesRow3, BitFlipVector, OnesByteBroadcast, ColumnSums); + } + + wasm_v128_store(&ColumnSumBuffer[0], ColumnSums[0]); + wasm_v128_store(&ColumnSumBuffer[4], ColumnSums[1]); + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd( + v128_t ABroadcast, + const uint8_t* B, + v128_t Accumulators[2] +) +{ + v128_t BElements0 = wasm_v128_load(&B[0]); + v128_t BElements1 = wasm_v128_load(&B[16]); + + Accumulators[0] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BElements0, ABroadcast, Accumulators[0]); + Accumulators[1] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BElements1, ABroadcast, Accumulators[1]); +} + + +template<> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + MLAS_UNREFERENCED_PARAMETER(CountM); + MLAS_UNREFERENCED_PARAMETER(ldc); + + while (CountN > 0) { + + v128_t Accumulators[2]; + + // + // Initialize the accumulators with the row and column sums. + // + + int32_t RowSumValue = RowSumBuffer[0]; + + if (ZeroPointB != nullptr) { + + int32_t ScaledRowSumBuffer[8]; + + for (size_t i = 0; i < 8; i++) { + ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; + } + + ZeroPointB += 8; + + Accumulators[0] = wasm_v128_load(&ScaledRowSumBuffer[0]); + Accumulators[1] = wasm_v128_load(&ScaledRowSumBuffer[4]); + + } + else { + + Accumulators[0] = wasm_i32x4_splat(RowSumValue); + Accumulators[1] = Accumulators[0]; + } + + Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&ColumnSumBuffer[0])); + Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&ColumnSumBuffer[4])); + ColumnSumBuffer += 8; + + // + // Broadcast each pair of 16-bit values from the matrix A and multiply + // with the pair of 16-bit values from matrix B, and add the 32-bit + // intermediate into the accumulator registers. + // + + const uint8_t* a = A; + size_t k = PackedCountK; + + while (k >= 4) { + + v128_t AElements = wasm_v128_load((v128_t*)a); + v128_t ABroadcast; + + ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 0, 0, 0, 0); + MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[0], Accumulators); + + ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 1, 1, 1, 1); + MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[32], Accumulators); + + ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 2, 2, 2, 2); + MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[64], Accumulators); + + ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 3, 3, 3, 3); + MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[96], Accumulators); + + a += 4 * 4; + B += 4 * 32; + k -= 4; + } + + while (k > 0) { + + v128_t ABroadcast = wasm_i32x4_splat(*((int32_t*)a)); + MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[0], Accumulators); + + a += 4; + B += 32; + k -= 1; + } + + // + // Output the accumulator block after optionally accumulating the values + // from matrix C. + // + + if (CountN >= 8) { + + if (!ZeroMode) { + Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); + Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&C[4])); + } + + wasm_v128_store(&C[0], Accumulators[0]); + wasm_v128_store(&C[4], Accumulators[1]); + + C += 8; + CountN -= 8; + + } + else { + + // + // Output the remaining partial output block. + // + + if ((CountN & 4) != 0) { + + if (!ZeroMode) { + Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); + } + + wasm_v128_store(&C[0], Accumulators[0]); + C += 4; + + Accumulators[0] = Accumulators[1]; + } + + if ((CountN & 2) != 0) { + + if (!ZeroMode) { + Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load64_zero(&C[0])); + } + + wasm_v128_store64_lane(&C[0], Accumulators[0], 0); + C += 2; + + Accumulators[0] = wasm_i32x4_shuffle(Accumulators[0], wasm_i32x4_splat(0), 2, 3, 2, 3); + } + + if ((CountN & 1) != 0) { + + int32_t AccumulatorValue = wasm_i32x4_extract_lane(Accumulators[0], 0); + + if (!ZeroMode) { + AccumulatorValue += C[0]; + } + + C[0] = AccumulatorValue; + } + + CountN = 0; + } + } + + return 1; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmRelaxedSimd = { + MlasGemmQuantOperation, + nullptr, + nullptr, + MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK, + 0, + 4 // multiple of kernel stride M +}; diff --git a/src/lib/qnbitgemm.cpp b/src/lib/qnbitgemm.cpp index f064a8e..19d11a6 100644 --- a/src/lib/qnbitgemm.cpp +++ b/src/lib/qnbitgemm.cpp @@ -28,10 +28,11 @@ enum QNBitGemmVariant { // Valid variants - SQNBitGemmVariant_BitWidth4_CompFp32 = 0, - SQNBitGemmVariant_BitWidth4_CompInt8, - HQNBitGemmVariant_BitWidth4_CompFp16, - HQNBitGemmVariant_BitWidth4_CompInt8, + SQ4BitGemmVariant_CompFp32 = 0, + SQ4BitGemmVariant_CompInt8, + HQ4BitGemmVariant_CompFp16, + HQ4BitGemmVariant_CompInt8, + SQ8BitGemmVariant_CompInt8, // End of valid variants @@ -47,16 +48,21 @@ GetQNBitGemmVariant( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - if (BlkBitWidth == 4 && - (BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { - if (ComputeType == SQNBIT_CompFp32) { - return SQNBitGemmVariant_BitWidth4_CompFp32; - } else if (ComputeType == HQNBIT_CompFp16) { - return HQNBitGemmVariant_BitWidth4_CompFp16; - } else if (ComputeType == SQNBIT_CompInt8) { - return SQNBitGemmVariant_BitWidth4_CompInt8; - } else if (ComputeType == HQNBIT_CompInt8) { - return HQNBitGemmVariant_BitWidth4_CompInt8; + if ((BlkLen == 16 || BlkLen == 32 || BlkLen == 64 || BlkLen == 128 || BlkLen == 256)) { + if (BlkBitWidth == 4) { + if (ComputeType == SQNBIT_CompFp32) { + return SQ4BitGemmVariant_CompFp32; + } else if (ComputeType == HQNBIT_CompFp16) { + return HQ4BitGemmVariant_CompFp16; + } else if (ComputeType == SQNBIT_CompInt8) { + return SQ4BitGemmVariant_CompInt8; + } else if (ComputeType == HQNBIT_CompInt8) { + return HQ4BitGemmVariant_CompInt8; + } + } else if (BlkBitWidth == 8) { + if (ComputeType == SQNBIT_CompInt8) { + return SQ8BitGemmVariant_CompInt8; + } } } @@ -80,20 +86,26 @@ MlasIsQNBitGemmAvailable( const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType); switch (Variant) { - case SQNBitGemmVariant_BitWidth4_CompFp32: { + case SQ4BitGemmVariant_CompFp32: { return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr && Dispatch->SQ4BitBlkDequantBForSgemm_CompFp32 != nullptr; } - case HQNBitGemmVariant_BitWidth4_CompFp16: { + case HQ4BitGemmVariant_CompFp16: { return Dispatch->HQ4BitGemmPackQuantBData != nullptr && Dispatch->HQ4BitGemmKernel_CompFp16 != nullptr && Dispatch->HQ4BitBlkDequantBForHgemm_CompFp16 != nullptr; } - case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 + case SQ4BitGemmVariant_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8 return + (Dispatch->SQ4BitGemmKernel_Packed_CompInt8 != nullptr && Dispatch->QuantizeA_Packed_CompInt8 != nullptr) || (Dispatch->SQ4BitGemmKernel_CompInt8 != nullptr && Dispatch->QuantizeARow_CompInt8 != nullptr) || (Dispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr && Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr); } + case SQ8BitGemmVariant_CompInt8: { + return Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr && + Dispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr && + Dispatch->QuantizeARowComputeBlkSum_CompInt8 != nullptr; + } default: { return false; } @@ -110,16 +122,17 @@ QNBitGemmPerGemmWorkspaceSize( size_t K, size_t BlkBitWidth, size_t BlkLen, + bool HasZeroPoint, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; - if (Dispatch == nullptr) { + if (Dispatch == nullptr || Dispatch->QNBitGemmPerGemmWorkspaceSize == nullptr) { return 0; } - if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceSize != nullptr) { - return Dispatch->Q4BitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, ComputeType); + if (BlkBitWidth == 4 || BlkBitWidth == 8) { + return Dispatch->QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType); } return 0; @@ -133,12 +146,12 @@ QNBitGemmPerGemmWorkspaceAlignment( ) { const auto* Dispatch = GetMlasPlatform().QNBitGemmDispatch; - if (Dispatch == nullptr) { + if (Dispatch == nullptr || Dispatch->QNBitGemmPerGemmWorkspaceAlignment == nullptr) { return 1; } - if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPerGemmWorkspaceAlignment != nullptr) { - return Dispatch->Q4BitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); + if (BlkBitWidth == 4 || BlkBitWidth == 8) { + return Dispatch->QNBitGemmPerGemmWorkspaceAlignment(BlkLen, ComputeType); } return 1; @@ -151,10 +164,11 @@ QNBitGemmPerGemmWorkspaceStride( size_t K, size_t BlkBitWidth, size_t BlkLen, + bool HasZeroPoint, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const auto Size = QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const auto Size = QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkBitWidth, BlkLen, HasZeroPoint, ComputeType); const auto Alignment = QNBitGemmPerGemmWorkspaceAlignment(BlkBitWidth, BlkLen, ComputeType); return MlasDivRoundup(Size, Alignment) * Alignment; } @@ -169,10 +183,12 @@ MlasQNBitGemmBatchWorkspaceSize( size_t BatchN, size_t BlkBitWidth, size_t BlkLen, + bool HasZeroPoint, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - const size_t PerGemmWorkspaceStride = QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const size_t PerGemmWorkspaceStride = + QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, HasZeroPoint, ComputeType); if (PerGemmWorkspaceStride == 0) { return 0; } @@ -190,6 +206,7 @@ MlasQNBitGemmPackQuantBDataSize( size_t K, size_t BlkBitWidth, size_t BlkLen, + bool HasZeroPoint, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { @@ -200,7 +217,11 @@ MlasQNBitGemmPackQuantBDataSize( if (BlkBitWidth == 4 && Dispatch->Q4BitGemmPackQuantBDataSize != nullptr) { return Dispatch->Q4BitGemmPackQuantBDataSize( - N, K, BlkLen, ComputeType + N, K, BlkLen, HasZeroPoint, ComputeType + ); + } else if (BlkBitWidth == 8 && Dispatch->Q8BitGemmPackQuantBDataSize != nullptr) { + return Dispatch->Q8BitGemmPackQuantBDataSize( + N, K, BlkLen, HasZeroPoint, ComputeType ); } @@ -232,7 +253,7 @@ MlasQNBitGemmPackQuantBData( const void* QuantBData, void* PackedQuantBDataAndOrBlkSumWorkspace, const void* QuantBScale, - bool has_zp_input, + bool HasZeroPoint, const void* QuantBZeroPoint, MLAS_THREADPOOL* ThreadPool ) @@ -245,7 +266,7 @@ MlasQNBitGemmPackQuantBData( if (BlkBitWidth == 4) { if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, @@ -253,7 +274,7 @@ MlasQNBitGemmPackQuantBData( ComputeType, static_cast(QuantBData), static_cast(QuantBScale), - has_zp_input, + HasZeroPoint, static_cast(QuantBZeroPoint), packed_quant_b, ThreadPool @@ -283,7 +304,47 @@ MlasQNBitGemmPackQuantBData( ); return; } + } else if (BlkBitWidth == 8) { + if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr) { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum( + N, + K, + BlkLen, + ComputeType, + static_cast(QuantBData), + static_cast(QuantBScale), + HasZeroPoint, + static_cast(QuantBZeroPoint), + packed_quant_b, + ThreadPool + ); + } + } +} + +bool MLASCALL +MlasQNBitGemmScalesPacked( + size_t K, + size_t BlkBitWidth, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + bool HasZeroPoint +) { +#ifdef MLAS_TARGET_ARM64 + if (BlkBitWidth == 4 && ComputeType == SQNBIT_CompInt8) { + const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8; + return UsePacked && UsePacked(K, BlkLen, HasZeroPoint); } +#else + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); + MLAS_UNREFERENCED_PARAMETER(BlkLen); + MLAS_UNREFERENCED_PARAMETER(ComputeType); + MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); +#endif // MLAS_TARGET_ARM64 + return false; } namespace @@ -519,6 +580,16 @@ SQ4BitGemm_CompInt8( const size_t RangeCountN ) { + const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8; + const auto SQ4BitGemm = GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_Packed_CompInt8; + if (UsePacked && SQ4BitGemm && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) { + const std::byte* QuantA = static_cast(PerGemmWorkspace); + SQ4BitGemm(BlkLen, QuantA, DataParams->PackedQuantBData, + DataParams->C, RangeStartM, RangeCountM, RangeStartN, RangeCountN, K, + DataParams->ldc, DataParams->Bias); + return; + } + #ifdef MLAS_TARGET_AMD64_IX86 PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); constexpr size_t BlkBitWidth = 4; @@ -636,6 +707,86 @@ SQ4BitGemm_CompInt8( } } +void +SQ8BitGemm_CompInt8( + const size_t BlkLen, + const size_t K, + const MLAS_QNBIT_GEMM_DATA_PARAMS* const DataParams, + void* const PerGemmWorkspace, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN +) +{ + PerGemmQuantAWorkspace* const per_gemm_quant_a_workspace = static_cast(PerGemmWorkspace); + constexpr size_t BlkBitWidth = 8; + + const size_t k_blks = MlasDivRoundup(K, BlkLen); + + // quant A scale is embedded in QuantData if QuantScale is nullptr. + const size_t lda = k_blks * (per_gemm_quant_a_workspace->QuantScale ? BlkLen : Q8BlkSize(BlkLen)); + const size_t ldc = DataParams->ldc; + const size_t ldb = k_blks * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t k_blks_zp_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(k_blks); + + const std::byte* QuantA = per_gemm_quant_a_workspace->QuantData + RangeStartM * lda; + const float* QuantAScale = per_gemm_quant_a_workspace->QuantScale + RangeStartM * k_blks; + + assert(RangeStartN % 16 == 0); + const std::byte* QuantBData = static_cast(DataParams->PackedQuantBData) + RangeStartN * ldb; + const float* QuantBScale = DataParams->QuantBScale + RangeStartN * k_blks; + const std::byte* QuantBZeroPoint = + (DataParams->QuantBZeroPoint == nullptr) + ? nullptr + : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; + const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; + const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + + const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; + + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + CountN = std::min(RangeCountN - n, size_t{128}); + + const std::byte* b_col = QuantBData + n * ldb; + const float* b_col_scale = QuantBScale + n * k_blks; + const std::byte* b_col_zp = + (QuantBZeroPoint == nullptr) ? nullptr : QuantBZeroPoint + n * k_blks_zp_bytes; + float* c_blk = C + n; + const float* bias = (Bias == nullptr) ? nullptr : Bias + n; + + if (GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { + const float* b_blk_sum = QuantBBlkSum + n * k_blks; + GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8( + BlkLen, + QuantA, + QuantAScale, + b_col, + b_col_scale, + b_col_zp, + c_blk, + RangeCountM, + CountN, + K, + k_blks, + bias, + ldc, + ABlockSum, + b_blk_sum + ); + + if (DataParams->PostProcessor != nullptr) { + DataParams->PostProcessor->Process( + DataParams->C, RangeStartM, RangeStartN + n, + RangeCountM, CountN, ldc + ); + } + } + } +} + template void InitializeWorkspace_CompInt8( @@ -666,6 +817,8 @@ InitializeWorkspace_CompInt8( { MLAS_UNREFERENCED_PARAMETER(N); + const auto UsePacked = GetMlasPlatform().QNBitGemmDispatch->UsePacked_CompInt8; + const auto QuantizeA_Packed = GetMlasPlatform().QNBitGemmDispatch->QuantizeA_Packed_CompInt8; const auto QuantizeARow = GetMlasPlatform().QNBitGemmDispatch->QuantizeARow_CompInt8; const auto QuantizeARow2 = GetMlasPlatform().QNBitGemmDispatch->QuantizeARowComputeBlkSum_CompInt8; @@ -673,7 +826,15 @@ InitializeWorkspace_CompInt8( const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); // TODO: try parallel on BatchN * M threads because BatchN is usually 1. - if (QuantizeARow) { + if (UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + QuantizeA_Packed(BlkLen, ARowPtr, M, K, QuantARowPtr); + }); + } else if (QuantizeARow) { MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { const auto& data = DataParams[gemm_idx]; @@ -753,7 +914,8 @@ InitializeWorkspaceFn GetInitializeWorkspace(QNBitGemmVariant variant) { switch (variant) { - case SQNBitGemmVariant_BitWidth4_CompInt8: + case SQ4BitGemmVariant_CompInt8: + case SQ8BitGemmVariant_CompInt8: return InitializeWorkspace_CompInt8; default: return nullptr; @@ -765,7 +927,7 @@ InitializeWorkspaceFn GetInitializeWorkspace(QNBitGemmVariant variant) { switch (variant) { - case HQNBitGemmVariant_BitWidth4_CompInt8: + case HQ4BitGemmVariant_CompInt8: return InitializeWorkspace_CompInt8; default: return nullptr; @@ -793,10 +955,12 @@ QNBitGemmFn GetQNBitGemm(QNBitGemmVariant variant) { switch (variant) { - case SQNBitGemmVariant_BitWidth4_CompFp32: + case SQ4BitGemmVariant_CompFp32: return SQ4BitGemm_CompFp32; - case SQNBitGemmVariant_BitWidth4_CompInt8: + case SQ4BitGemmVariant_CompInt8: return SQ4BitGemm_CompInt8; + case SQ8BitGemmVariant_CompInt8: + return SQ8BitGemm_CompInt8; default: return nullptr; } @@ -807,7 +971,7 @@ QNBitGemmFn GetQNBitGemm(QNBitGemmVariant variant) { switch (variant) { - case HQNBitGemmVariant_BitWidth4_CompFp16: + case HQ4BitGemmVariant_CompFp16: return HQ4BitGemm_CompFp16; default: return nullptr; @@ -844,7 +1008,9 @@ MlasQNBitGemmBatch( ); } - const size_t PerGemmWorkspaceStride = QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, ComputeType); + const bool has_zp_input = DataParams->QuantBZeroPoint; + const size_t PerGemmWorkspaceStride = + QNBitGemmPerGemmWorkspaceStride(M, N, K, BlkBitWidth, BlkLen, has_zp_input, ComputeType); if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace(Variant); InitializeWorkspaceOperation != nullptr) { @@ -862,8 +1028,15 @@ MlasQNBitGemmBatch( const auto* Data = &DataParams[gemm_i]; void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); + } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; @@ -933,8 +1106,16 @@ MlasQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; - if (ComputeType == SQNBIT_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; + const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; + const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); + ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); + } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; diff --git a/src/lib/qnbitgemm_kernel_neon.cpp b/src/lib/qnbitgemm_kernel_neon.cpp index d05de64..0d06eb0 100644 --- a/src/lib/qnbitgemm_kernel_neon.cpp +++ b/src/lib/qnbitgemm_kernel_neon.cpp @@ -15,14 +15,23 @@ Module Name: --*/ +#include "qnbitgemm_kernel_neon.h" + #include #include +#include #include "qnbitgemm.h" -#include "qnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" +#ifdef USE_KLEIDIAI +#include "kai/kai_common.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_ukernel_interface.h" +#endif + namespace sqnbitgemm_neon { @@ -38,16 +47,31 @@ Q4BitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, + bool HasZeroPoint, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { +#ifndef USE_KLEIDIAI + MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType - - constexpr size_t BlkBitWidth = 4; - - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; +#endif + +#ifdef USE_KLEIDIAI + if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = GetKleidiAIGemmUKernel(); + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); + } else +#endif + { + constexpr size_t BlkBitWidth = 4; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; + } } void @@ -121,27 +145,98 @@ SQ4BitGemmPackQuantBData( ); } +void +SQ4BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte*, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ +#ifndef USE_KLEIDIAI + MLAS_UNREFERENCED_PARAMETER(QuantBScaleBegin); + MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); +#endif + assert(BlkLen >= 16 && BlkLen % 16 == 0); + +#ifdef USE_KLEIDIAI + if (UseKleidiAI(K, BlkLen, HasZeroPoint)) { + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = GetKleidiAIGemmUKernel(); + std::byte* PackedQuantBDataBegin = PackedQuantB.PackedQuantBData; + + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + + kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + params.scale_dt = kai_dt_bf16; + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t scales_len = N * BlockCountK; + std::vector scales(scales_len); + for (size_t i = 0; i < scales_len; i++) { + const uint32_t* i32 = reinterpret_cast(&QuantBScaleBegin[i]); + scales[i] = *i32 >> 16; + } + + kai_run_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(1, N, K, nr, kr, sr, BlkLen, + reinterpret_cast(QuantBDataBegin), BlockCountK * BlkLen / 2, + nullptr, scales.data(), BlockCountK * sizeof(uint16_t), + PackedQuantBDataBegin, 0, ¶ms); + } else +#endif + { + std::byte* PackedQuantBDataBegin = reinterpret_cast(PackedQuantB.QuantBWorkspace_); + SQ4BitGemmPackQuantBData(N, K, BlkLen, ComputeType, QuantBDataBegin, PackedQuantBDataBegin, ThreadPool); + } +} + // // Workspace size calculation function implementation. // size_t -Q4BitGemmPerGemmWorkspaceSize( +QNBitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkLen, + bool HasZeroPoint, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { MLAS_UNREFERENCED_PARAMETER(N); +#ifndef USE_KLEIDIAI + MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); +#endif switch (ComputeType) { case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 - const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); - return PerGemmWorkspaceSize; +#ifdef USE_KLEIDIAI + if (UseKleidiAI(K, BlkLen, HasZeroPoint)) { + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = + M == 1? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); + + const size_t mr = ukernel.get_mr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + return kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(M, K, mr, kr, sr); + } else +#endif + { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + return PerGemmWorkspaceSize; + } } default: { return 0; @@ -150,7 +245,7 @@ Q4BitGemmPerGemmWorkspaceSize( } size_t -Q4BitGemmPerGemmWorkspaceAlignment( +QNBitGemmPerGemmWorkspaceAlignment( size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) @@ -169,33 +264,66 @@ Q4BitGemmPerGemmWorkspaceAlignment( } // namespace +bool +UseKleidiAI(size_t K, size_t BlkLen, bool HasZp) +{ +#ifdef USE_KLEIDIAI + bool has_dotprod = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot(); + return (BlkLen % 32) == 0 && (K % BlkLen) == 0 && !HasZp && has_dotprod; +#else + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(BlkLen); + MLAS_UNREFERENCED_PARAMETER(HasZp); + return false; +#endif +} + } // namespace sqnbitgemm_neon // -// Kernel dispatch structure definition. +// Kernel dispatch structure accessor. // -const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() { - MLAS_QNBIT_GEMM_DISPATCH d; +const MLAS_QNBIT_GEMM_DISPATCH& +GetMlasQNBitGemmDispatchNeon( + bool InitializeWithDotSupport +) +{ + // Note: The InitializeWithX parameters are only used in the invocation of this method that initializes the static + // MLAS_QNBIT_GEMM_DISPATCH instance. - d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q4BitGemmPackQuantBDataSize; - d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; + static const MLAS_QNBIT_GEMM_DISPATCH MlasQNBitGemmDispatchNeon = [&]() { + MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceSize; - d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment; + d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q4BitGemmPackQuantBDataSize; + d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; + d.SQ4BitGemmPackQuantBDataAndBlkSum = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataAndBlkSum; - d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; - d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32; - if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeonDot()) { - d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; - } - d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; + d.QNBitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceAlignment; + + d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32; + + if (InitializeWithDotSupport) { + d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8; + d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; + d.UsePacked_CompInt8 = sqnbitgemm_neon::UsePacked_CompInt8; + +#ifdef USE_KLEIDIAI + d.SQ4BitGemmKernel_Packed_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_Packed_CompInt8; + d.QuantizeA_Packed_CompInt8 = sqnbitgemm_neon::QuantizeA_Packed_CompInt8; +#endif + } #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) - d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16; - d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16; - d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; + d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16; + d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16; + d.HQ4BitGemmKernel_CompFp16 = sqnbitgemm_neon::HQ4BitGemmKernel_CompFp16; #endif // MLAS_F16VEC_INTRINSICS_SUPPORTED && MLAS_TARGET_ARM64 - return d; -}(); + return d; + }(); + + return MlasQNBitGemmDispatchNeon; +} diff --git a/src/lib/qnbitgemm_kernel_neon.h b/src/lib/qnbitgemm_kernel_neon.h index ccadd24..a254ec9 100644 --- a/src/lib/qnbitgemm_kernel_neon.h +++ b/src/lib/qnbitgemm_kernel_neon.h @@ -23,6 +23,7 @@ Module Name: #include #include +#include "mlas_qnbit.h" #include "mlasi.h" namespace sqnbitgemm_neon @@ -107,6 +108,13 @@ HQ4BitGemmKernel_CompFp16( // SQNBIT_CompInt8 declarations +bool +UsePacked_CompInt8( + size_t K, + size_t BlkLen, + bool HasZp +); + void QuantizeARow_CompInt8( size_t BlkLen, @@ -131,6 +139,35 @@ SQ4BitGemmKernel_CompInt8( const float* Bias ); +#ifdef USE_KLEIDIAI +void +QuantizeA_Packed_CompInt8( + size_t BlkLen, + const float* A, + size_t CountM, + size_t CountK, + std::byte* QuantA +); + +void +SQ4BitGemmKernel_Packed_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* PackedQuantBData, + float* C, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN, + size_t CountK, + size_t ldc, + const float *Bias +); +#endif + +bool +UseKleidiAI(size_t K, size_t BlkLen, bool HasZp); + // // General helpers. // diff --git a/src/lib/quantize.cpp b/src/lib/quantize.cpp index ae638fa..fad174f 100644 --- a/src/lib/quantize.cpp +++ b/src/lib/quantize.cpp @@ -1704,8 +1704,8 @@ MlasRequantizeOutput( float min_f = float(std::numeric_limits::lowest() - ZeroPoint); float max_f = float(std::numeric_limits::max() - ZeroPoint); const __m128 PerMatrixScaleVector = PerColumnScale ? MlasReinterpretAsFloat32x4(__lsx_vldi(0)) : MlasReinterpretAsFloat32x4(__lsx_vldrepl_w(Scale, 0)); - const __m128 MinimumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&min_f))); - const __m128 MaximumValueVector = MlasReinterpretAsFloat32x4(__lsx_vreplgr2vr_w( *((uint32_t*)&max_f))); + const __m128 MinimumValueVector = MlasReinterpretAsFloat32x4((__m128i)(v4f32){min_f,min_f,min_f,min_f}); + const __m128 MaximumValueVector = MlasReinterpretAsFloat32x4((__m128i)(v4f32){max_f,max_f,max_f,max_f}); const __m128i ZeroPointVector = __lsx_vreplgr2vr_w(ZeroPoint); if (nullptr != Bias) { diff --git a/src/lib/rotary_embedding.cpp b/src/lib/rotary_embedding.cpp new file mode 100644 index 0000000..63e0a7f --- /dev/null +++ b/src/lib/rotary_embedding.cpp @@ -0,0 +1,108 @@ +/*++ + +Copyright (c) Intel Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding.cpp + +Abstract: + + This module implements rotary embedding kernels for fp32/16. + +--*/ + +#include "rotary_embedding.h" + +template +void +MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const T* input_data, + const T* sin_data, + const T* cos_data, + size_t rotary_emb_dim, + bool interleaved, + T* output_data +) { + const size_t half_rotary_emb_dim = rotary_emb_dim / 2; + size_t cache_idx = 0; + bool sign = false; + size_t j = 0; + for (size_t i = 0; i < rotary_emb_dim; i++) { + if (interleaved) { + cache_idx = (i / 2) % half_rotary_emb_dim; + sign = i & 1; + j = sign ? i - 1 : i + 1; // i - sign + } else { + cache_idx = i % half_rotary_emb_dim; + sign = (i >= half_rotary_emb_dim); + j = (i + half_rotary_emb_dim) % rotary_emb_dim; + } + float output_data_i = static_cast(input_data[i]) * static_cast(cos_data[cache_idx]); + float input_data_j = static_cast(input_data[j]); + float sin_data_cache_idx = static_cast(sin_data[cache_idx]); + if (sign) { + output_data_i += input_data_j * sin_data_cache_idx; + } else { + output_data_i -= input_data_j * sin_data_cache_idx; + } + output_data[i] = static_cast(output_data_i); + } +} + +template <> +void +MLASCALL +MlasRotaryEmbedOneRow( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output +) { + const auto* dispatch = GetMlasPlatform().RopeDispatch; + + if (dispatch == nullptr || dispatch->SRope == nullptr) { + MlasRotaryEmbedOneRow_FallBack(input, sin_data, cos_data, dim, interleaved, output); + return; + } + + dispatch->SRope(input, sin_data, cos_data, dim, interleaved, output); +} + +template <> +void +MLASCALL +MlasRotaryEmbedOneRow( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + bool interleaved, + MLAS_FP16* output +) { + const auto* dispatch = GetMlasPlatform().RopeDispatch; + + if (dispatch == nullptr || dispatch->HRope == nullptr) { + MlasRotaryEmbedOneRow_FallBack(input, sin_data, cos_data, dim, interleaved, output); + return; + } + + dispatch->HRope(input, sin_data, cos_data, dim, interleaved, output); +} + +template +void +MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const float* input_data, + const float* sin_data, + const float* cos_data, + size_t rotary_emb_dim, + bool interleaved, + float* output_data +); diff --git a/src/lib/rotary_embedding.h b/src/lib/rotary_embedding.h new file mode 100644 index 0000000..c017ece --- /dev/null +++ b/src/lib/rotary_embedding.h @@ -0,0 +1,57 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + implementing rotary embedding. + +--*/ + +#pragma once + +#include "mlasi.h" + +struct MLAS_ROPE_DISPATCH { + // rotary embedding kernel for fp32 + typedef void(SRope_Fn)( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output + ); + + SRope_Fn* SRope = nullptr; + + // rotary embedding kernel for fp16 + typedef void(HRope_Fn)( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + bool interleaved, + MLAS_FP16* output + ); + + HRope_Fn* HRope = nullptr; +}; + +template +void MLASCALL +MlasRotaryEmbedOneRow_FallBack( + const T* input_data, + const T* sin_data, + const T* cos_data, + size_t rotary_emb_dim, + bool interleaved, + T* output_data +); diff --git a/src/lib/rotary_embedding_kernel_avx2.cpp b/src/lib/rotary_embedding_kernel_avx2.cpp new file mode 100644 index 0000000..024e67d --- /dev/null +++ b/src/lib/rotary_embedding_kernel_avx2.cpp @@ -0,0 +1,308 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_avx2.cpp + +Abstract: + + This module implements the rotary embedding kernels for AVX2 supported h/w. + +--*/ + + +#include + +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_avx2.h" + +namespace rope_avx2 { + +namespace { + +typedef __m256 float32x8_t; +static constexpr int32_t mask_buffer[16] = {-1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0}; + +template +void +RopeKernel_Avx2_fp16_Impl( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + MLAS_FP16* output +); + +float32x8_t +load_fp16_and_convert_to_fp32(const MLAS_FP16* input) +{ + __m128i fp16 = _mm_lddqu_si128(reinterpret_cast(input)); + return _mm256_cvtph_ps(fp16); +} + +void +convert_to_fp16_and_store(MLAS_FP16* dst_fp16, const __m256 output) +{ + __m128i fp16_chunk = _mm256_cvtps_ph(output, _MM_FROUND_TO_NEAREST_INT); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst_fp16), fp16_chunk); +} + +template <> +void +RopeKernel_Avx2_fp16_Impl( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + MLAS_FP16* output +) +{ + // ?cast input -> const unsigned short* + const size_t half_dim = dim >> 1; + size_t i = 0, j = half_dim; + for (; i + 7 < half_dim; i += 8, j += 8) { + float32x8_t real = load_fp16_and_convert_to_fp32(input + i); + float32x8_t imag = load_fp16_and_convert_to_fp32(input + j); + float32x8_t sin_val = load_fp16_and_convert_to_fp32(sin_data + i); + float32x8_t cos_val = load_fp16_and_convert_to_fp32(cos_data + i); + // Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + // Store back into non interleaved format + convert_to_fp16_and_store(output + i, real_out); + convert_to_fp16_and_store(output + j, imag_out); + } + for (; i < half_dim; i++, j++) { + float real = input[i].ToFloat(); + float imag = input[j].ToFloat(); + float sin_val = sin_data[i]; + float cos_val = cos_data[i]; + output[i] = MLAS_FP16(real * cos_val - imag * sin_val); + output[j] = MLAS_FP16(real * sin_val + imag * cos_val); + } +} + +template <> +void +RopeKernel_Avx2_fp16_Impl( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + MLAS_FP16* output +) +{ + // ?cast input -> const unsigned short* + size_t i = 0; + for (; i + 15 < dim; i += 16) { + float32x8_t x0 = load_fp16_and_convert_to_fp32(input + i); + float32x8_t x1 = load_fp16_and_convert_to_fp32(input + i + 8); + float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); + float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); + __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); + float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); + float32x8_t sin_val = load_fp16_and_convert_to_fp32(sin_data + i / 2); + float32x8_t cos_val = load_fp16_and_convert_to_fp32(cos_data + i / 2); + // Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + // Store back into interleaved format + __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec); + float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec); + float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s); + float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s); + + // Store back into non interleaved format + convert_to_fp16_and_store(output + i, y0); + convert_to_fp16_and_store(output + i + 8, y1); + } + + for (; i < dim; i++) { + size_t cache_idx = i / 2; + bool sign = i & 1; + size_t j = sign ? i - 1 : i + 1; + + float output_data_i = input[i].ToFloat() * cos_data[cache_idx].ToFloat(); + float input_data_j = input[j].ToFloat(); + float sin_data_cache_idx = sin_data[cache_idx].ToFloat(); + if (sign) { + output_data_i += input_data_j * sin_data_cache_idx; + } else { + output_data_i -= input_data_j * sin_data_cache_idx; + } + output[i] = MLAS_FP16(output_data_i); + } +} + +template +void +RopeKernel_Avx2_fp32_Impl( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + float* output +); + +template <> +void +RopeKernel_Avx2_fp32_Impl( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + float* output +) { + const size_t half_dim = dim >> 1; + size_t i = 0, j = half_dim; + for (; i + 7 < half_dim; i += 8, j += 8) { + float32x8_t real = _mm256_loadu_ps(input + i); + float32x8_t imag = _mm256_loadu_ps(input + j); + float32x8_t sin_val = _mm256_loadu_ps(sin_data + i); + float32x8_t cos_val = _mm256_loadu_ps(cos_data + i); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into non interleaved format + _mm256_storeu_ps(output + i, real_out); + _mm256_storeu_ps(output + j, imag_out); + } + if (half_dim - i != 0) { + size_t rem = half_dim - i; + const __m256i mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - rem)); + //Use a mask to load the remaining input values + float32x8_t real = _mm256_maskload_ps(input + i, mask); + float32x8_t imag = _mm256_maskload_ps(input + j, mask); + float32x8_t sin_val = _mm256_maskload_ps(sin_data + i, mask); + float32x8_t cos_val = _mm256_maskload_ps(cos_data + i, mask); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into non interleaved format + _mm256_maskstore_ps(output + i, mask, real_out); + _mm256_maskstore_ps(output + j, mask, imag_out); + } +} + +template <> +void +RopeKernel_Avx2_fp32_Impl( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + float* output +) { + size_t i = 0; + for (; i + 15 < dim; i += 16) { + float32x8_t x0 = _mm256_loadu_ps(input + i); + float32x8_t x1 = _mm256_loadu_ps(input + i + 8); + //Load imaginary and real values to separate non-interleaved vectors + float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); + float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); + __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); + float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); + float32x8_t sin_val = _mm256_loadu_ps(sin_data + i / 2); + float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into interleaved format + __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec); + float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec); + float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s); + float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s); + _mm256_storeu_ps(output + i, y0); + _mm256_storeu_ps(output + i + 8, y1); + } + if (dim - i != 0) { + size_t rem = dim - i; + const __m256i mask0 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?8:rem))); + const __m256i mask1 = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - (rem>8?(rem-8):0))); + float32x8_t x0 = _mm256_maskload_ps(input + i, mask0); //Load the first set of data using mask + float32x8_t x1 = _mm256_maskload_ps(input + i + 8, mask1); //Load the reminder of data using a second mask + //Load imaginary and real values to separate non-interleaved vectors + float32x8_t real_s = _mm256_shuffle_ps(x0, x1, 0b10001000); + float32x8_t imag_s = _mm256_shuffle_ps(x0, x1, 0b11011101); + __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); + float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); + float32x8_t sin_val = _mm256_loadu_ps(sin_data+ i / 2); + float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); + //Compute Real and Imaginary output values + float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); + float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); + //Store back into interleaved format + __m256i out_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); + float32x8_t real_out_s = _mm256_permutevar8x32_ps(real_out, out_mask_vec); + float32x8_t imag_out_s = _mm256_permutevar8x32_ps(imag_out, out_mask_vec); + float32x8_t y0 = _mm256_unpacklo_ps(real_out_s, imag_out_s); + float32x8_t y1 = _mm256_unpackhi_ps(real_out_s, imag_out_s); + _mm256_maskstore_ps(output + i, mask0, y0); + _mm256_maskstore_ps(output + i + 8, mask1, y1); + } +} + +} // rope_avx2 namespace + +void +RopeKernel_Avx2_fp32( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output +) { + // real part and imaginary part must be paired + assert(dim % 2 == 0); + const auto* input_impl = reinterpret_cast(input); + const auto* sin_impl = reinterpret_cast(sin_data); + const auto* cos_impl = reinterpret_cast(cos_data); + auto* output_impl = reinterpret_cast(output); + + if (interleaved) { + RopeKernel_Avx2_fp32_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } else { + RopeKernel_Avx2_fp32_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } +} + +void +RopeKernel_Avx2_fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + bool interleaved, + MLAS_FP16* output +) +{ + // real part and imaginary part must be paired + assert(dim % 2 == 0); + + if (interleaved) { + RopeKernel_Avx2_fp16_Impl(input, sin_data, cos_data, dim, output); + } else { + RopeKernel_Avx2_fp16_Impl(input, sin_data, cos_data, dim, output); + } +} +} + +// +// Kernel dispatch structure definition. +// +const MLAS_ROPE_DISPATCH MlasRopeDispatchAvx2 = []() { + MLAS_ROPE_DISPATCH d; + d.SRope = rope_avx2::RopeKernel_Avx2_fp32; + d.HRope = rope_avx2::RopeKernel_Avx2_fp16; + return d; +}(); diff --git a/src/lib/rotary_embedding_kernel_avx2.h b/src/lib/rotary_embedding_kernel_avx2.h new file mode 100644 index 0000000..c08833e --- /dev/null +++ b/src/lib/rotary_embedding_kernel_avx2.h @@ -0,0 +1,47 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_avx2.h + +Abstract: + + This module includes function declarations and common helper functions for + rotary embedding on for AVX2 enabled h/w. + +--*/ + +#pragma once + + + +#include "mlasi.h" + +namespace rope_avx2 { + +// Rotary embedding kernel for FP32. Embed one hidden state vector. +void +RopeKernel_Avx2( + const float* input, + const float* sin_data, + const float* cos_data, + size_t dim, + bool interleaved, + float* output +); + +void +RopeKernel_Avx2_fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin_data, + const MLAS_FP16* cos_data, + size_t dim, + bool interleaved, + MLAS_FP16* output +); + +} // namespace rope_avx2 diff --git a/src/lib/rotary_embedding_kernel_neon.cpp b/src/lib/rotary_embedding_kernel_neon.cpp new file mode 100644 index 0000000..e59a95c --- /dev/null +++ b/src/lib/rotary_embedding_kernel_neon.cpp @@ -0,0 +1,32 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon.cpp + +Abstract: + + This module implements the rotary embedding kernels for ARM NEON. + +--*/ + +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_neon.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon = []() { + MLAS_ROPE_DISPATCH d; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + if (MlasFp16AccelerationSupported()) { + d.HRope = rope_neon::RopeKernel_Fp16; + } +#endif + return d; +}(); diff --git a/src/lib/rotary_embedding_kernel_neon.h b/src/lib/rotary_embedding_kernel_neon.h new file mode 100644 index 0000000..8153f65 --- /dev/null +++ b/src/lib/rotary_embedding_kernel_neon.h @@ -0,0 +1,37 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + rotary embedding on ARM cpu. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" + +namespace rope_neon { + +// Rotary embedding kernel for fp16. Embed one hidden state vector. +void +RopeKernel_Fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +); + +} // namespace rope_neon diff --git a/src/lib/rotary_embedding_kernel_neon_fp16.cpp b/src/lib/rotary_embedding_kernel_neon_fp16.cpp new file mode 100644 index 0000000..3a93723 --- /dev/null +++ b/src/lib/rotary_embedding_kernel_neon_fp16.cpp @@ -0,0 +1,279 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + rotary_embedding_kernel_neon_fp16.cpp + +Abstract: + + This module implements the fp16 rotary embedding kernels for ARM NEON. + +--*/ + +#include +#include + +#include "fp16_common.h" +#include "rotary_embedding.h" +#include "rotary_embedding_kernel_neon.h" + +namespace rope_neon { + +namespace { + +template +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +); + +template <> +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +) { + const size_t half_dim = dim >> 1; + size_t i = 0, j = half_dim; + if (i + 7 < half_dim) { + float16x8_t real = MlasLoadFloat16x8(input + i); + float16x8_t imag = MlasLoadFloat16x8(input + j); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + for (; i + 15 < half_dim; i += 8, j += 8) { + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x8(output + i, real_out); + MlasStoreFloat16x8(output + j, imag_out); + real = MlasLoadFloat16x8(input + i + 8); + imag = MlasLoadFloat16x8(input + j + 8); + sin_val = MlasLoadFloat16x8(sin + i + 8); + cos_val = MlasLoadFloat16x8(cos + i + 8); + } + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x8(output + i, real_out); + MlasStoreFloat16x8(output + j, imag_out); + i += 8, j += 8; + } + for (; i + 3 < half_dim; i += 4, j += 4) { + float16x4_t real = MlasLoadFloat16x4(input + i); + float16x4_t imag = MlasLoadFloat16x4(input + j); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreFloat16x4(output + i, real_out); + MlasStoreFloat16x4(output + j, imag_out); + } + if (half_dim - i == 3) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + real = MlasLoadLaneFloat16x4<1>(input + i + 1, real); + real = MlasLoadLaneFloat16x4<2>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag); + imag = MlasLoadLaneFloat16x4<2>(input + j + 2, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 1, real_out); + MlasStoreLaneFloat16x4<2>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out); + MlasStoreLaneFloat16x4<2>(output + j + 2, imag_out); + } else if (half_dim - i == 2) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + real = MlasLoadLaneFloat16x4<1>(input + i + 1, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + imag = MlasLoadLaneFloat16x4<1>(input + j + 1, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 1, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + MlasStoreLaneFloat16x4<1>(output + j + 1, imag_out); + } else if (half_dim - i == 1) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + j, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + j, imag_out); + } +} + +template <> +void +RopeKernel_Fp16_Impl( + const _mlas_fp16_* input, + const _mlas_fp16_* sin, + const _mlas_fp16_* cos, + size_t dim, + _mlas_fp16_* output +) { + size_t i = 0; + if (i + 15 < dim) { + float16x8_t x0 = MlasLoadFloat16x8(input + i); + float16x8_t x1 = MlasLoadFloat16x8(input + i + 8); + float16x8_t sin_val = MlasLoadFloat16x8(sin + i); + float16x8_t cos_val = MlasLoadFloat16x8(cos + i); + for (; i + 31 < dim; i += 16) { + float16x8_t real = vuzp1q_f16(x0, x1); + float16x8_t imag = vuzp2q_f16(x0, x1); + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + float16x8_t y0 = vzip1q_f16(real_out, imag_out); + float16x8_t y1 = vzip2q_f16(real_out, imag_out); + MlasStoreFloat16x8(output + i, y0); + MlasStoreFloat16x8(output + i + 8, y1); + x0 = MlasLoadFloat16x8(input + i + 16); + x1 = MlasLoadFloat16x8(input + i + 24); + sin_val = MlasLoadFloat16x8(sin + i + 16); + cos_val = MlasLoadFloat16x8(cos + i + 16); + } + float16x8_t real = vuzp1q_f16(x0, x1); + float16x8_t imag = vuzp2q_f16(x0, x1); + float16x8_t real_out = vfmsq_f16(vmulq_f16(real, cos_val), imag, sin_val); + float16x8_t imag_out = vfmaq_f16(vmulq_f16(real, sin_val), imag, cos_val); + float16x8_t y0 = vzip1q_f16(real_out, imag_out); + float16x8_t y1 = vzip2q_f16(real_out, imag_out); + MlasStoreFloat16x8(output + i, y0); + MlasStoreFloat16x8(output + i + 8, y1); + i += 16; + } + for (; i + 7 < dim; i += 8) { + float16x4_t x0 = MlasLoadFloat16x4(input + i); + float16x4_t x1 = MlasLoadFloat16x4(input + i + 4); + float16x4_t real = vuzp1_f16(x0, x1); + float16x4_t imag = vuzp2_f16(x0, x1); + float16x4_t sin_val = MlasLoadFloat16x4(sin + i); + float16x4_t cos_val = MlasLoadFloat16x4(cos + i); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + float16x4_t y0 = vzip1_f16(real_out, imag_out); + float16x4_t y1 = vzip2_f16(real_out, imag_out); + MlasStoreFloat16x4(output + i, y0); + MlasStoreFloat16x4(output + i + 4, y1); + } + if (dim - i == 6) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); + real = MlasLoadLaneFloat16x4<2>(input + i + 4, real); + imag = MlasLoadLaneFloat16x4<2>(input + i + 5, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + sin_val = MlasLoadLaneFloat16x4<2>(sin + i + 2, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + cos_val = MlasLoadLaneFloat16x4<2>(cos + i + 2, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + MlasStoreLaneFloat16x4<1>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out); + MlasStoreLaneFloat16x4<2>(output + i + 4, real_out); + MlasStoreLaneFloat16x4<2>(output + i + 5, imag_out); + } else if (dim - i == 4) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + real = MlasLoadLaneFloat16x4<1>(input + i + 2, real); + imag = MlasLoadLaneFloat16x4<1>(input + i + 3, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + sin_val = MlasLoadLaneFloat16x4<1>(sin + i + 1, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + cos_val = MlasLoadLaneFloat16x4<1>(cos + i + 1, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + MlasStoreLaneFloat16x4<1>(output + i + 2, real_out); + MlasStoreLaneFloat16x4<1>(output + i + 3, imag_out); + } else if (dim - i == 2) { + float16x4_t real = MlasZeroFloat16x4(); + float16x4_t imag = MlasZeroFloat16x4(); + float16x4_t sin_val = MlasZeroFloat16x4(); + float16x4_t cos_val = MlasZeroFloat16x4(); + real = MlasLoadLaneFloat16x4<0>(input + i, real); + imag = MlasLoadLaneFloat16x4<0>(input + i + 1, imag); + sin_val = MlasLoadLaneFloat16x4<0>(sin + i, sin_val); + cos_val = MlasLoadLaneFloat16x4<0>(cos + i, cos_val); + float16x4_t real_out = vfms_f16(vmul_f16(real, cos_val), imag, sin_val); + float16x4_t imag_out = vfma_f16(vmul_f16(real, sin_val), imag, cos_val); + MlasStoreLaneFloat16x4<0>(output + i, real_out); + MlasStoreLaneFloat16x4<0>(output + i + 1, imag_out); + } +} + +} // namespace + +void +RopeKernel_Fp16( + const MLAS_FP16* input, + const MLAS_FP16* sin, + const MLAS_FP16* cos, + size_t dim, + bool interleaved, + MLAS_FP16* output +) { + // real part and imaginary part must be paired + assert(dim % 2 == 0); + + const auto* input_impl = reinterpret_cast(input); + const auto* sin_impl = reinterpret_cast(sin); + const auto* cos_impl = reinterpret_cast(cos); + auto* output_impl = reinterpret_cast<_mlas_fp16_*>(output); + + if (interleaved) { + RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } else { + RopeKernel_Fp16_Impl(input_impl, sin_impl, cos_impl, dim, output_impl); + } +} + +} // namespace rope_neon diff --git a/src/lib/saturation_check.cpp b/src/lib/saturation_check.cpp new file mode 100644 index 0000000..7b022a7 --- /dev/null +++ b/src/lib/saturation_check.cpp @@ -0,0 +1,42 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + saturation_check.cpp + +Abstract: + + This module implements logic to check saturation of the VPMADDUBSW + instruction. + +--*/ + +#include "mlasi.h" + +namespace onnxruntime +{ + +#if defined(MLAS_TARGET_AMD64) + +std::atomic saturation_count{0}; + +void +reset_saturation_count() +{ + saturation_count = 0; +} + +#else + +void +reset_saturation_count() +{ +} + +#endif + +} // namespace onnxruntime diff --git a/src/lib/sgemm.cpp b/src/lib/sgemm.cpp index f8b25fb..616622a 100644 --- a/src/lib/sgemm.cpp +++ b/src/lib/sgemm.cpp @@ -1580,14 +1580,7 @@ MlasGemmBatch( const double Complexity = double(M) * double(N) * double(K); - ptrdiff_t TargetThreadCount; - - if (Complexity < double(MLAS_SGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { - TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; - } else { - TargetThreadCount = GetMlasPlatform().MaximumThreadCount; - } - + ptrdiff_t TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); if (TargetThreadCount >= MaximumThreadCount) { diff --git a/src/lib/softmax.h b/src/lib/softmax.h new file mode 100644 index 0000000..69fe1ae --- /dev/null +++ b/src/lib/softmax.h @@ -0,0 +1,129 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + softmax.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + softmax. + +--*/ + +#pragma once + +#include "mlasi.h" + +struct MLAS_SOFTMAX_DISPATCH { + /** + * @brief Compute the hyperbolic tangent function for each element of the input array + * @param Input Address of the input array. Valid in [-3.51562, 3.51562]. + * @param Output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input array + */ + typedef void(Tanh_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N + ); + + Tanh_Fp16_Fn* Tanh_Fp16 = nullptr; + + /** + * @brief Compute the softcap function for each element of the input array. Use tanh activation. + * @param Input Address of the input array. Valid if input / softcap in [-3.51562, 3.51562]. + * @param Output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input array + * @param Softcap The softcap value + */ + typedef void(Softcap_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + const MLAS_FP16 Softcap + ); + + Softcap_Fp16_Fn* Softcap_Fp16 = nullptr; + + /** + * @brief Compute the exponential function for each element of the input array. + * @param Input Address of the input array. Valid in [-17.3287, 11.0904]. + * @param Output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input array + */ + typedef void(Exp_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N + ); + + Exp_Fp16_Fn* Exp_Fp16 = nullptr; + + /** + * @brief Find the max value among the input array + * @param Input Address of the input array + * @param N Number of elements in the input array + */ + typedef MLAS_FP16(ReduceMax_Fp16_Fn)( + const MLAS_FP16* Input, + size_t N + ); + + ReduceMax_Fp16_Fn* ReduceMax_Fp16 = nullptr; + + /** + * @brief Compute the expotential function for each element of the input array and returnt he sum. It has smaller + * dynamic range for the input than Exp_Fp16_Fn thus is faster. + * @param Input Address of the input array. Valid in [-10.7438, 10.7438] + * @param Output Address of the output array. Could be the same as the input array or nullptr. + * @param N Number of elements in the input array + * @param NegativeMaximum The negative of the maximum value in the input array + */ + typedef MLAS_FP16(SumExp_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + const MLAS_FP16 NegativeMaximum + ); + + SumExp_Fp16_Fn* SumExp_Fp16 = nullptr; + + /** + * @brief Compute the softmax output for each element of the input array. input / sum. + * @param Input Address of the input array. Values of exp(x) + * @param Output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input array + * @param Sum Sum of exp(input) + */ + typedef void(Softmax_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + const MLAS_FP16 Sum + ); + + Softmax_Fp16_Fn* Softmax_Fp16 = nullptr; + + /** + * @brief Compute the log softmax output for each element of the input array. input - max - logSum + * @param Input Address of the input array + * @param Output Address of the output array. Could be the same as the input array. + * @param N Number of elements in the input array + * @param NagativeMaximum The negative of the maximum value in the input array + * @param LogSum The logarithm of the sum of the exponential function of the input array + */ + typedef void(LogSoftmax_Fp16_Fn)( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + const MLAS_FP16 NagativeMaximum, + const MLAS_FP16 LogSum + ); + + LogSoftmax_Fp16_Fn* LogSoftmax_Fp16 = nullptr; +}; diff --git a/src/lib/softmax_kernel_neon.cpp b/src/lib/softmax_kernel_neon.cpp new file mode 100644 index 0000000..77ad4b9 --- /dev/null +++ b/src/lib/softmax_kernel_neon.cpp @@ -0,0 +1,38 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + softmax_kernel_neon.cpp + +Abstract: + + This module implements the softmax kernels for ARM NEON. + +--*/ + +#include "softmax.h" +#include "softmax_kernel_neon.h" + +// +// Kernel dispatch structure definition. +// +const MLAS_SOFTMAX_DISPATCH MlasSoftmaxDispatchNeon = []() { + MLAS_SOFTMAX_DISPATCH d; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + if (MlasFp16AccelerationSupported()) { + d.Tanh_Fp16 = softmax_neon::Tanh_Kernel_Fp16; + d.Softcap_Fp16 = softmax_neon::Softcap_Kernel_Fp16; + d.Exp_Fp16 = softmax_neon::Exp_Kernel_Fp16; + d.ReduceMax_Fp16 = softmax_neon::ReduceMax_Kernel_Fp16; + d.SumExp_Fp16 = softmax_neon::SumExp_Kernel_Fp16; + d.Softmax_Fp16 = softmax_neon::Softmax_Kernel_Fp16; + d.LogSoftmax_Fp16 = softmax_neon::LogSoftmax_Kernel_Fp16; + } +#endif + return d; +}(); diff --git a/src/lib/softmax_kernel_neon.h b/src/lib/softmax_kernel_neon.h new file mode 100644 index 0000000..e362e5d --- /dev/null +++ b/src/lib/softmax_kernel_neon.h @@ -0,0 +1,40 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + softmax_kernel_neon.h + +Abstract: + + This module includes function declarations and common helper functions for + softmax on ARM cpu. + +--*/ + +#pragma once + +#include + +#include "mlasi.h" + +namespace softmax_neon { + +void Tanh_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N); + +void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 Softcap); + +void Exp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N); + +MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N); + +MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 NegativeMaximum); + +void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 Sum); + +void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 NegativeMaximum, const MLAS_FP16 LogSum); + +} // namespace rope_neon diff --git a/src/lib/softmax_kernel_neon_fp16.cpp b/src/lib/softmax_kernel_neon_fp16.cpp new file mode 100644 index 0000000..dfd65d9 --- /dev/null +++ b/src/lib/softmax_kernel_neon_fp16.cpp @@ -0,0 +1,917 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + softmax_kernel_neon_fp16.cpp + +Abstract: + + This module implements the fp16 softmax kernels for ARM NEON. + +--*/ +#include +#include + +#include "fp16_common.h" +#include "softmax.h" +#include "softmax_kernel_neon.h" + +namespace softmax_neon { + +template +struct MlasExpConstants { + T LowerRange; + T UpperRange; + T LowerRangeSumExp; + T UpperRangeSumExp; + T RoundingBias; + T Log2Reciprocal; + T Log2High; + T Log2Mid; + T Log2Low; + T poly_0; + T poly_1; + T poly_2; + T poly_3; + T poly_4; + T poly_56; + T MinimumExponent; + T MaximumExponent; +}; + +constexpr MlasExpConstants<_mlas_fp16_> ExpConstantsFp16 = { + 0xcc55, // -25 * ln2 + 0x498c, // 16 * ln2 + 0xc95f, // -15.5 * ln2 + 0x495f, // 15.5 * ln2 + 0x6600, // 1.5 * 2^10 + 0x3dc5, // 1/ln2 + 0xb98b, // -6.9287109375e-1f16 + 0x8c85, // -2.758502960205078e-4f16 + 0x8004, // -2.384185791015625e-7f16 + 0x15b0, // 1/6! + 0x2044, // 1/5! + 0x2955, // 1/4! + 0x3155, // 1/3! + 0x3800, // 1/2! + 0x3c00, // 1/1! + 0xC800, // -14 + 0x3C00, // 15 +}; + +template +MLAS_FORCEINLINE +const MlasExpConstants& Get_Exp_Constants(); + +template <> +MLAS_FORCEINLINE +const MlasExpConstants& Get_Exp_Constants() { + const static MlasExpConstants ExpConstantsFp16x8 = { + MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRange), + MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRange), + MlasBroadcastFloat16x8(ExpConstantsFp16.LowerRangeSumExp), + MlasBroadcastFloat16x8(ExpConstantsFp16.UpperRangeSumExp), + MlasBroadcastFloat16x8(ExpConstantsFp16.RoundingBias), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Reciprocal), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2High), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Mid), + MlasBroadcastFloat16x8(ExpConstantsFp16.Log2Low), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_0), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_1), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_2), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_3), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_4), + MlasBroadcastFloat16x8(ExpConstantsFp16.poly_56), + MlasBroadcastFloat16x8(ExpConstantsFp16.MinimumExponent), + MlasBroadcastFloat16x8(ExpConstantsFp16.MaximumExponent), + }; + return ExpConstantsFp16x8; +} + +template <> +MLAS_FORCEINLINE +const MlasExpConstants& Get_Exp_Constants() { + const static MlasExpConstants ExpConstantsFp16x4 = { + MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRange), + MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRange), + MlasBroadcastFloat16x4(ExpConstantsFp16.LowerRangeSumExp), + MlasBroadcastFloat16x4(ExpConstantsFp16.UpperRangeSumExp), + MlasBroadcastFloat16x4(ExpConstantsFp16.RoundingBias), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Reciprocal), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2High), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Mid), + MlasBroadcastFloat16x4(ExpConstantsFp16.Log2Low), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_0), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_1), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_2), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_3), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_4), + MlasBroadcastFloat16x4(ExpConstantsFp16.poly_56), + MlasBroadcastFloat16x4(ExpConstantsFp16.MinimumExponent), + MlasBroadcastFloat16x4(ExpConstantsFp16.MaximumExponent), + }; + return ExpConstantsFp16x4; +} + +// Range reduction + polynomial approximation. Refer algorithm details to MlasComputeExpVector. +template +MLAS_FORCEINLINE +T Exp_Vector_Fp16(T x) { + const auto& constants = Get_Exp_Constants(); + auto clamped_x = MlasClampFloat16(x, constants.LowerRange, constants.UpperRange); + + // integral + auto biased = MlasMultiplyAddFloat16(clamped_x, constants.Log2Reciprocal, constants.RoundingBias); + auto m = MlasSubtractFloat16(biased, constants.RoundingBias); + + // residual + auto r = MlasMultiplyAddFloat16(m, constants.Log2High, clamped_x); + r = MlasMultiplyAddFloat16(m, constants.Log2Mid, r); + r = MlasMultiplyAddFloat16(m, constants.Log2Low, r); + + // handle overflow + auto overflow = MlasShiftLeftInt16<10>(MlasReinterpretFloat16AsInt16(biased)); + auto normal = overflow; + + auto minimum_exponent = MlasReinterpretFloat16AsInt16(constants.MinimumExponent); + auto maximum_exponent = MlasReinterpretFloat16AsInt16(constants.MaximumExponent); + normal = MlasClampInt16(normal, minimum_exponent, maximum_exponent); + + overflow = MlasSubtractInt16(overflow, normal); + overflow = MlasAddInt16(overflow, maximum_exponent); + normal = MlasAddInt16(normal, maximum_exponent); + + // polynomial approximation + auto p = MlasMultiplyAddFloat16(constants.poly_0, r, constants.poly_1); + p = MlasMultiplyAddFloat16(p, r, constants.poly_2); + p = MlasMultiplyAddFloat16(p, r, constants.poly_3); + p = MlasMultiplyAddFloat16(p, r, constants.poly_4); + p = MlasMultiplyAddFloat16(p, r, constants.poly_56); + + auto overflow_f = MlasReinterpretInt16AsFloat16(overflow); + r = MlasMultiplyFloat16(r, overflow_f); + p = MlasMultiplyAddFloat16(p, r, overflow_f); + p = MlasMultiplyFloat16(p, MlasReinterpretInt16AsFloat16(normal)); + + return p; +} + +void Exp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + auto r0 = Exp_Vector_Fp16(v0); + auto r1 = Exp_Vector_Fp16(v1); + auto r2 = Exp_Vector_Fp16(v2); + auto r3 = Exp_Vector_Fp16(v3); + + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + MlasStoreFloat16x8(output + 16, r2); + MlasStoreFloat16x8(output + 24, r3); + + input += 32; + output += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + auto r0 = Exp_Vector_Fp16(v0); + auto r1 = Exp_Vector_Fp16(v1); + + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + + input += 16; + output += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + auto r0 = Exp_Vector_Fp16(v0); + MlasStoreFloat16x8(output, r0); + + input += 8; + output += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + auto r0 = Exp_Vector_Fp16(v0); + MlasStoreFloat16x4(output, r0); + + input += 4; + output += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + auto r0 = Exp_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 3); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + auto r0 = Exp_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 2); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + auto r0 = Exp_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 1); + } +} + +// assume no overflow +template +MLAS_FORCEINLINE +T SumExp_Vector_Fp16(T x, T negative_maximum) { + const auto& constants = Get_Exp_Constants(); + auto clamped_x = MlasMaximumFloat16(MlasAddFloat16(x, negative_maximum), constants.LowerRangeSumExp); + + // integral + auto biased = MlasMultiplyAddFloat16(clamped_x, constants.Log2Reciprocal, constants.RoundingBias); + auto m = MlasSubtractFloat16(biased, constants.RoundingBias); + + // residual + auto r = MlasMultiplyAddFloat16(m, constants.Log2High, clamped_x); + r = MlasMultiplyAddFloat16(m, constants.Log2Mid, r); + r = MlasMultiplyAddFloat16(m, constants.Log2Low, r); + + // 2^m + auto normal = MlasShiftLeftInt16<10>(MlasReinterpretFloat16AsInt16(biased)); + normal = MlasAddInt16(normal, MlasReinterpretFloat16AsInt16(constants.MaximumExponent)); + + // polynomial approximation + auto p = MlasMultiplyAddFloat16(constants.poly_0, r, constants.poly_1); + p = MlasMultiplyAddFloat16(p, r, constants.poly_2); + p = MlasMultiplyAddFloat16(p, r, constants.poly_3); + p = MlasMultiplyAddFloat16(p, r, constants.poly_4); + p = MlasMultiplyAddFloat16(p, r, constants.poly_56); + p = MlasMultiplyAddFloat16(p, r, constants.poly_56); + + p = MlasMultiplyFloat16(p, MlasReinterpretInt16AsFloat16(normal)); + + return p; +} + +MLAS_FORCEINLINE +float16x8_t AddUp(float16x8_t v0, float16x8_t v1, float16x8_t v2, float16x8_t v3, float16x8_t v4) { + v0 = MlasAddFloat16(v0, v1); + v2 = MlasAddFloat16(v2, v3); + return MlasAddFloat16(MlasAddFloat16(v0, v2), v4); +} + +MLAS_FORCEINLINE +float16x8_t AddUp(float16x8_t v0, float16x8_t v1, float16x8_t v2) { + return MlasAddFloat16(MlasAddFloat16(v0, v1), v2); +} + +MLAS_FP16 SumExp_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 NegativeMaximum) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + float16x8_t negative_maximum8 = MlasBroadcastFloat16x8(NegativeMaximum.val); + float16x4_t negative_maximum4 = MlasBroadcastFloat16x4(NegativeMaximum.val); + const bool store_output = Output != nullptr; + float16x8_t accumulator8 = MlasZeroFloat16x8(); + float16x4_t accumulator4 = MlasZeroFloat16x4(); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum8); + auto r1 = SumExp_Vector_Fp16(v1, negative_maximum8); + auto r2 = SumExp_Vector_Fp16(v2, negative_maximum8); + auto r3 = SumExp_Vector_Fp16(v3, negative_maximum8); + + accumulator8 = AddUp(r0, r1, r2, r3, accumulator8); + + if (store_output) { + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + MlasStoreFloat16x8(output + 16, r2); + MlasStoreFloat16x8(output + 24, r3); + output += 32; + } + + input += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum8); + auto r1 = SumExp_Vector_Fp16(v1, negative_maximum8); + + accumulator8 = AddUp(r0, r1, accumulator8); + + if (store_output) { + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + output += 16; + } + + input += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum8); + accumulator8 = MlasAddFloat16(r0, accumulator8); + + if (store_output) { + MlasStoreFloat16x8(output, r0); + output += 8; + } + + input += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4); + accumulator4 = MlasAddFloat16(r0, accumulator4); + + if (store_output) { + MlasStoreFloat16x4(output, r0); + output += 4; + } + + input += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4); + + if (store_output) { + MlasStorePartialFloat16x4(output, r0, 3); + } + + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 3)); + accumulator4 = MlasAddFloat16(r0, accumulator4); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4); + + if (store_output) { + MlasStorePartialFloat16x4(output, r0, 2); + } + + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 3)); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 2)); + accumulator4 = MlasAddFloat16(r0, accumulator4); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + auto r0 = SumExp_Vector_Fp16(v0, negative_maximum4); + + if (store_output) { + MlasStorePartialFloat16x4(output, r0, 1); + } + + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 3)); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 2)); + r0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0), MlasReinterpretFloat16AsInt16(r0), 1)); + accumulator4 = MlasAddFloat16(r0, accumulator4); + } + + auto t = MlasAddFloat16(vget_low_f16(accumulator8), vget_high_f16(accumulator8)); + t = MlasAddFloat16(t, accumulator4); + _mlas_fp16_ result = MlasReduceAddFloat16(t); + return MLAS_FP16::FromBits(result); +} + +template +struct MlasTanhConstants { + T LowerRange; + T UpperRange; + T alpha_7; + T alpha_5; + T alpha_3; + T alpha_1; + T beta_6; + T beta_4; + T beta_2; + T beta_0; +}; + +constexpr MlasTanhConstants<_mlas_fp16_> TanhConstantsFp16 = { + 0xc308, // -3.51562 + 0x4308, // 3.51562 + 0x0001, + 0x00f9, + 0x1138, + 0x1d03, + 0x0014, + 0x07c5, + 0x18a5, + 0x1d03, +}; + +template +MLAS_FORCEINLINE +const MlasTanhConstants& Get_Tanh_Constants(); + +template <> +MLAS_FORCEINLINE +const MlasTanhConstants& Get_Tanh_Constants() { + const static MlasTanhConstants TanhConstantsFp16x8 = { + MlasBroadcastFloat16x8(TanhConstantsFp16.LowerRange), + MlasBroadcastFloat16x8(TanhConstantsFp16.UpperRange), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_7), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_5), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_3), + MlasBroadcastFloat16x8(TanhConstantsFp16.alpha_1), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_6), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_4), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_2), + MlasBroadcastFloat16x8(TanhConstantsFp16.beta_0), + }; + return TanhConstantsFp16x8; +} + +template <> +MLAS_FORCEINLINE +const MlasTanhConstants& Get_Tanh_Constants() { + const static MlasTanhConstants TanhConstantsFp16x4 = { + MlasBroadcastFloat16x4(TanhConstantsFp16.LowerRange), + MlasBroadcastFloat16x4(TanhConstantsFp16.UpperRange), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_7), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_5), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_3), + MlasBroadcastFloat16x4(TanhConstantsFp16.alpha_1), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_6), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_4), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_2), + MlasBroadcastFloat16x4(TanhConstantsFp16.beta_0), + }; + return TanhConstantsFp16x4; +} + +// TODO(fajin): optimize polynomial coefficients +template +MLAS_FORCEINLINE +T Tanh_Vector_Fp16(T x) { + const auto& constants = Get_Tanh_Constants(); + x = MlasClampFloat16(x, constants.LowerRange, constants.UpperRange); + + T x_2 = MlasMultiplyFloat16(x, x); + + T p = MlasMultiplyAddFloat16(constants.alpha_7, x_2, constants.alpha_5); + p = MlasMultiplyAddFloat16(p, x_2, constants.alpha_3); + p = MlasMultiplyAddFloat16(p, x_2, constants.alpha_1); + p = MlasMultiplyFloat16(p, x); + + T q = MlasMultiplyAddFloat16(constants.beta_6, x_2, constants.beta_4); + q = MlasMultiplyAddFloat16(q, x_2, constants.beta_2); + q = MlasMultiplyAddFloat16(q, x_2, constants.beta_0); + + return MlasDivideFloat16(p, q); +} + +void Tanh_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + auto r0 = Tanh_Vector_Fp16(v0); + auto r1 = Tanh_Vector_Fp16(v1); + auto r2 = Tanh_Vector_Fp16(v2); + auto r3 = Tanh_Vector_Fp16(v3); + + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + MlasStoreFloat16x8(output + 16, r2); + MlasStoreFloat16x8(output + 24, r3); + + input += 32; + output += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + auto r0 = Tanh_Vector_Fp16(v0); + auto r1 = Tanh_Vector_Fp16(v1); + + MlasStoreFloat16x8(output, r0); + MlasStoreFloat16x8(output + 8, r1); + + input += 16; + output += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + auto r0 = Tanh_Vector_Fp16(v0); + MlasStoreFloat16x8(output, r0); + + input += 8; + output += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + auto r0 = Tanh_Vector_Fp16(v0); + MlasStoreFloat16x4(output, r0); + + input += 4; + output += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + auto r0 = Tanh_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 3); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + auto r0 = Tanh_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 2); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + auto r0 = Tanh_Vector_Fp16(v0); + MlasStorePartialFloat16x4(output, r0, 1); + } +} + +void Softcap_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 Softcap) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + auto softcap8 = MlasBroadcastFloat16x8(Softcap.val); + auto softcap4 = MlasBroadcastFloat16x4(Softcap.val); + auto one8 = MlasBroadcastFloat16x8((_mlas_fp16_)0x3c00); + auto one4 = MlasBroadcastFloat16x4((_mlas_fp16_)0x3c00); + auto softcap_reciprocal8 = MlasDivideFloat16(one8, softcap8); + auto softcap_reciprocal4 = MlasDivideFloat16(one4, softcap4); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8); + v1 = MlasMultiplyFloat16(v1, softcap_reciprocal8); + v2 = MlasMultiplyFloat16(v2, softcap_reciprocal8); + v3 = MlasMultiplyFloat16(v3, softcap_reciprocal8); + + v0 = Tanh_Vector_Fp16(v0); + v1 = Tanh_Vector_Fp16(v1); + v2 = Tanh_Vector_Fp16(v2); + v3 = Tanh_Vector_Fp16(v3); + + v0 = MlasMultiplyFloat16(v0, softcap8); + v1 = MlasMultiplyFloat16(v1, softcap8); + v2 = MlasMultiplyFloat16(v2, softcap8); + v3 = MlasMultiplyFloat16(v3, softcap8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + MlasStoreFloat16x8(output + 16, v2); + MlasStoreFloat16x8(output + 24, v3); + + input += 32; + output += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8); + v1 = MlasMultiplyFloat16(v1, softcap_reciprocal8); + + v0 = Tanh_Vector_Fp16(v0); + v1 = Tanh_Vector_Fp16(v1); + + v0 = MlasMultiplyFloat16(v0, softcap8); + v1 = MlasMultiplyFloat16(v1, softcap8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + + input += 16; + output += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal8); + v0 = Tanh_Vector_Fp16(v0); + v0 = MlasMultiplyFloat16(v0, softcap8); + MlasStoreFloat16x8(output, v0); + + input += 8; + output += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); + v0 = Tanh_Vector_Fp16(v0); + v0 = MlasMultiplyFloat16(v0, softcap4); + MlasStoreFloat16x4(output, v0); + + input += 4; + output += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); + v0 = Tanh_Vector_Fp16(v0); + v0 = MlasMultiplyFloat16(v0, softcap4); + MlasStorePartialFloat16x4(output, v0, 3); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); + v0 = Tanh_Vector_Fp16(v0); + v0 = MlasMultiplyFloat16(v0, softcap4); + MlasStorePartialFloat16x4(output, v0, 2); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + v0 = MlasMultiplyFloat16(v0, softcap_reciprocal4); + v0 = Tanh_Vector_Fp16(v0); + v0 = MlasMultiplyFloat16(v0, softcap4); + MlasStorePartialFloat16x4(output, v0, 1); + } +} + +MLAS_FP16 ReduceMax_Kernel_Fp16(const MLAS_FP16* Input, size_t N) { + const auto* input = reinterpret_cast(Input); + auto max8 = MlasBroadcastFloat16x8((_mlas_fp16_)0xfbff); + auto max4 = MlasBroadcastFloat16x4((_mlas_fp16_)0xfbff); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + v0 = MlasMaximumFloat16(v0, v1); + v2 = MlasMaximumFloat16(v2, v3); + v0 = MlasMaximumFloat16(v0, v2); + max8 = MlasMaximumFloat16(max8, v0); + + input += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + v0 = MlasMaximumFloat16(v0, v1); + max8 = MlasMaximumFloat16(max8, v0); + + input += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + max8 = MlasMaximumFloat16(max8, v0); + + input += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + max4 = MlasMaximumFloat16(max4, v0); + + input += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3)); + max4 = MlasMaximumFloat16(max4, v0); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3)); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 2)); + max4 = MlasMaximumFloat16(max4, v0); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 3)); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 2)); + v0 = MlasReinterpretInt16AsFloat16(vset_lane_s16(static_cast(0xfbff), MlasReinterpretFloat16AsInt16(v0), 1)); + max4 = MlasMaximumFloat16(max4, v0); + } + + auto t = MlasMaximumFloat16(vget_low_f16(max8), vget_high_f16(max8)); + t = MlasMaximumFloat16(t, max4); + _mlas_fp16_ result = MlasReduceMaximumFloat16(t); + + return MLAS_FP16::FromBits(result); +} + +void Softmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 Sum) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + auto sum8 = MlasBroadcastFloat16x8(Sum.val); + auto sum4 = MlasBroadcastFloat16x4(Sum.val); + auto scale8 = MlasDivideFloat16(MlasBroadcastFloat16x8((_mlas_fp16_)0x3c00), sum8); + auto scale4 = MlasDivideFloat16(MlasBroadcastFloat16x4((_mlas_fp16_)0x3c00), sum4); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + v0 = MlasMultiplyFloat16(v0, scale8); + v1 = MlasMultiplyFloat16(v1, scale8); + v2 = MlasMultiplyFloat16(v2, scale8); + v3 = MlasMultiplyFloat16(v3, scale8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + MlasStoreFloat16x8(output + 16, v2); + MlasStoreFloat16x8(output + 24, v3); + + input += 32; + output += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + v0 = MlasMultiplyFloat16(v0, scale8); + v1 = MlasMultiplyFloat16(v1, scale8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + + input += 16; + output += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + v0 = MlasMultiplyFloat16(v0, scale8); + MlasStoreFloat16x8(output, v0); + + input += 8; + output += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + v0 = MlasMultiplyFloat16(v0, scale4); + MlasStoreFloat16x4(output, v0); + + input += 4; + output += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + v0 = MlasMultiplyFloat16(v0, scale4); + MlasStorePartialFloat16x4(output, v0, 3); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + v0 = MlasMultiplyFloat16(v0, scale4); + MlasStorePartialFloat16x4(output, v0, 2); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + v0 = MlasMultiplyFloat16(v0, scale4); + MlasStorePartialFloat16x4(output, v0, 1); + } +} + +void LogSoftmax_Kernel_Fp16(const MLAS_FP16* Input, MLAS_FP16* Output, size_t N, const MLAS_FP16 NegativeMaximum, const MLAS_FP16 LogSum) { + const auto* input = reinterpret_cast(Input); + auto* output = reinterpret_cast<_mlas_fp16_*>(Output); + auto negative_maximum8 = MlasBroadcastFloat16x8(NegativeMaximum.val); + auto negative_maximum4 = MlasBroadcastFloat16x4(NegativeMaximum.val); + auto log_sum8 = MlasBroadcastFloat16x8(LogSum.val); + auto log_sum4 = MlasBroadcastFloat16x4(LogSum.val); + + while (N >= 32) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + auto v2 = MlasLoadFloat16x8(input + 16); + auto v3 = MlasLoadFloat16x8(input + 24); + + v0 = MlasAddFloat16(v0, negative_maximum8); + v1 = MlasAddFloat16(v1, negative_maximum8); + v2 = MlasAddFloat16(v2, negative_maximum8); + v3 = MlasAddFloat16(v3, negative_maximum8); + + v0 = MlasSubtractFloat16(v0, log_sum8); + v1 = MlasSubtractFloat16(v1, log_sum8); + v2 = MlasSubtractFloat16(v2, log_sum8); + v3 = MlasSubtractFloat16(v3, log_sum8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + MlasStoreFloat16x8(output + 16, v2); + MlasStoreFloat16x8(output + 24, v3); + + input += 32; + output += 32; + N -= 32; + } + + if (N & 16) { + auto v0 = MlasLoadFloat16x8(input); + auto v1 = MlasLoadFloat16x8(input + 8); + + v0 = MlasAddFloat16(v0, negative_maximum8); + v1 = MlasAddFloat16(v1, negative_maximum8); + + v0 = MlasSubtractFloat16(v0, log_sum8); + v1 = MlasSubtractFloat16(v1, log_sum8); + + MlasStoreFloat16x8(output, v0); + MlasStoreFloat16x8(output + 8, v1); + + input += 16; + output += 16; + N -= 16; + } + + if (N & 8) { + auto v0 = MlasLoadFloat16x8(input); + v0 = MlasAddFloat16(v0, negative_maximum8); + v0 = MlasSubtractFloat16(v0, log_sum8); + MlasStoreFloat16x8(output, v0); + + input += 8; + output += 8; + N -= 8; + } + + if (N & 4) { + auto v0 = MlasLoadFloat16x4(input); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); + MlasStoreFloat16x4(output, v0); + + input += 4; + output += 4; + N -= 4; + } + + if (N == 3) { + auto v0 = MlasLoadPartialFloat16x4(input, 3); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); + MlasStorePartialFloat16x4(output, v0, 3); + } else if (N == 2) { + auto v0 = MlasLoadPartialFloat16x4(input, 2); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); + MlasStorePartialFloat16x4(output, v0, 2); + } else if (N == 1) { + auto v0 = MlasLoadPartialFloat16x4(input, 1); + v0 = MlasAddFloat16(v0, negative_maximum4); + v0 = MlasSubtractFloat16(v0, log_sum4); + MlasStorePartialFloat16x4(output, v0, 1); + } +} + +} // namespace rope_neon diff --git a/src/lib/sqnbitgemm_kernel_avx2.cpp b/src/lib/sqnbitgemm_kernel_avx2.cpp index 5d42f22..5f80a81 100644 --- a/src/lib/sqnbitgemm_kernel_avx2.cpp +++ b/src/lib/sqnbitgemm_kernel_avx2.cpp @@ -585,6 +585,88 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx2( return CountM; } +template +MLAS_FORCEINLINE +size_t +SQ8BitGemmKernel_BlkSum_CompInt8_avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen == 16) { + MlasQ8Int8GemmKernelBlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ8Int8GemmKernelBlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + CountK, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ8Int8GemmKernelBlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + size_t SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni( const size_t BlkLen, @@ -1310,9 +1392,9 @@ SQ4BitGemmPackQuantBDataAndBlkSum( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, - bool has_zp_input, + bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { @@ -1325,7 +1407,34 @@ SQ4BitGemmPackQuantBDataAndBlkSum( if (BlkLen == 32 && ComputeType == SQNBIT_CompInt8) { SubBlkLen = 64; } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + +static void +SQ8BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 64; + } + Q8PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); } // @@ -1334,17 +1443,20 @@ SQ4BitGemmPackQuantBDataAndBlkSum( const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum; - d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; - d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; return d; @@ -1353,17 +1465,20 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() { const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() { MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum; - d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; - d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2; d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx2; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2; return d; diff --git a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h index 445ead3..aec5dc9 100644 --- a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h +++ b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen16.h @@ -6,6 +6,13 @@ #include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx2BlkLen16[40], 32) = { + 0x00000000, 0x00000000, 0x00000002, 0x00000002, 0x00000001, 0x00000001, 0x00000003, 0x00000003, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, + 0x00000000, 0x00000000, 0x00000000, 0x00000000, 0x00000001, 0x00000001, 0x00000001, 0x00000001 +}; MLAS_FORCEINLINE __m256 load_and_broadcast_4_scale_2(const float* scale) @@ -152,6 +159,208 @@ accumulate_blklen16_r2c1blk4_avx2( scale_a0, scale_a1, scale_b, acc0, acc1); } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk4_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m256& acc0 +) +{ + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + // 22222222 22222222, 33333333 33333333 + const __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + // 00 22, 11 33 + const __m256i scale_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16)); + // 0123, 0123 + __m256 scale_b_4_ps = _mm256_broadcast_ps((const __m128*)scale_b); + __m256 scale_a0_4_ps = _mm256_broadcast_ps((const __m128*)scale_a0); + __m256 scale_a0b_4_ps = _mm256_mul_ps(scale_b_4_ps, scale_a0_4_ps); + __m256 scale_a0b_4_shuffle_ps = _mm256_permutevar_ps(scale_a0b_4_ps, scale_mask); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + // 0000, 1111 + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + // 2222, 3333 + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + // 0022, 1133 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 8)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 16)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 24)); + + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + // 0000 0000, 1111 1111 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + // 00 00, 11 11 + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + // 22 22, 33 33 + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + // 00 22, 11 33 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r2c1blk4_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + // 22222222 22222222, 33333333 33333333 + const __m256i bv1_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr + 32)); + // 00 22, 11 33 + const __m256i scale_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16)); + // 0123, 0123 + __m256 scale_b_4_ps = _mm256_broadcast_ps((const __m128*)scale_b); + __m256 scale_a0_4_ps = _mm256_broadcast_ps((const __m128*)scale_a0); + __m256 scale_a0b_4_ps = _mm256_mul_ps(scale_b_4_ps, scale_a0_4_ps); + __m256 scale_a0b_4_shuffle_ps = _mm256_permutevar_ps(scale_a0b_4_ps, scale_mask); + __m256 scale_a1_4_ps = _mm256_broadcast_ps((const __m128*)scale_a1); + __m256 scale_a1b_4_ps = _mm256_mul_ps(scale_b_4_ps, scale_a1_4_ps); + __m256 scale_a1b_4_shuffle_ps = _mm256_permutevar_ps(scale_a1b_4_ps, scale_mask); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + // 0000, 1111 + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + // 2222, 3333 + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + // 0022, 1133 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + + const __m256i dot10_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + const __m256i dot11_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale_a1b_4_shuffle_ps, acc1); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 8)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 16)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 24)); + + // 00000000 00000000, 11111111 11111111 + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + // row 0 + // 0000 0000, 1111 1111 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + // 00 00, 11 11 + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + // 22 22, 33 33 + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + // 00 22, 11 33 + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_4_shuffle_ps, acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + const __m256i dot11_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av11_32_epi8); + const __m256i dot11_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av11_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + const __m256i dot11_low_8_epi32 = _mm256_madd_epi16(one_mask, dot11_low_16_epi16); + const __m256i dot11_high_8_epi32 = _mm256_madd_epi16(one_mask, dot11_high_16_epi16); + const __m256i dot11_8_epi32 = _mm256_add_epi32(dot11_low_8_epi32, dot11_high_8_epi32); + + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale_a1b_4_shuffle_ps, acc1); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk1_avx2( + const __m128i& av00_16_epi8, + const std::byte* QuantBDataPtr, + float scale_a0b, + __m256& acc0 +) +{ + const __m128i bv0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantBDataPtr)); + __m256 scale_a0b_1_ps = _mm256_set1_ps(scale_a0b); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + const __m128i dot00_4_epi32 = _mm_dpbusds_avx_epi32(_mm_setzero_si128(), bv0_16_epi8, av00_16_epi8); + const __m256i dot00_8_epi32 = _mm256_cvtepu32_epi64(dot00_4_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_1_ps, acc0); + } + else +#endif + { + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen16 + 24)); + const __m256i bv0_32_epi8 = _mm256_cvtepu8_epi16(bv0_16_epi8); + const __m256i av00_32_epi8 = _mm256_cvtepu8_epi16(av00_16_epi8); + const __m256i dot00_16_epi16 = _mm256_maddubs_epi16(bv0_32_epi8, av00_32_epi8); + const __m256i dot00_8_epi32 = _mm256_madd_epi16(one_mask, dot00_16_epi16); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_a0b_1_ps, acc0); + } +} + static MLAS_FORCEINLINE void accumulate_blklen16_r1c1blk4_avx2( const __m256i& av0_32_epi8, @@ -332,6 +541,118 @@ Q4Int8GemmR2xC4BlkLen16Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData4 = BlkDataSizeInBytes * PerAccuBlk4; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 3; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda + 32)); + + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData4, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData4, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData4, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const __m128i av_10_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr + lda)); + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_b0 = *QuantBScalePtr; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_a00 * scale_b0, acc[0]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr, scale_a10 * scale_b0, acc[NCols4]); + + const float scale_b1 = *(QuantBScalePtr + 1); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a00 * scale_b1, acc[1]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a10 * scale_b1, acc[NCols4 + 1]); + + const float scale_b2 = *(QuantBScalePtr + 2); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a00 * scale_b2, acc[2]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a10 * scale_b2, acc[NCols4 + 2]); + + const float scale_b3 = *(QuantBScalePtr + 3); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a00 * scale_b3, acc[3]); + accumulate_q8_blklen16_r1c1blk1_avx2(av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a10 * scale_b3, acc[NCols4 + 3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr+= NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( const std::byte* QuantA, const float* QuantAScale, @@ -437,6 +758,108 @@ void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen16Avx2( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const std::byte* QuantABlk00 = QuantAPtr; + const std::byte* QuantABlk01 = QuantABlk00 + 32; + const std::byte* QuantABlk10 = QuantAPtr + lda; + const std::byte* QuantABlk11 = QuantABlk10 + 32; + + // load A: + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk00); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk01); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk10); + const __m256i av_11_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk11); + + accumulate_q8_blklen16_r2c1blk4_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + // load A + const std::byte* QuantABlk0 = QuantAPtr; + const __m128i av0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantABlk0)); + const __m128i av1_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantABlk0 + lda)); + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_b0 = *QuantBScalePtr; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr, scale_a00 * scale_b0, acc0); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr, scale_a10 * scale_b0, acc1); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen16Avx2( const std::byte* QuantA, @@ -549,6 +972,106 @@ Q4Int8GemmR1xC4BlkLen16Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + const size_t StrideQuantBData4 = BlkDataSizeInBytes * PerAccuBlk4; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData4, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData4, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData4, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += StrideQuantBData4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const float scale_a00 = *QuantAScalePtr; + + const float scale_b0 = *QuantBScalePtr; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_a00 * scale_b0, acc[0]); + + const float scale_b1 = *(QuantBScalePtr + 1); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a00 * scale_b1, acc[1]); + + const float scale_b2 = *(QuantBScalePtr + 2); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a00 * scale_b2, acc[2]); + + const float scale_b3 = *(QuantBScalePtr + 3); + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a00 * scale_b3, acc[3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen16Avx2( const std::byte* QuantA, @@ -634,6 +1157,90 @@ Q4Int8GemmR1xC1BlkLen16Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 4 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; // track half byte increments with this index instead of a pointer + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_01_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + 32)); + + accumulate_q8_blklen16_r1c1blk4_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen16 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const float scale_a00 = *QuantAScalePtr; + + const float scale_a0b = scale_a00 * (*QuantBScalePtr); + accumulate_q8_blklen16_r1c1blk1_avx2(av_16_epi8, QuantBDataPtr, scale_a0b, acc0); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen16Avx2( @@ -725,3 +1332,95 @@ MLAS_FORCEINLINE return CountM; } + +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen16Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen16Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen16Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index 5dab809..d2d9886 100644 --- a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -6,6 +6,11 @@ #include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx2BlkLen32[24], 32) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001 +}; MLAS_FORCEINLINE void accumulate_1blk_dot(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, @@ -115,6 +120,156 @@ accumulate_blklen32_r2c1blk2_avx2( #endif } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m256& acc0 +) +{ + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i bv1_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); // 01 01 01 01 + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_a0b_2_ps = _mm256_mul_ps(scale_b_2_ps, scale_a0_2_ps); + __m256 scale0_8_ps = _mm256_shuffle_ps(scale_a0b_2_ps, scale_a0b_2_ps, _MM_SHUFFLE(1, 1, 0, 0)); // 00 11 00 11 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11 00 11 + const __m256 sum0_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_ps, scale0_8_ps, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11, 00 11 + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale0_8_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk2_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + const __m256i bv1_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256 scale_b_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); // 01 01 01 01 + __m256 scale_a0_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + __m256 scale_a0b_2_ps = _mm256_mul_ps(scale_b_2_ps, scale_a0_2_ps); + __m256 scale0_8_ps = _mm256_shuffle_ps(scale_a0b_2_ps, scale_a0b_2_ps, _MM_SHUFFLE(1, 1, 0, 0)); // 00 11 00 11 + __m256 scale_a1_2_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + __m256 scale_a1b_2_ps = _mm256_mul_ps(scale_b_2_ps, scale_a1_2_ps); + __m256 scale1_8_ps = _mm256_shuffle_ps(scale_a1b_2_ps, scale_a1b_2_ps, _MM_SHUFFLE(1, 1, 0, 0)); // 00 11 00 11 + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + const __m256i dot00_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + const __m256i dot01_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av01_32_epi8); + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11 00 11 + const __m256 sum0_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_ps, scale0_8_ps, acc0); + + const __m256i dot10_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + const __m256i dot11_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv1_32_epi8, av11_32_epi8); + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); // 00 11 00 11 + const __m256 sum1_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_ps, scale1_8_ps, acc1); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + // row 0 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_hadd_epi32(dot00_8_epi32, dot01_8_epi32); // 00 11, 00 11 + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale0_8_ps, acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + const __m256i dot11_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av11_32_epi8); + const __m256i dot11_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av11_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + const __m256i dot11_low_8_epi32 = _mm256_madd_epi16(one_mask, dot11_low_16_epi16); + const __m256i dot11_high_8_epi32 = _mm256_madd_epi16(one_mask, dot11_high_16_epi16); + const __m256i dot11_8_epi32 = _mm256_add_epi32(dot11_low_8_epi32, dot11_high_8_epi32); + + const __m256i sum1_8_epi32 = _mm256_hadd_epi32(dot10_8_epi32, dot11_8_epi32); // 00 11, 00 11 + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale1_8_ps, acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk2_avx2( @@ -196,6 +351,100 @@ accumulate_blklen32_r2c1blk1_avx2( #endif } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + __m256& acc0 +) +{ + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + accumulate_1blk_dot_vnni(av00_32_epi8, bv0_32_epi8, combined_scale00, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + __m256 dot00_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(dot00_8_ps, _mm256_set1_ps(combined_scale00), acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + float combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + const __m256i bv0_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + accumulate_1blk_dot_vnni(av00_32_epi8, bv0_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_vnni(av10_32_epi8, bv0_32_epi8, combined_scale10, acc1); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen32 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + + // row 0 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + __m256 dot00_8_ps = _mm256_cvtepi32_ps(dot00_8_epi32); + acc0 = _mm256_fmadd_ps(dot00_8_ps, _mm256_set1_ps(combined_scale00), acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + __m256 dot10_8_ps = _mm256_cvtepi32_ps(dot10_8_epi32); + acc1 = _mm256_fmadd_ps(dot10_8_ps, _mm256_set1_ps(combined_scale10), acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk1_avx2( @@ -367,6 +616,116 @@ Q4Int8Gemm2x4x2BlkLen32Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk2 = 2; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData2 = PerAccuBlk2 * BlkDataSizeInBytes; + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + // load A: + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + if (k_blks_remaining > 0) { + // load A + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + + float scale_a00 = *QuantAScalePtr; + float scale_a10 = *(QuantAScalePtr + BlockCountK); + + float scale_00 = scale_a00 * (QuantBScalePtr)[0], scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc[0], acc[NCols4]); + + float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0], scale_11 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, scale_11, acc[1], acc[NCols4 + 1]); + + float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0], scale_12 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, scale_12, acc[2], acc[NCols4 + 2]); + + float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0], scale_13 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, scale_13, acc[3], acc[NCols4 + 3]); + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( const std::byte* QuantA, @@ -460,6 +819,95 @@ void MLAS_FORCEINLINE Q4Int8Gemm2xXBlkLen32Avx2( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + const __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + BlkLen32)); + + accumulate_q8_blklen32_r2c1blk2_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx2(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc0, acc1); + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmXx4BlkLen32Avx2( @@ -589,6 +1037,100 @@ Q4Int8GemmXx4BlkLen32Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk2 = 2; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData2 = PerAccuBlk2 * BlkDataSizeInBytes; + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); + + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData2, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData2 * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + if (k_blks_remaining > 0) { + // load A + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + const float scale_a00 = *QuantAScalePtr; + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc[0]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, acc[1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, acc[2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, acc[3]); + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmXxXBlkLen32Avx2( @@ -672,6 +1214,81 @@ Q4Int8GemmXxXBlkLen32Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes16; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr)); + const __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + BlkLen32)); + accumulate_q8_blklen32_r1c1blk2_avx2(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + if (k_blks_remaining > 0) { + const __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + const float& scale_a00 = *QuantAScalePtr; + const float& scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc0); + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t @@ -765,6 +1382,98 @@ MLAS_FORCEINLINE return CountM; } +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen32Avx2( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen32Avx2( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen32Avx2( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} + // this function is to explore larger NCols. With Avx2 it does not improve performance. // Leave it here until the same is implemented in avx512. template accumulator> diff --git a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h index d4b89bd..2058374 100644 --- a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h +++ b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen64.h @@ -6,6 +6,12 @@ #include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx2BlkLen64[24], 32) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001, 0x00010001 +}; + template static MLAS_FORCEINLINE void accumulate_blklen64_r2c1blk1_avx2( @@ -76,6 +82,141 @@ accumulate_blklen64_r2c1blk1_avx2( #endif } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r1c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& bv0_32_epi8, + const __m256i& bv1_32_epi8, + float scale_a0b, + __m256& acc0 +) +{ + __m256 scale_8_ps = _mm256_set1_ps(scale_a0b); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + __m256i sum_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum_8_epi32 = _mm256_dpbusds_avx_epi32(sum_8_epi32, bv1_32_epi8, av01_32_epi8); + __m256 sum_ps = _mm256_cvtepi32_ps(sum_8_epi32); + acc0 = _mm256_fmadd_ps(sum_ps, scale_8_ps, acc0); + } + else +#endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_add_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale_8_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r2c1blk1_avx2( + const __m256i& av00_32_epi8, + const __m256i& av01_32_epi8, + const __m256i& av10_32_epi8, + const __m256i& av11_32_epi8, + const __m256i& bv0_32_epi8, + const __m256i& bv1_32_epi8, + float scale_a0b, + float scale_a1b, + __m256& acc0, + __m256& acc1 +) +{ + __m256 scale0_8_ps = _mm256_set1_ps(scale_a0b); + __m256 scale1_8_ps = _mm256_set1_ps(scale_a1b); + +#if !defined(__GNUC__) || (__GNUC__ > 10) + if constexpr (vnni) + { + __m256i sum0_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av00_32_epi8); + sum0_8_epi32 = _mm256_dpbusds_avx_epi32(sum0_8_epi32, bv1_32_epi8, av01_32_epi8); + __m256 sum0_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_ps, scale0_8_ps, acc0); + + __m256i sum1_8_epi32 = _mm256_dpbusds_avx_epi32(_mm256_setzero_si256(), bv0_32_epi8, av10_32_epi8); + sum1_8_epi32 = _mm256_dpbusds_avx_epi32(sum1_8_epi32, bv1_32_epi8, av11_32_epi8); + __m256 sum1_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_ps, scale1_8_ps, acc1); + } + else + #endif + { + // 2 x i8 x i8 may be larger than i16 + const __m256i low_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64)); + const __m256i high_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 8)); + const __m256i one_mask = _mm256_load_si256(reinterpret_cast(MasksAvx2BlkLen64 + 16)); + + const __m256i bv0_low_32_epi8 = _mm256_and_si256(bv0_32_epi8, low_mask); + const __m256i bv0_high_32_epi8 = _mm256_and_si256(bv0_32_epi8, high_mask); + const __m256i bv1_low_32_epi8 = _mm256_and_si256(bv1_32_epi8, low_mask); + const __m256i bv1_high_32_epi8 = _mm256_and_si256(bv1_32_epi8, high_mask); + + // row 0 + const __m256i dot00_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av00_32_epi8); + const __m256i dot00_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av00_32_epi8); + const __m256i dot01_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av01_32_epi8); + const __m256i dot01_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av01_32_epi8); + + const __m256i dot00_low_8_epi32 = _mm256_madd_epi16(one_mask, dot00_low_16_epi16); + const __m256i dot00_high_8_epi32 = _mm256_madd_epi16(one_mask, dot00_high_16_epi16); + const __m256i dot00_8_epi32 = _mm256_add_epi32(dot00_low_8_epi32, dot00_high_8_epi32); + + const __m256i dot01_low_8_epi32 = _mm256_madd_epi16(one_mask, dot01_low_16_epi16); + const __m256i dot01_high_8_epi32 = _mm256_madd_epi16(one_mask, dot01_high_16_epi16); + const __m256i dot01_8_epi32 = _mm256_add_epi32(dot01_low_8_epi32, dot01_high_8_epi32); + + const __m256i sum0_8_epi32 = _mm256_add_epi32(dot00_8_epi32, dot01_8_epi32); + __m256 sum0_8_ps = _mm256_cvtepi32_ps(sum0_8_epi32); + acc0 = _mm256_fmadd_ps(sum0_8_ps, scale0_8_ps, acc0); + + // row 1 + const __m256i dot10_low_16_epi16 = _mm256_maddubs_epi16(bv0_low_32_epi8, av10_32_epi8); + const __m256i dot10_high_16_epi16 = _mm256_maddubs_epi16(bv0_high_32_epi8, av10_32_epi8); + const __m256i dot11_low_16_epi16 = _mm256_maddubs_epi16(bv1_low_32_epi8, av11_32_epi8); + const __m256i dot11_high_16_epi16 = _mm256_maddubs_epi16(bv1_high_32_epi8, av11_32_epi8); + + const __m256i dot10_low_8_epi32 = _mm256_madd_epi16(one_mask, dot10_low_16_epi16); + const __m256i dot10_high_8_epi32 = _mm256_madd_epi16(one_mask, dot10_high_16_epi16); + const __m256i dot10_8_epi32 = _mm256_add_epi32(dot10_low_8_epi32, dot10_high_8_epi32); + + const __m256i dot11_low_8_epi32 = _mm256_madd_epi16(one_mask, dot11_low_16_epi16); + const __m256i dot11_high_8_epi32 = _mm256_madd_epi16(one_mask, dot11_high_16_epi16); + const __m256i dot11_8_epi32 = _mm256_add_epi32(dot11_low_8_epi32, dot11_high_8_epi32); + + const __m256i sum1_8_epi32 = _mm256_add_epi32(dot10_8_epi32, dot11_8_epi32); + __m256 sum1_8_ps = _mm256_cvtepi32_ps(sum1_8_epi32); + acc1 = _mm256_fmadd_ps(sum1_8_ps, scale1_8_ps, acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk1_avx2( @@ -212,6 +353,138 @@ Q4Int8GemmR2xC4BlkLen64Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4 * NRows2] = { + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), + _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps() + }; + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a1 = *(QuantAScalePtr + BlockCountK); + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + const float scale_a1b0 = (*QuantBScalePtr) * scale_a1; + const float scale_a0b1 = (*(QuantBScalePtr + 1)) * scale_a0; + const float scale_a1b1 = (*(QuantBScalePtr + 1)) * scale_a1; + const float scale_a0b2 = (*(QuantBScalePtr + 2)) * scale_a0; + const float scale_a1b2 = (*(QuantBScalePtr + 2)) * scale_a1; + const float scale_a0b3 = (*(QuantBScalePtr + 3)) * scale_a0; + const float scale_a1b3 = (*(QuantBScalePtr + 3)) * scale_a1; + + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256i bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + __m256i bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + __m256i bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + __m256i bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + __m256i bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + __m256i bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); + + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, scale_a1b2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, scale_a1b3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); + } + + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, scale_a1b2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, scale_a1b3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen64Avx2( @@ -292,6 +565,109 @@ Q4Int8GemmR2xC1BlkLen64Avx2( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(), acc1 = _mm256_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a1 = *(QuantAScalePtr + BlockCountK); + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + const float scale_a1b0 = (*QuantBScalePtr) * scale_a1; + + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + __m256i av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + __m256i av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + av_10_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda)); + av_11_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + lda + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + } + + accumulate_q8_blklen64_r2c1blk1_avx2(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, scale_a1b0, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + *(SumPtr + ldc) = hsum_float_8(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen64Avx2( @@ -371,6 +747,120 @@ Q4Int8GemmR1xC4BlkLen64Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc[NCols4] = {_mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps(), _mm256_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + const float scale_a0b1 = (*(QuantBScalePtr + 1)) * scale_a0; + const float scale_a0b2 = (*(QuantBScalePtr + 2)) * scale_a0; + const float scale_a0b3 = (*(QuantBScalePtr + 3)) * scale_a0; + + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + __m256i bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + __m256i bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + __m256i bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + __m256i bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + __m256i bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + __m256i bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); + + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc[0]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, acc[1]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, acc[2]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, acc[3]); + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + bv10_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes)); + bv11_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + SubblkDataSizeInBytes + 32)); + bv20_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes)); + bv21_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 2 * SubblkDataSizeInBytes + 32)); + bv30_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes)); + bv31_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 3 * SubblkDataSizeInBytes + 32)); + } + + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc[0]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv10_32_epi8, bv11_32_epi8, scale_a0b1, acc[1]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv20_32_epi8, bv21_32_epi8, scale_a0b2, acc[2]); + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv30_32_epi8, bv31_32_epi8, scale_a0b3, acc[3]); + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen64Avx2( @@ -447,6 +937,97 @@ Q4Int8GemmR1xC1BlkLen64Avx2( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 64; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m256 acc0 = _mm256_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0 = *QuantAScalePtr; + const float scale_a0b0 = (*QuantBScalePtr) * scale_a0; + + __m256i av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + __m256i av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + __m256i bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + __m256i bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + + for (size_t kk = 0; kk < PerBlkSubblkCount - 1; kk++) { + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc0); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + + av_00_epi8 = _mm256_load_si256((const __m256i*)QuantAPtr); + av_01_epi8 = _mm256_load_si256((const __m256i*)(QuantAPtr + 32)); + + bv00_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr)); + bv01_32_epi8 = _mm256_load_si256(reinterpret_cast(QuantBDataPtr + 32)); + } + + accumulate_q8_blklen64_r1c1blk1_avx2(av_00_epi8, av_01_epi8, bv00_32_epi8, bv01_32_epi8, scale_a0b0, acc0); + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen64Avx2( @@ -539,3 +1120,96 @@ MlasQ4Int8GemmKernelBlkLen64Avx2( return CountM; } + +template +MLAS_FORCEINLINE size_t +MlasQ8Int8GemmKernelBlkLen64Avx2( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen64Avx2( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen64Avx2( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx512.cpp b/src/lib/sqnbitgemm_kernel_avx512.cpp index 592b244..f917cf7 100644 --- a/src/lib/sqnbitgemm_kernel_avx512.cpp +++ b/src/lib/sqnbitgemm_kernel_avx512.cpp @@ -249,6 +249,99 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx512( return CountM; } +MLAS_FORCEINLINE +size_t +SQ8BitGemmKernel_BlkSum_CompInt8_avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen == 16) { + MlasQ8Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ8Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ8Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ8Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + void MLASCALL QuantizeARow_CompInt8_avx512( size_t BlkLen, @@ -337,9 +430,35 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, - bool has_zp_input, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 128; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + +static void +SQ8BitGemmPackQuantBDataAndBlkSum512( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { @@ -351,23 +470,27 @@ SQ4BitGemmPackQuantBDataAndBlkSum512( if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); + Q8PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); } const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() { MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum512; - d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; - d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512; d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx512; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h index d79554c..6e8cebe 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen128.h @@ -35,6 +35,13 @@ // bv1_64_epi8 = _mm512_inserti64x4(_mm512_castsi256_si512(bv0_higher), bv1_higher, 1); //} +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen128[32], 64) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; + static MLAS_FORCEINLINE void dot_accumulate_1blk( const __m512i& bv0_64_epi8, @@ -139,6 +146,117 @@ accumulate_blklen128_r2c1blk1_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen128_r1c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + float scale_a0b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); + + if constexpr (vnni) { + dot_accumulate_1blkvnni(bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, scale_a0b, acc0); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(dot00_16_epi32, dot01_16_epi32); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_set1_ps(scale_a0b), acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen128_r2c1blk1_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + float scale_a0b, + float scale_a1b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); + + if constexpr (vnni) { + dot_accumulate_1blkvnni(bv0_64_epi8, bv1_64_epi8, av00_64_epi8, av01_64_epi8, scale_a0b, acc0); + dot_accumulate_1blkvnni(bv0_64_epi8, bv1_64_epi8, av10_64_epi8, av11_64_epi8, scale_a1b, acc1); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen128 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(dot00_16_epi32, dot01_16_epi32); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_set1_ps(scale_a0b), acc0); + + // row 1 + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + const __m512i sum1_16_epi32 = _mm512_add_epi32(dot10_16_epi32, dot11_16_epi32); + + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, _mm512_set1_ps(scale_a1b), acc1); + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen128Avx512( @@ -251,6 +369,110 @@ Q4Int8GemmR2xC4BlkLen128Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + const float scale_a0b1 = (*QuantAScalePtr) * (*(QuantBScalePtr + 1)); + const float scale_a0b2 = (*QuantAScalePtr) * (*(QuantBScalePtr + 2)); + const float scale_a0b3 = (*QuantAScalePtr) * (*(QuantBScalePtr + 3)); + const float scale_a1b0 = (*(QuantAScalePtr + BlockCountK)) * (*QuantBScalePtr); + const float scale_a1b1 = (*(QuantAScalePtr + BlockCountK)) * (*(QuantBScalePtr + 1)); + const float scale_a1b2 = (*(QuantAScalePtr + BlockCountK)) * (*(QuantBScalePtr + 2)); + const float scale_a1b3 = (*(QuantAScalePtr + BlockCountK)) * (*(QuantBScalePtr + 3)); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc[0], acc[NCols4]); + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, scale_a1b1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, scale_a0b2, scale_a1b2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, scale_a0b3, scale_a1b3, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + *SumPtr = _mm512_reduce_add_ps(acc[0]); + *(SumPtr + 1) = _mm512_reduce_add_ps(acc[1]); + *(SumPtr + 2) = _mm512_reduce_add_ps(acc[2]); + *(SumPtr + 3) = _mm512_reduce_add_ps(acc[3]); + *(SumPtr + ldc) = _mm512_reduce_add_ps(acc[NCols4]); + *(SumPtr + ldc + 1) = _mm512_reduce_add_ps(acc[NCols4 + 1]); + *(SumPtr + ldc + 2) = _mm512_reduce_add_ps(acc[NCols4 + 2]); + *(SumPtr + ldc + 3) = _mm512_reduce_add_ps(acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + *SumPtr += *BiasPtr; + *(SumPtr + 1) += *(BiasPtr + 1); + *(SumPtr + 2) += *(BiasPtr + 2); + *(SumPtr + 3) += *(BiasPtr + 3); + *(SumPtr + ldc) += *BiasPtr; + *(SumPtr + ldc + 1) += *(BiasPtr + 1); + *(SumPtr + ldc + 2) += *(BiasPtr + 2); + *(SumPtr + ldc + 3) += *(BiasPtr + 3); + } + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen128Avx512( @@ -332,6 +554,89 @@ Q4Int8GemmR2xC1BlkLen128Avx512( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + const float scale_a1b0 = (*(QuantAScalePtr + BlockCountK)) * (*QuantBScalePtr); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av00_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av01_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + const __m512i av10_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av11_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + SubblkLen / 2)); + + accumulate_q8_blklen128_r2c1blk1_avx512(av00_64_epi8, av01_64_epi8, av10_64_epi8, av11_64_epi8, QuantBDataPtr, scale_a0b0, scale_a1b0, acc0, acc1); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen128Avx512( @@ -411,6 +716,90 @@ Q4Int8GemmR1xC4BlkLen128Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + const float scale_a0b1 = (*QuantAScalePtr) * (*(QuantBScalePtr + 1)); + const float scale_a0b2 = (*QuantAScalePtr) * (*(QuantBScalePtr + 2)); + const float scale_a0b3 = (*QuantAScalePtr) * (*(QuantBScalePtr + 3)); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, scale_a0b0, acc[0]); + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + SubblkDataSizeInBytes, scale_a0b1, acc[1]); + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * SubblkDataSizeInBytes, scale_a0b2, acc[2]); + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * SubblkDataSizeInBytes, scale_a0b3, acc[3]); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += NCols4 * SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr +=NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBData; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen128Avx512( @@ -487,6 +876,82 @@ Q4Int8GemmR1xC1BlkLen128Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t SubblkLen = 128; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + const size_t PerBlkSubblkCount = BlkLen / SubblkLen; + const size_t SubblkDataSizeInBytes = BlkDataSizeInBytes / PerBlkSubblkCount; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + for (size_t k = 0; k < BlockCountK; ++k) { + const float scale_a0b0 = (*QuantAScalePtr) * (*QuantBScalePtr); + + for (size_t kk = 0; kk < PerBlkSubblkCount; kk++) { + const __m512i av0_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + SubblkLen / 2)); + + accumulate_q8_blklen128_r1c1blk1_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, scale_a0b0, acc0); + + // increment block pointers + QuantAPtr += SubblkLen; + QuantBDataPtr += SubblkDataSizeInBytes; + } + QuantAScalePtr++; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen128Avx512( @@ -579,3 +1044,97 @@ MlasQ4Int8GemmKernelBlkLen128Avx512( return CountM; } + +template +MLAS_FORCEINLINE size_t +MlasQ8Int8GemmKernelBlkLen128Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen128Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h index 0306488..b720c45 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen16.h @@ -9,7 +9,14 @@ #include "sqnbitgemm_kernel_avx512_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" - +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen16[48], 64) = { + 0x00000000, 0x00000000, 0x00000004, 0x00000004, 0x00000001, 0x00000001, 0x00000005, 0x00000005, + 0x00000002, 0x00000002, 0x00000006, 0x00000006, 0x00000003, 0x00000003, 0x00000007, 0x00000007, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; static MLAS_FORCEINLINE void load_4blk_4b_packed_blklen16(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) @@ -120,6 +127,131 @@ accumulate_blklen16_r2c1blk4_avx512( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r2c1blk8_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16)); // 0044115522663377 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a0b_ps, 0); // 0123456700000000 + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); // 0x8 1x8 2x8 3x8 + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); // 4x8 5x8 6x8 7x8 + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 0000111122223333 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 4444555566667777 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + + // row 1 + const __m256 scale_a1_ps = _mm256_loadu_ps(scale_a1); + const __m256 scale_a1b_ps = _mm256_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a1b_ps, 0); + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); + + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); // 0x8 1x8 2x8 3x8 + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); // 4x8 5x8 6x8 7x8 + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); // 0000111122223333 + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); // 4444555566667777 + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + + const __m512i t11 = _mm512_unpacklo_epi64(dot10_16_epi32, dot11_16_epi32); // 0044115522663377 + const __m512i t12 = _mm512_unpackhi_epi64(dot10_16_epi32, dot11_16_epi32); // 0044115522663377 + const __m512i sum1_16_epi32 = _mm512_add_epi32(t11, t12); + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, scale_a1b_16_ps, acc1); +} + +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk8_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16)); // 0044115522663377 + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a0b_ps, 0); // 0123456700000000 + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); // 0x8 1x8 2x8 3x8 + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); // 4x8 5x8 6x8 7x8 + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 0000111122223333 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 4444555566667777 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); // 0044115522663377 + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); +} + static MLAS_FORCEINLINE void accumulate_blklen16_r1c1blk8_avx512vnni( const __m512i& av0_64_epi8, @@ -214,6 +346,36 @@ accumulate_blklen16_r2c1blk4_avx512vnni( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen16_r1c1blk8_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen16)); // 0044115522663377 + const __m256 scale_b_ps = _mm256_loadu_ps(scale_b); // 01234567 + + const __m256 scale_a0_ps = _mm256_loadu_ps(scale_a0); // 01234567 + const __m256 scale_a0b_ps = _mm256_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x8(_mm512_setzero_ps(), scale_a0b_ps, 0); // 01234567 00000000 + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000111122223333 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 4444555566667777 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0044115522663377 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); +} + template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen16Avx512( @@ -399,6 +561,152 @@ Q4Int8GemmR2xC4BlkLen16Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData8 = BlkDataSizeInBytes * PerAccuBlk8; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[NCols4]); + + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + PerAccuBlk8, acc[1]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk8, acc[NCols4 + 1]); + + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk8, acc[2]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk8, acc[NCols4 + 2]); + + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk8, acc[3]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk8, acc[NCols4 + 3]); + } else { + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk8, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk8, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk8, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += StrideQuantBData8 * NCols4; + QuantBScalePtr += PerAccuBlk8 * NCols4; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + // In A, the bytes beyond K has set to 0. + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const __m128i av1_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_a0b0 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_a1b0 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr, scale_a0b0, acc2[0]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr, scale_a1b0, acc2[NCols4]); + + const float scale_a0b1 = scale_a00 * (QuantBScalePtr + 1)[0]; + const float scale_a1b1 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a0b1, acc2[1]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_a1b1, acc2[NCols4 + 1]); + + const float scale_a0b2 = scale_a00 * (QuantBScalePtr + 2)[0]; + const float scale_a1b2 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a0b2, acc2[2]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_a1b2, acc2[NCols4 + 2]); + + const float scale_a0b3 = scale_a00 * (QuantBScalePtr + 3)[0]; + const float scale_a1b3 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a0b3, acc2[3]); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_a1b3, acc2[NCols4 + 3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2C1BlkLen16Avx512( @@ -509,6 +817,108 @@ Q4Int8GemmR2C1BlkLen16Avx512( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc1); + } else { + accumulate_q8_blklen16_r2c1blk8_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + } + + // increment block pointers + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av0_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const __m128i av1_16_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av0_16_epi8, QuantBDataPtr, scale_00, acc20); + accumulate_q8_blklen16_r1c1blk1_avx2(av1_16_epi8, QuantBDataPtr, scale_10, acc21); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen16Avx512( @@ -628,6 +1038,118 @@ Q4Int8GemmR1xC4BlkLen16Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBDataCol = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBData8 = PerAccuBlk8 * BlkDataSizeInBytes; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + PerAccuBlk8, acc[1]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk8, acc[2]); + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk8, acc[3]); + } else { + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + PerAccuBlk8, acc[1]); + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk8, acc[2]); + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData8, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk8, acc[3]); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += StrideQuantBData8 * NCols4; + QuantBScalePtr += PerAccuBlk8 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + const float scale_a00 = *QuantAScalePtr; + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, acc2[1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, acc2[2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, acc2[3]); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * StrideQuantBDataCol; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen16Avx512( @@ -719,6 +1241,94 @@ Q4Int8GemmR1xC1BlkLen16Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + + constexpr size_t PerAccuBlk8 = 8; + + const size_t lda = BlockCountK * BlkLen16; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk8; k_blks_remaining -= PerAccuBlk8) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen16_r1c1blk8_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } else { + accumulate_q8_blklen16_r1c1blk8_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen16 * PerAccuBlk8; + QuantAScalePtr += PerAccuBlk8; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk8; + QuantBScalePtr += PerAccuBlk8; + } + + __m256 acc2 = h_add_512(acc0); + while (k_blks_remaining-- > 0) { + const __m128i av_00_epi8 = _mm_lddqu_si128(reinterpret_cast(QuantAPtr)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen16_r1c1blk1_avx2(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen16; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t @@ -810,3 +1420,94 @@ MlasQ4Int8GemmKernelBlkLen16Avx512( return CountM; } + +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen16Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen16 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen16Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h index 3b1096a..f630883 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen32.h @@ -8,6 +8,15 @@ #include "sqnbitgemm_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_kernel_avx512_int8_blklen64.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen32[48], 64) = { + 0x00000000, 0x00000000, 0x00000002, 0x00000002, 0x00000000, 0x00000000, 0x00000002, 0x00000002, + 0x00000001, 0x00000001, 0x00000003, 0x00000003, 0x00000001, 0x00000001, 0x00000003, 0x00000003, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; + static MLAS_FORCEINLINE void load_4blk_4b_packed_blklen32(const std::byte* QuantBDataPtr, __m512i& bv0_64_epi8, __m512i& bv1_64_epi8) { @@ -115,6 +124,139 @@ accumulate_blklen32_r2c1blk4_avx512( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32)); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a0b_ps, 0); + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 00220022 11331133 + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 00000000 11111111 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 22222222 33333333 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 00220022 11331133 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); +} + +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk4_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + const __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + const __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 16)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32 + 32)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + const __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32)); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a0b_ps, 0); + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 00220022 11331133 + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); // 00000000 11111111 + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); // 22222222 33333333 + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi64(dot00_16_epi32, dot01_16_epi32); // 00220022 11331133 + const __m512i t02 = _mm512_unpackhi_epi64(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + + // row 1 + const __m128 scale_a1_ps = _mm_loadu_ps(scale_a1); // 0123 + const __m128 scale_a1b_ps = _mm_mul_ps(scale_b_ps, scale_a1_ps); + __m512 scale_a1b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a1b_ps, 0); + + scale_a1b_16_ps = _mm512_permutexvar_ps(idx, scale_a1b_16_ps); // 00220022 11331133 + + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); // 00000000 11111111 + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); // 22222222 33333333 + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + + const __m512i t11 = _mm512_unpacklo_epi64(dot10_16_epi32, dot11_16_epi32); // 00220022 11331133 + const __m512i t12 = _mm512_unpackhi_epi64(dot10_16_epi32, dot11_16_epi32); + const __m512i sum1_16_epi32 = _mm512_add_epi32(t11, t12); + + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, scale_a1b_16_ps, acc1); +} + static MLAS_FORCEINLINE void accumulate_blklen32_r1c1blk4_avx512vnni( const __m512i& av0_64_epi8, @@ -203,6 +345,38 @@ accumulate_blklen32_r2c1blk4_avx512vnni( } } +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk4_avx512vnni( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_loadu_si512(reinterpret_cast(QuantBDataPtr + 64)); + __m512i idx = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen32)); + + const __m128 scale_b_ps = _mm_loadu_ps(scale_b); // 0123 + + const __m128 scale_a0_ps = _mm_loadu_ps(scale_a0); // 0123 + const __m128 scale_a0b_ps = _mm_mul_ps(scale_b_ps, scale_a0_ps); + __m512 scale_a0b_16_ps = _mm512_insertf32x4(_mm512_setzero_ps(), scale_a0b_ps, 0); + + scale_a0b_16_ps = _mm512_permutexvar_ps(idx, scale_a0b_16_ps); // 0022002211331133 + + const __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv0_64_epi8, av00_64_epi8); // 0000000011111111 + const __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv1_64_epi8, av01_64_epi8); // 2222222233333333 + + const __m512i t1_16_epi32 = _mm512_unpacklo_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i t2_16_epi32 = _mm512_unpackhi_epi64(dot0_16_epi32, dot1_16_epi32); // 0022002211331133 + const __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // 0022002211331133 + const __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); + acc0 = _mm512_fmadd_ps(sum_16_ps, scale_a0b_16_ps, acc0); +} + MLAS_FORCEINLINE void accumulate_1blk_dot_avx512vnni(const __m256i& av_32_epi8, const __m256i& bv_32_epi8, const float& combined_scale, __m256& acc) { @@ -256,6 +430,44 @@ accumulate_blklen32_r2c1blk1_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r1c1blk1_avx512( + const __m256i& av00_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + __m256& acc0 +) +{ + if constexpr (vnni) { + const __m256i bv_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + } else { + accumulate_q8_blklen32_r1c1blk1_avx2(av00_32_epi8, QuantBDataPtr, combined_scale00, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen32_r2c1blk1_avx512( + const __m256i& av00_32_epi8, + const __m256i& av10_32_epi8, + const std::byte* QuantBDataPtr, + float combined_scale00, + float combined_scale10, + __m256& acc0, + __m256& acc1 +) +{ + if constexpr (vnni) { + const __m256i bv_32_epi8 = _mm256_loadu_si256(reinterpret_cast(QuantBDataPtr)); + accumulate_1blk_dot_avx512vnni(av00_32_epi8, bv_32_epi8, combined_scale00, acc0); + accumulate_1blk_dot_avx512vnni(av10_32_epi8, bv_32_epi8, combined_scale10, acc1); + } else { + accumulate_q8_blklen32_r2c1blk1_avx2(av00_32_epi8, av10_32_epi8, QuantBDataPtr, combined_scale00, combined_scale10, acc0, acc1); + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR2xC4BlkLen32Avx512( @@ -437,6 +649,142 @@ Q4Int8GemmR2xC4BlkLen32Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = PerAccuBlk4 * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[NCols4]); + + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, acc[NCols4 + 1]); + + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, acc[NCols4 + 2]); + + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, acc[NCols4 + 3]); + } else { + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk4, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk4, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk4, acc[3], acc[NCols4 + 3]); + } + + // increment block pointers + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } // k_blks_remaining + + __m256 acc2[NCols4 * NRows2] = { + h_add_512(acc[0]), + h_add_512(acc[1]), + h_add_512(acc[2]), + h_add_512(acc[3]), + h_add_512(acc[4]), + h_add_512(acc[5]), + h_add_512(acc[6]), + h_add_512(acc[7]) + }; + + for (; k_blks_remaining > 0; --k_blks_remaining) { + // load A + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0], scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc2[0], acc2[NCols4]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0], scale_11 = scale_a10 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, scale_01, scale_11, acc2[1], acc2[NCols4 + 1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0], scale_12 = scale_a10 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, scale_02, scale_12, acc2[2], acc2[NCols4 + 2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0], scale_13 = scale_a10 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, scale_03, scale_13, acc2[3], acc2[NCols4 + 3]); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } // k_blks_remaining + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + __m128 acc_r1 = FoldAccumulators(acc2[NCols4 + 0], acc2[NCols4 + 1], acc2[NCols4 + 2], acc2[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2C1BlkLen32Avx512( @@ -548,8 +896,8 @@ Q4Int8GemmR2C1BlkLen32Avx512( } template -MLAS_FORCEINLINE void -Q4Int8GemmR1xC4BlkLen32Avx512( +void MLAS_FORCEINLINE +Q8Int8GemmR2C1BlkLen32Avx512( const std::byte* QuantA, const float* QuantAScale, const std::byte* QuantBData, @@ -559,69 +907,170 @@ Q4Int8GemmR1xC4BlkLen32Avx512( size_t CountN, size_t BlockCountK, const float* Bias, - size_t ldc -) + size_t ldc) { constexpr size_t BlkLen32 = 32; - constexpr size_t BlkBitWidth4 = 4; - constexpr size_t NCols4 = 4; - [[maybe_unused]] constexpr size_t NRows2 = 2; - constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - - // process 2 blks of 64 4b weights a time + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); constexpr size_t PerAccuBlk4 = 4; const size_t lda = BlockCountK * BlkLen32; - //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); - //const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; - assert(CountM < NRows2); - assert(CountN % NCols4 == 0); + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); - for (size_t m = 0; m < CountM; m++) { + for (size_t m = 0; m < CountM; m += NRows2) { const std::byte* QuantBDataColPtr = QuantBData; const float* QuantBScaleColPtr = QuantBScale; const float* BiasPtr = Bias; - auto* SumPtr = C + m * ldc; + float* SumPtr = C + m * ldc; - for (size_t n = 0; n < CountN; n += NCols4) { + for (size_t n = 0; n < CountN; n++) { const std::byte* QuantAPtr = QuantA + m * lda; const float* QuantAScalePtr = QuantAScale + m * BlockCountK; const std::byte* QuantBDataPtr = QuantBDataColPtr; const float* QuantBScalePtr = QuantBScaleColPtr; - __m512 acc[NCols4] = { - _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() - }; + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + // process 2 blks of 64 4b weights a time for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + lda + 64)); if constexpr (vnni) { - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); - accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc1); } else { - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); - accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + accumulate_q8_blklen32_r2c1blk4_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); } + // increment block pointers QuantAPtr += BlkLen32 * PerAccuBlk4; QuantAScalePtr += PerAccuBlk4; - QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; - QuantBScalePtr += PerAccuBlk4 * NCols4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; } - __m256 acc2[NCols4] = { - h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) - }; - - while (k_blks_remaining-- > 0) { + __m256 acc20 = h_add_512(acc0); + __m256 acc21 = h_add_512(acc1); + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const __m256i av_10_epi8 = _mm256_loadu_si256((const __m256i*)(QuantAPtr + lda)); + + const float scale_a00 = *QuantAScalePtr; + const float scale_a10 = *(QuantAScalePtr + BlockCountK); + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + const float scale_10 = scale_a10 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, scale_00, scale_10, acc20, acc21); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc20); + *(SumPtr + ldc) = hsum_float_8(acc21); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + +template +MLAS_FORCEINLINE void +Q4Int8GemmR1xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + + // process 2 blks of 64 4b weights a time + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + //const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen32); + //const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } else { + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + while (k_blks_remaining-- > 0) { // load A const std::byte* QuantABlk0 = QuantAPtr; const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantABlk0); @@ -667,6 +1116,114 @@ Q4Int8GemmR1xC4BlkLen32Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } else { + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + PerAccuBlk4, acc[1]); + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 2 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk4, acc[2]); + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr + 3 * PerAccuBlk4 * BlkDataSizeInBytes16, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk4, acc[3]); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes16 * PerAccuBlk4 * NCols4; + QuantBScalePtr += PerAccuBlk4 * NCols4; + } + + __m256 acc2[NCols4] = { + h_add_512(acc[0]), h_add_512(acc[1]), h_add_512(acc[2]), h_add_512(acc[3]) + }; + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + const float scale_a00 = *QuantAScalePtr; + + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2[0]); + + const float scale_01 = scale_a00 * (QuantBScalePtr + 1)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + BlkDataSizeInBytes16, scale_01, acc2[1]); + + const float scale_02 = scale_a00 * (QuantBScalePtr + 2)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes16, scale_02, acc2[2]); + + const float scale_03 = scale_a00 * (QuantBScalePtr + 3)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes16, scale_03, acc2[3]); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes16 * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc2[0], acc2[1], acc2[2], acc2[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes16; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen32Avx512( @@ -759,6 +1316,94 @@ Q4Int8GemmR1xC1BlkLen32Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + constexpr size_t PerAccuBlk4 = 4; + + const size_t lda = BlockCountK * BlkLen32; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk4; k_blks_remaining -= PerAccuBlk4) { + const __m512i av_00_epi8 = _mm512_loadu_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_loadu_si512((const __m512i*)(QuantAPtr + 64)); + + if constexpr (vnni) { + accumulate_q8_blklen32_r1c1blk4_avx512vnni(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } else { + accumulate_q8_blklen32_r1c1blk4_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + } + + QuantAPtr += BlkLen32 * PerAccuBlk4; + QuantAScalePtr += PerAccuBlk4; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk4; + QuantBScalePtr += PerAccuBlk4; + } + + __m256 acc2 = h_add_512(acc0); + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m256i av_00_epi8 = _mm256_loadu_si256((const __m256i*)QuantAPtr); + + const float scale_a00 = *QuantAScalePtr; + const float scale_00 = scale_a00 * (QuantBScalePtr)[0]; + accumulate_q8_blklen32_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, scale_00, acc2); + + QuantAPtr += BlkLen32; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_8(acc2); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t @@ -850,3 +1495,93 @@ MlasQ4Int8GemmKernelBlkLen32Avx512( return CountM; } + +template +MLAS_FORCEINLINE +size_t +MlasQ8Int8GemmKernelBlkLen32Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen32 = 32; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen32 * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen32); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2C1BlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen32Avx512( + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h index 72ce28d..68bf1da 100644 --- a/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h +++ b/src/lib/sqnbitgemm_kernel_avx512_int8_blklen64.h @@ -6,6 +6,13 @@ #include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" +MLAS_DECLSPEC_ALIGN(static const uint32_t MasksAvx512BlkLen64[32], 64) = { + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, 0x00ff00ff, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, + 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00, 0xff00ff00 +}; + static MLAS_FORCEINLINE __m256 h_add_512(__m512 a) { @@ -125,7 +132,7 @@ dot_accumulate_2blkvnni( __m512i t1_16_epi32 = _mm512_unpacklo_epi32(dot0_16_epi32, dot1_16_epi32); __m512i t2_16_epi32 = _mm512_unpackhi_epi32(dot0_16_epi32, dot1_16_epi32); - __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // sum for blk: 0 0 1 1 0 0 1 1... + __m512i sum_16_epi32 = _mm512_add_epi32(t1_16_epi32, t2_16_epi32); // sum for blk: 0 1 0 1 0 1 0 1... __m512 sum_16_ps = _mm512_cvtepi32_ps(sum_16_epi32); __m256 scale_a_8_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a)); @@ -182,6 +189,146 @@ accumulate_blklen64_r2c1blk2_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r1c1blk2_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni(av00_64_epi8, av01_64_epi8, scale_a0, bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, acc0); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + const __m512 scale_a0_16_ps = _mm512_broadcast_f32x8(scale_a0_ps); + const __m512 scale_a0b_16_ps = _mm512_mul_ps(scale_b_16_ps, scale_a0_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi32(dot00_16_epi32, dot01_16_epi32); // 01010101 01010101 + const __m512i t02 = _mm512_unpackhi_epi32(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r2c1blk2_avx512( + const __m512i& av00_64_epi8, + const __m512i& av01_64_epi8, + const __m512i& av10_64_epi8, + const __m512i& av11_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv0_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + __m512i bv1_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr + 64)); + + const __m256 scale_b_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_b)); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x8(scale_b_ps); + + if constexpr (vnni) { + dot_accumulate_2blkvnni(av00_64_epi8, av01_64_epi8, scale_a0, bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, acc0); + dot_accumulate_2blkvnni(av10_64_epi8, av11_64_epi8, scale_a1, bv0_64_epi8, bv1_64_epi8, scale_b_16_ps, acc1); + } else { + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + const __m512i bv0_low_64_epi8 = _mm512_and_si512(bv0_64_epi8, low_mask); + const __m512i bv0_high_64_epi8 = _mm512_and_si512(bv0_64_epi8, high_mask); + const __m512i bv1_low_64_epi8 = _mm512_and_si512(bv1_64_epi8, low_mask); + const __m512i bv1_high_64_epi8 = _mm512_and_si512(bv1_64_epi8, high_mask); + const __m512i one_32_epi16 = generate_ones_32_epi16(); + + // row 0 + const __m256 scale_a0_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a0)); + const __m512 scale_a0_16_ps = _mm512_broadcast_f32x8(scale_a0_ps); + const __m512 scale_a0b_16_ps = _mm512_mul_ps(scale_b_16_ps, scale_a0_16_ps); + + const __m512i dot00_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av00_64_epi8); + const __m512i dot00_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av00_64_epi8); + const __m512i dot01_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av01_64_epi8); + const __m512i dot01_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av01_64_epi8); + + const __m512i dot00_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_low_32_epi16); + const __m512i dot00_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot00_high_32_epi16); + const __m512i dot01_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_low_32_epi16); + const __m512i dot01_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot01_high_32_epi16); + + const __m512i dot00_16_epi32 = _mm512_add_epi32(dot00_low_16_epi32, dot00_high_16_epi32); + const __m512i dot01_16_epi32 = _mm512_add_epi32(dot01_low_16_epi32, dot01_high_16_epi32); + + const __m512i t01 = _mm512_unpacklo_epi32(dot00_16_epi32, dot01_16_epi32); // 01010101 01010101 + const __m512i t02 = _mm512_unpackhi_epi32(dot00_16_epi32, dot01_16_epi32); + const __m512i sum0_16_epi32 = _mm512_add_epi32(t01, t02); + + const __m512 sum0_16_ps = _mm512_cvtepi32_ps(sum0_16_epi32); + acc0 = _mm512_fmadd_ps(sum0_16_ps, scale_a0b_16_ps, acc0); + + // row 1 + const __m256 scale_a1_ps = _mm256_castpd_ps(_mm256_broadcast_sd((double*)scale_a1)); + const __m512 scale_a1_16_ps = _mm512_broadcast_f32x8(scale_a1_ps); + const __m512 scale_a1b_16_ps = _mm512_mul_ps(scale_b_16_ps, scale_a1_16_ps); + + const __m512i dot10_low_32_epi16 = _mm512_maddubs_epi16(bv0_low_64_epi8, av10_64_epi8); + const __m512i dot10_high_32_epi16 = _mm512_maddubs_epi16(bv0_high_64_epi8, av10_64_epi8); + const __m512i dot11_low_32_epi16 = _mm512_maddubs_epi16(bv1_low_64_epi8, av11_64_epi8); + const __m512i dot11_high_32_epi16 = _mm512_maddubs_epi16(bv1_high_64_epi8, av11_64_epi8); + + const __m512i dot10_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_low_32_epi16); + const __m512i dot10_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot10_high_32_epi16); + const __m512i dot11_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_low_32_epi16); + const __m512i dot11_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot11_high_32_epi16); + + const __m512i dot10_16_epi32 = _mm512_add_epi32(dot10_low_16_epi32, dot10_high_16_epi32); + const __m512i dot11_16_epi32 = _mm512_add_epi32(dot11_low_16_epi32, dot11_high_16_epi32); + + const __m512i t11 = _mm512_unpacklo_epi32(dot10_16_epi32, dot11_16_epi32); + const __m512i t12 = _mm512_unpackhi_epi32(dot10_16_epi32, dot11_16_epi32); + const __m512i sum1_16_epi32 = _mm512_add_epi32(t11, t12); + + const __m512 sum1_16_ps = _mm512_cvtepi32_ps(sum1_16_epi32); + acc1 = _mm512_fmadd_ps(sum1_16_ps, scale_a1b_16_ps, acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk2_avx512( @@ -283,6 +430,112 @@ accumulate_blklen64_r2c1blk1_avx512( } } +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r1c1blk1_avx512( + const __m512i& av0_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_b, + __m512& acc0 +) +{ + __m512i bv_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av0_64_epi8); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } else { + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + __m512i bv_low_64_epi8 = _mm512_and_si512(bv_64_epi8, low_mask); + __m512i bv_high_64_epi8 = _mm512_and_si512(bv_64_epi8, high_mask); + + // row 0 + __m512i dot0_low_32_epi16 = _mm512_maddubs_epi16(bv_low_64_epi8, av0_64_epi8); + __m512i dot0_high_32_epi16 = _mm512_maddubs_epi16(bv_high_64_epi8, av0_64_epi8); + __m512i dot0_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_low_32_epi16); + __m512i dot0_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_high_32_epi16); + __m512i dot0_16_epi32 = _mm512_add_epi32(dot0_low_16_epi32, dot0_high_16_epi32); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + } +} + +template +static MLAS_FORCEINLINE void +accumulate_q8_blklen64_r2c1blk1_avx512( + const __m512i& av0_64_epi8, + const __m512i& av1_64_epi8, + const std::byte* QuantBDataPtr, + const float* scale_a0, + const float* scale_a1, + const float* scale_b, + __m512& acc0, + __m512& acc1 +) +{ + __m512i bv_64_epi8 = _mm512_load_si512(reinterpret_cast(QuantBDataPtr)); + const __m128 scale_b_ps = _mm_broadcast_ss(scale_b); + const __m512 scale_b_16_ps = _mm512_broadcast_f32x2(scale_b_ps); + + if constexpr (vnni) { + __m512i dot0_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av0_64_epi8); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + + __m512i dot1_16_epi32 = _mm512_dpbusd_epi32(_mm512_setzero_epi32(), bv_64_epi8, av1_64_epi8); + __m512 sum1_16_ps = _mm512_cvtepi32_ps(dot1_16_epi32); + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + acc1 = _mm512_fmadd_ps(sum1_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } else { + const __m512i one_32_epi16 = generate_ones_32_epi16(); + const __m512i low_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64)); + const __m512i high_mask = _mm512_load_si512(reinterpret_cast(MasksAvx512BlkLen64 + 16)); + __m512i bv_low_64_epi8 = _mm512_and_si512(bv_64_epi8, low_mask); + __m512i bv_high_64_epi8 = _mm512_and_si512(bv_64_epi8, high_mask); + + // row 0 + __m512i dot0_low_32_epi16 = _mm512_maddubs_epi16(bv_low_64_epi8, av0_64_epi8); + __m512i dot0_high_32_epi16 = _mm512_maddubs_epi16(bv_high_64_epi8, av0_64_epi8); + __m512i dot0_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_low_32_epi16); + __m512i dot0_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot0_high_32_epi16); + __m512i dot0_16_epi32 = _mm512_add_epi32(dot0_low_16_epi32, dot0_high_16_epi32); + __m512 sum0_16_ps = _mm512_cvtepi32_ps(dot0_16_epi32); + + __m128 scale_a0_ps = _mm_broadcast_ss(scale_a0); + __m512 scale_a0_16_ps = _mm512_broadcast_f32x2(scale_a0_ps); + + acc0 = _mm512_fmadd_ps(sum0_16_ps, _mm512_mul_ps(scale_a0_16_ps, scale_b_16_ps), acc0); + + // row 1 + __m512i dot1_low_32_epi16 = _mm512_maddubs_epi16(bv_low_64_epi8, av1_64_epi8); + __m512i dot1_high_32_epi16 = _mm512_maddubs_epi16(bv_high_64_epi8, av1_64_epi8); + __m512i dot1_low_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot1_low_32_epi16); + __m512i dot1_high_16_epi32 = _mm512_madd_epi16(one_32_epi16, dot1_high_32_epi16); + __m512i dot1_16_epi32 = _mm512_add_epi32(dot1_low_16_epi32, dot1_high_16_epi32); + __m512 sum1_16_ps = _mm512_cvtepi32_ps(dot1_16_epi32); + + __m128 scale_a1_ps = _mm_broadcast_ss(scale_a1); + __m512 scale_a1_16_ps = _mm512_broadcast_f32x2(scale_a1_ps); + + acc1 = _mm512_fmadd_ps(sum1_16_ps, _mm512_mul_ps(scale_a1_16_ps, scale_b_16_ps), acc1); + } +} + template static MLAS_FORCEINLINE void accumulate_blklen64_r1c1blk1_avx512( @@ -448,6 +701,106 @@ Q4Int8GemmR2xC4BlkLen64Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4BlkLen64Avx512( + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkLen64 = 64; + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen64); + + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen64; + const size_t StrideQuantBData = PerAccuBlk2 * BlkDataSizeInBytes; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4 * NRows2] = { + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), + _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps() + }; + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + PerAccuBlk2, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 2 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2 * PerAccuBlk2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr + 3 * StrideQuantBData, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3 * PerAccuBlk2, acc[3], acc[NCols4 + 3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += StrideQuantBData * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } // k_blks_remaining + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc[0], acc[NCols4]); + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 1, acc[1], acc[NCols4 + 1]); + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 2, acc[2], acc[NCols4 + 2]); + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr + 3, acc[3], acc[NCols4 + 3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + __m128 acc_r1 = FoldAccumulators(acc[NCols4 + 0], acc[NCols4 + 1], acc[NCols4 + 2], acc[NCols4 + 3]); + if (BiasPtr != nullptr) { + const __m128 bias_4_ps = _mm_loadu_ps(BiasPtr); + acc_r0 = _mm_add_ps(acc_r0, bias_4_ps); + acc_r1 = _mm_add_ps(acc_r1, bias_4_ps); + } + _mm_storeu_ps(SumPtr, acc_r0); + _mm_storeu_ps(SumPtr + ldc, acc_r1); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template void MLAS_FORCEINLINE Q4Int8GemmR2xC1BlkLen64Avx512( @@ -540,6 +893,95 @@ Q4Int8GemmR2xC1BlkLen64Avx512( } } +template +void MLAS_FORCEINLINE +Q8Int8GemmR2xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM % NRows2 == 0); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m += NRows2) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + float* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(), acc1 = _mm512_setzero_ps(); + + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + const __m512i av_11_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda + 64)); + + accumulate_q8_blklen64_r2c1blk2_avx512(av_00_epi8, av_01_epi8, av_10_epi8, av_11_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_10_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + lda)); + + accumulate_q8_blklen64_r2c1blk1_avx512(av_00_epi8, av_10_epi8, QuantBDataPtr, QuantAScalePtr, QuantAScalePtr + BlockCountK, QuantBScalePtr, acc0, acc1); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + *(SumPtr + ldc) = hsum_float_16(acc1); + if (BiasPtr) { + *SumPtr += *BiasPtr; + *(SumPtr + ldc) += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC4BlkLen64Avx512( @@ -633,6 +1075,94 @@ Q4Int8GemmR1xC4BlkLen64Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + + assert(CountM < NRows2); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc[NCols4] = {_mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps(), _mm512_setzero_ps()}; + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining >= PerAccuBlk2; k_blks_remaining -= PerAccuBlk2) { + const __m512i av0_64_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av1_64_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + PerAccuBlk2, acc[1]); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 2 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2 * PerAccuBlk2, acc[2]); + accumulate_q8_blklen64_r1c1blk2_avx512(av0_64_epi8, av1_64_epi8, QuantBDataPtr + 3 * PerAccuBlk2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3 * PerAccuBlk2, acc[3]); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBDataPtr += PerAccuBlk2 * BlkDataSizeInBytes * NCols4; + QuantBScalePtr += PerAccuBlk2 * NCols4; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc[0]); + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 1, acc[1]); + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 2 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 2, acc[2]); + accumulate_q8_blklen64_r1c1blk1_avx512(av_epi8, QuantBDataPtr + 3 * BlkDataSizeInBytes, QuantAScalePtr, QuantBScalePtr + 3, acc[3]); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes * NCols4; + QuantBScalePtr += NCols4; + } + + __m128 acc_r0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + if (BiasPtr != nullptr) { + acc_r0 = _mm_add_ps(acc_r0, _mm_loadu_ps(BiasPtr)); + } + + _mm_storeu_ps(SumPtr, acc_r0); + + // move to next NCols columns + QuantBDataColPtr += NCols4 * BlockCountK * BlkDataSizeInBytes; + QuantBScaleColPtr += NCols4 * BlockCountK; + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + template MLAS_FORCEINLINE void Q4Int8GemmR1xC1BlkLen64Avx512( @@ -718,6 +1248,88 @@ Q4Int8GemmR1xC1BlkLen64Avx512( } } +template +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1BlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + [[maybe_unused]] constexpr size_t NCols4 = 4; + [[maybe_unused]] constexpr size_t NRows2 = 2; + constexpr size_t BlkLen64 = 64; + const size_t BlkDataSizeInBytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + constexpr size_t PerAccuBlk2 = 2; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBData = BlockCountK * BlkDataSizeInBytes; + const size_t StrideQuantBScale = BlockCountK; + + assert(CountM < NRows2); + assert(CountN < NCols4); + + for (size_t m = 0; m < CountM; m++) { + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n++) { + const std::byte* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + const std::byte* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + __m512 acc0 = _mm512_setzero_ps(); + size_t k_blks_remaining = BlockCountK; + for (; k_blks_remaining > 1; k_blks_remaining -= PerAccuBlk2) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + const __m512i av_01_epi8 = _mm512_load_si512((const __m512i*)(QuantAPtr + 64)); + + accumulate_q8_blklen64_r1c1blk2_avx512(av_00_epi8, av_01_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + // increment block pointers + QuantAPtr += BlkLen64 * PerAccuBlk2; + QuantBDataPtr += BlkDataSizeInBytes * PerAccuBlk2; + QuantAScalePtr += PerAccuBlk2; + QuantBScalePtr += PerAccuBlk2; + } + + for (; k_blks_remaining > 0; --k_blks_remaining) { + const __m512i av_00_epi8 = _mm512_load_si512((const __m512i*)QuantAPtr); + + accumulate_q8_blklen64_r1c1blk1_avx512(av_00_epi8, QuantBDataPtr, QuantAScalePtr, QuantBScalePtr, acc0); + + QuantAPtr += BlkLen64; + QuantAScalePtr++; + QuantBDataPtr += BlkDataSizeInBytes; + QuantBScalePtr++; + } + + *SumPtr = hsum_float_16(acc0); + if (BiasPtr) { + *SumPtr += *BiasPtr; + } + + // move to next column + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } + } +} + template MLAS_FORCEINLINE size_t MlasQ4Int8GemmKernelBlkLen64Avx512( @@ -838,3 +1450,96 @@ MlasQ4Int8GemmKernelBlkLen64Avx512( return CountM; } + +template +MLAS_FORCEINLINE size_t +MlasQ8Int8GemmKernelBlkLen64Avx512( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + + const size_t lda = BlockCountK * BlkLen * sizeof(int8_t); + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t remainingCols = CountN % NCols4; + size_t multipleCols = CountN - remainingCols; + + if (multipleRows > 0 && multipleCols > 0) { + Q8Int8GemmR2xC4BlkLen64Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols, + BlockCountK, + Bias, + ldc + ); + } + + if (remainingCols > 0 && multipleRows > 0) { + Q8Int8GemmR2xC1BlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleCols, + multipleRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + if (remainingRows > 0 && multipleCols > 0) { + Q8Int8GemmR1xC4BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols, + BlockCountK, + Bias, + ldc); + } + + if (remainingCols > 0 && remainingRows > 0) { + Q8Int8GemmR1xC1BlkLen64Avx512( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols * StrideQuantBData, + QuantBScale + multipleCols * StrideQuantBScale, + C + multipleRows * ldc + multipleCols, + remainingRows, + remainingCols, + BlockCountK, + Bias ? Bias + multipleCols : nullptr, + ldc); + } + + return CountM; +} diff --git a/src/lib/sqnbitgemm_kernel_avx512vnni.cpp b/src/lib/sqnbitgemm_kernel_avx512vnni.cpp index a4468bb..ea5eebd 100644 --- a/src/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/src/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -299,6 +299,99 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni( return CountM; } +MLAS_FORCEINLINE +size_t +SQ8BitGemmKernel_BlkSum_CompInt8_avx512vnni( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t /*CountK*/, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum +) +{ + if (BlkLen == 16) { + MlasQ8Int8GemmKernelBlkLen16Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 32) { + MlasQ8Int8GemmKernelBlkLen32Avx512( + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else if (BlkLen == 64) { + MlasQ8Int8GemmKernelBlkLen64Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } else { + MlasQ8Int8GemmKernelBlkLen128Avx512( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + CountM, + CountN, + BlockCountK, + Bias, + ldc + ); + } + + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = GetMlasPlatform().GemmFloatKernel( + a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false + ); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + return CountM; +} + void MLASCALL QuantizeARow_CompInt8_avx512( size_t BlkLen, @@ -317,9 +410,35 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, - bool has_zp_input, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 128; + } + PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + +static void +SQ8BitGemmPackQuantBDataAndBlkSum512vnni( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { @@ -331,7 +450,8 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( if (ComputeType == SQNBIT_CompInt8) { SubBlkLen = 128; } - PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, has_zp_input, QuantBZPBegin, packed_quant_b, ThreadPool); + Q8PackQuantBDataAndBlkSum(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, + HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); } // @@ -340,17 +460,20 @@ SQ4BitGemmPackQuantBDataAndBlkSum512vnni( const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() { MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPackQuantBDataSize = Q4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<4>; + d.Q8BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize<8>; d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum512vnni; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum512vnni; - d.Q4BitGemmPerGemmWorkspaceSize = Q4BitGemmPerGemmWorkspaceSize; - d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment; + d.QNBitGemmPerGemmWorkspaceSize = QNBitGemmPerGemmWorkspaceSize; + d.QNBitGemmPerGemmWorkspaceAlignment = QNBitGemmPerGemmWorkspaceAlignment; d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32; d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2; d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = SQ8BitGemmKernel_BlkSum_CompInt8_avx512vnni; d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512; return d; diff --git a/src/lib/sqnbitgemm_kernel_avx_common.h b/src/lib/sqnbitgemm_kernel_avx_common.h index b0367b7..bb38f37 100644 --- a/src/lib/sqnbitgemm_kernel_avx_common.h +++ b/src/lib/sqnbitgemm_kernel_avx_common.h @@ -6,23 +6,24 @@ // Quantized B data packing function implementation. // +template static size_t -Q4BitGemmPackQuantBDataSize( +QNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, + bool /* HasZeroPoint */, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { - constexpr size_t BlkBitWidth = 4; const size_t BlockCountK = MlasDivRoundup(K, BlkLen); if (ComputeType == SQNBIT_CompInt8) { size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); const size_t ScaleSize = N * BlockCountK * sizeof(float); size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); - // _mm256_load_si256 requires alignment on a 32-byte boundary - constexpr size_t PackedQuantBDataAlignment = 32; + // avx512 requires alignment on a 64-byte boundary + constexpr size_t PackedQuantBDataAlignment = 64; PackedQuantBDataSize += PackedQuantBDataAlignment - 1; constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); BlkSumSize += BlkSumAlignment - 1; @@ -246,6 +247,60 @@ PackQuantB( ); } +static void +Q8PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen) +{ + constexpr size_t BlkBitWidth = 8; + const size_t StrideN = BlockCountK * BlkLen; + const size_t BlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t SubBlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen); + const size_t SubBlkCountK = MlasDivRoundup(StrideN, SubBlkLen); + const size_t RemainderBlockCountK = BlockCountK % (SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + // SubBlkLen rows x 4 columns pack together, then remainder BlkLen x 4 columns if SubBlkLen > BlkLen. + // remainder columns keep the original order. + // SubBlkLen >= 16 and is multiple of 16 + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t c = tid / SubBlkCountK; + const size_t c_4 = c & (~3), c_res = c & 3; + const size_t r_subblk = tid % SubBlkCountK; + + const std::byte* src = QuantBDataBegin + c * StrideN + r_subblk * SubBlkLen; + + if (c_4 + 4 <= N) { // full 4 cols + if (RemainderBlockCountK && r_subblk == SubBlkCountK - 1) { // remainder blocks + std::byte* dest = + PackedQuantBDataBegin + c_4 * StrideN + r_subblk * SubBlkSize * 4 + c_res * BlkSize; + for (size_t i = 0; i < RemainderBlockCountK; i++) { + std::copy(src, src + BlkSize, dest); + src += BlkSize; + dest += BlkSize * 4; + } + } else { // full subblock + std::byte* dest = + PackedQuantBDataBegin + c_4 * StrideN + r_subblk * SubBlkSize * 4 + c_res * SubBlkSize; + std::copy(src, src + SubBlkSize, dest); + } + } else { // remainder cols + std::byte* dest = + PackedQuantBDataBegin + c * StrideN + r_subblk * SubBlkSize; + std::copy(src, src + std::min(SubBlkSize, StrideN - r_subblk * SubBlkSize), dest); + } + } + ); +} + //#include static void @@ -294,6 +349,61 @@ ComputePackBlkSum( ); } +static void +Q8ComputePackBlkSum( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK) +{ + std::vector QuantBScaleBeginCopy(N * BlockCountK); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t n_4 = n & (~3), n_res = n & 3; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; + uint8_t zp = 128; + if (QuantBZPBegin) { + const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset; + zp = (uint8_t)(*QuantBZP); + } + + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; + + // re-arrange scale to the same order as packed data + if (n_4 + 4 > N) { + *(QuantBScaleBegin + n * BlockCountK + k_blk) = QuantBScale; + } else if (BlkLen >= SubBlkLen) { + *(QuantBScaleBegin + n_4 * BlockCountK + k_blk * 4 + n_res) = QuantBScale; + } else { + size_t blks_per_sub = SubBlkLen / BlkLen; + size_t remainder_blk = BlockCountK % blks_per_sub; + size_t sub_blk_count_k = MlasDivRoundup(BlockCountK, blks_per_sub); + size_t k_subblk = k_blk / blks_per_sub; + size_t k_blk_res = k_blk % blks_per_sub; + size_t dest_offset; + + if (remainder_blk && k_subblk == sub_blk_count_k - 1) { // remainder blocks + dest_offset = n_4 * BlockCountK + k_blk * 4 + n_res; + } else { // full subblock + dest_offset = n_4 * BlockCountK + k_subblk * blks_per_sub * 4 + n_res * blks_per_sub + k_blk_res; + } + + *(QuantBScaleBegin + dest_offset) = QuantBScale; + } + }); +} + static void PackQuantBDataAndBlkSum( size_t N, @@ -302,22 +412,49 @@ PackQuantBDataAndBlkSum( size_t SubBlkLen, const std::byte* QuantBDataBegin, const float* QuantBScaleBegin, - bool has_zp_input, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + if (QuantBDataBegin) { + PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + ComputePackBlkSum(BlkLen, SubBlkLen, N, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, ThreadPool, BlockCountK); + } +} + +static void +Q8PackQuantBDataAndBlkSum( + size_t N, + size_t BlockCountK, + size_t BlkLen, + size_t SubBlkLen, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, const std::byte* QuantBZPBegin, - PackedQuantBDataStruct& packed_quant_b, + PackedQuantBDataStruct& PackedQuantB, MLAS_THREADPOOL* ThreadPool ) { if (QuantBDataBegin) { - PackQuantB(QuantBDataBegin, packed_quant_b.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + Q8PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); } if (QuantBScaleBegin) { - std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, packed_quant_b.PackedQuantBScale); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); } - if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { - ComputePackBlkSum(BlkLen, SubBlkLen, N, packed_quant_b.PackedQuantBScale, QuantBZPBegin, packed_quant_b.QuantBBlkSum, ThreadPool, BlockCountK); + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + Q8ComputePackBlkSum(BlkLen, SubBlkLen, N, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, ThreadPool, BlockCountK); } } @@ -326,11 +463,12 @@ PackQuantBDataAndBlkSum( // static size_t -Q4BitGemmPerGemmWorkspaceSize( +QNBitGemmPerGemmWorkspaceSize( size_t M, size_t N, size_t K, size_t BlkLen, + bool /* HasZeroPoint */, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { @@ -351,7 +489,7 @@ Q4BitGemmPerGemmWorkspaceSize( } static size_t -Q4BitGemmPerGemmWorkspaceAlignment( +QNBitGemmPerGemmWorkspaceAlignment( size_t BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) diff --git a/src/lib/sqnbitgemm_kernel_neon_int8.cpp b/src/lib/sqnbitgemm_kernel_neon_int8.cpp index 73beb06..8dbd339 100644 --- a/src/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/src/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -1,7 +1,6 @@ /*++ Copyright (c) Microsoft Corporation. All rights reserved. - Licensed under the MIT License. Module Name: @@ -20,11 +19,17 @@ Module Name: #include #include +#include #include "qnbitgemm.h" #include "qnbitgemm_kernel_neon.h" #include "sqnbitgemm_q8_block.h" +#ifdef USE_KLEIDIAI +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_ukernel_interface.h" +#endif + namespace sqnbitgemm_neon { @@ -126,6 +131,41 @@ QuantizeBlock( } // namespace +bool +UsePacked_CompInt8(size_t K, size_t BlkLen, bool HasZp) +{ + return UseKleidiAI(K, BlkLen, HasZp); +} + +#ifdef USE_KLEIDIAI +void +QuantizeA_Packed_CompInt8( + size_t, + const float* A, + size_t CountM, + size_t CountK, + std::byte* QuantA +) +{ + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = + CountM == 1? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); + + const size_t mr = ukernel.get_mr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + + const size_t src_stride = CountK * sizeof(float); + const size_t lhs_offset = kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(0, src_stride); + const size_t lhs_packed_offset = kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32( + 0, CountK, mr, kr, sr); + + const float* src_ptr = reinterpret_cast(reinterpret_cast(A) + lhs_offset); + void* dst_ptr = QuantA + lhs_packed_offset; + + kai_run_lhs_quant_pack_qai8dxp_f32(CountM, CountK, mr, kr, sr, 0, src_ptr, src_stride, dst_ptr); +} +#endif + void QuantizeARow_CompInt8( size_t BlkLen, @@ -1399,4 +1439,47 @@ SQ4BitGemmKernel_CompInt8( return CountM; } +#ifdef USE_KLEIDIAI +void +SQ4BitGemmKernel_Packed_CompInt8( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* PackedQuantBData, + float* C, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN, + size_t CountK, + size_t ldc, + const float* Bias +) +{ + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel ukernel = + RangeCountM == 1 && RangeStartM == 0? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); + + const size_t dst_stride = ldc * sizeof(float); + + const size_t lhs_packed_offset = ukernel.get_lhs_packed_offset(RangeStartM, CountK); + const size_t rhs_packed_offset = ukernel.get_rhs_packed_offset(RangeStartN, CountK, BlkLen); + const size_t dst_offset = ukernel.get_dst_offset(RangeStartM, RangeStartN, dst_stride); + + const void* lhs_ptr = QuantA + lhs_packed_offset; + const void* rhs_ptr = PackedQuantBData + rhs_packed_offset; + float* dst_ptr = reinterpret_cast(reinterpret_cast(C) + dst_offset); + + ukernel.run_matmul( + RangeCountM, RangeCountN, CountK, BlkLen, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max()); + + if (Bias != nullptr) { + for (size_t m = RangeStartM; m < RangeStartM + RangeCountM; m++) { + for (size_t n = RangeStartN; n < RangeStartN + RangeCountN; n++) { + C[m * ldc + n] += Bias[n]; + } + } + } +} +#endif + } // namespace sqnbitgemm_neon diff --git a/src/lib/sqnbitgemm_q8_block.h b/src/lib/sqnbitgemm_q8_block.h index 80af2f4..9dc217c 100644 --- a/src/lib/sqnbitgemm_q8_block.h +++ b/src/lib/sqnbitgemm_q8_block.h @@ -66,5 +66,5 @@ MLAS_FORCEINLINE constexpr size_t Q8BlkAlignment() { - return alignof(float); + return 16 * alignof(float); } diff --git a/src/lib/tanh.cpp b/src/lib/tanh.cpp index 9750337..63bd744 100644 --- a/src/lib/tanh.cpp +++ b/src/lib/tanh.cpp @@ -21,6 +21,7 @@ Module Name: --*/ #include "mlasi.h" +#include "softmax.h" // // Bundles the floating point constants for use by kernels written in assembly. @@ -119,9 +120,9 @@ Return Value: float Value = *Input++; - // This odd two-step process exists to ensure an input value of NaN carries through - // without modification because "std::min" and "std::max" return unreliable results - // when NaNs are involved, and it's clear from the test's reference outputs that + // This odd two-step process exists to ensure an input value of NaN carries through + // without modification because "std::min" and "std::max" return unreliable results + // when NaNs are involved, and it's clear from the test's reference outputs that // they want a NaN on output whenever the input is a NaN. float v_tmp; v_tmp = (Value < MlasTanhConstants.LowerRange) ? MlasTanhConstants.LowerRange : Value; @@ -149,9 +150,10 @@ Return Value: } } +template <> void MLASCALL -MlasComputeTanh( +MlasComputeTanh( const float* Input, float* Output, size_t N @@ -182,3 +184,50 @@ Return Value: MlasTanhKernel(Input, Output, N); #endif } + +template <> +void +MLASCALL +MlasComputeTanh( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N +) { + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + if (dispatch == nullptr || dispatch->Tanh_Fp16 == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Tanh_Fp16 is not supported."); + } + dispatch->Tanh_Fp16(Input, Output, N); +} + +template <> +void +MLASCALL +MlasComputeSoftcap( + const float* Input, + float* Output, + size_t N, + float cap +) { + for (size_t i = 0; i < N; i++) { + Output[i] = Input[i] / cap; + Output[i] = std::tanh(Output[i]); + Output[i] = Output[i] * cap; + } +} + +template <> +void +MLASCALL +MlasComputeSoftcap( + const MLAS_FP16* Input, + MLAS_FP16* Output, + size_t N, + MLAS_FP16 cap +) { + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + if (dispatch == nullptr || dispatch->Softcap_Fp16 == nullptr) { + MLAS_THROW_EX(std::runtime_error, "Softcap_Fp16 is not supported."); + } + dispatch->Softcap_Fp16(Input, Output, N, cap); +} diff --git a/src/lib/transpose.cpp b/src/lib/transpose.cpp index a758a0e..61c3796 100644 --- a/src/lib/transpose.cpp +++ b/src/lib/transpose.cpp @@ -16,6 +16,20 @@ Module Name: #include "mlasi.h" +// +// Define the parameters to execute segments of a transpose operation on worker +// threads. +// + +template +struct MLAS_TRANPOSE_WORK_BLOCK { + ptrdiff_t ThreadCountM; + const ElementType* Input; + ElementType* Output; + size_t M; + size_t N; +}; + #if defined(MLAS_SSE2_INTRINSICS) MLAS_FORCEINLINE @@ -470,20 +484,20 @@ MlasTranspose8x8Block( __m128i c3 = __lsx_vilvh_h(b3, b2); __m128 d0 = (__m128)(__lsx_vilvl_w(c2, c0)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 0], 0), __lsx_vpickve2gr_d(d0, 0), 0), (__m128i *)&Output[OutputStride * 0], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 1], 0), __lsx_vpickve2gr_d(d0, 1), 0), (__m128i *)&Output[OutputStride * 1], 0); + __lsx_vstelm_d(d0, &Output[OutputStride * 0], 0, 0); + __lsx_vstelm_d(d0, &Output[OutputStride * 1], 0, 1); __m128 d1 = (__m128)(__lsx_vilvh_w(c2, c0)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 2], 0), __lsx_vpickve2gr_d(d1, 0), 0), (__m128i *)&Output[OutputStride * 2], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 3], 0), __lsx_vpickve2gr_d(d1, 1), 0), (__m128i *)&Output[OutputStride * 3], 0); + __lsx_vstelm_d(d1, &Output[OutputStride * 2], 0, 0); + __lsx_vstelm_d(d1, &Output[OutputStride * 3], 0, 1); __m128 d2 = (__m128)(__lsx_vilvl_w(c3, c1)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 4], 0), __lsx_vpickve2gr_d(d2, 0), 0), (__m128i *)&Output[OutputStride * 4], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 5], 0), __lsx_vpickve2gr_d(d2, 1), 0), (__m128i *)&Output[OutputStride * 5], 0); + __lsx_vstelm_d(d2, &Output[OutputStride * 4], 0, 0); + __lsx_vstelm_d(d2, &Output[OutputStride * 5], 0, 1); __m128 d3 = (__m128)(__lsx_vilvh_w(c3, c1)); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 6], 0), __lsx_vpickve2gr_d(d3, 0), 0), (__m128i *)&Output[OutputStride * 6], 0); - __lsx_vst(__lsx_vinsgr2vr_d(__lsx_vld((__m128i *)&Output[OutputStride * 7], 0), __lsx_vpickve2gr_d(d3, 1), 0), (__m128i *)&Output[OutputStride * 7], 0); + __lsx_vstelm_d(d3, &Output[OutputStride * 6], 0, 0); + __lsx_vstelm_d(d3, &Output[OutputStride * 7], 0, 1); } #endif @@ -541,51 +555,69 @@ MlasTranspose8xNVector( MlasTranspose4xNVector(&Input[InputStride * 4], InputStride, &Output[OutputStride * 4], OutputStride); } +template void -MLASCALL -MlasTranspose( - const uint32_t* Input, - uint32_t* Output, - size_t M, - size_t N - ) +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId +); /*++ Routine Description: - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). + This routine is invoked from a worker thread to execute a segment of a transpose Arguments: - Input - Supplies the input buffer. - - Output - Supplies the output buffer. - - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. + Context - Supplies the pointer to the context for the threaded operation. - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. + ThreadId - Supplies the current index of the threaded operation. Return Value: None. --*/ + +template<> +void +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId + ) { - size_t n = N; + const auto* WorkBlock = (MLAS_TRANPOSE_WORK_BLOCK*)Context; + + // + // Partition the operation along the M dimension. + // + + size_t IndexM; + size_t CountM; + MlasPartitionWork(ThreadId, WorkBlock->ThreadCountM, WorkBlock->M, &IndexM, &CountM); + + // + // Set transpose parameters. + // + + const size_t M = WorkBlock->M; + const size_t N = WorkBlock->N; + + const uint32_t* Input = WorkBlock->Input + IndexM * N; + uint32_t* Output = WorkBlock->Output + IndexM; // // Transpose elements from the input matrix to the output matrix 4 columns // at a time. // + size_t n = N; + while (n >= 4) { const uint32_t* s = Input; uint32_t* d = Output; - size_t m = M; + size_t m = CountM; #if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) || \ defined(MLAS_LSX_INTRINSICS) @@ -624,7 +656,7 @@ Return Value: const uint32_t* s = Input; uint32_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 4) { @@ -650,68 +682,45 @@ Return Value: } } +template<> void -MLASCALL -MlasTranspose( - const float* Input, - float* Output, - size_t M, - size_t N +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId ) { - MlasTranspose( - reinterpret_cast(Input), - reinterpret_cast(Output), - M, - N); -} - - -void -MLASCALL -MlasTranspose( - const uint16_t* Input, - uint16_t* Output, - size_t M, - size_t N - ) -/*++ - -Routine Description: - - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). - -Arguments: - - Input - Supplies the input buffer. + const auto* WorkBlock = (MLAS_TRANPOSE_WORK_BLOCK*)Context; - Output - Supplies the output buffer. - - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. + // + // Partition the operation along the M dimension. + // - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. + size_t IndexM; + size_t CountM; + MlasPartitionWork(ThreadId, WorkBlock->ThreadCountM, WorkBlock->M, &IndexM, &CountM); -Return Value: + // + // Set transpose parameters. + // - None. + const size_t M = WorkBlock->M; + const size_t N = WorkBlock->N; ---*/ -{ - size_t n = N; + const uint16_t* Input = WorkBlock->Input + IndexM * N; + uint16_t* Output = WorkBlock->Output + IndexM; // // Transpose elements from the input matrix to the output matrix 4 columns // at a time. // + size_t n = N; + while (n >= 4) { const uint16_t* s = Input; uint16_t* d = Output; - size_t m = M; + size_t m = CountM; #if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) @@ -749,7 +758,7 @@ Return Value: const uint16_t* s = Input; uint16_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 4) { @@ -775,52 +784,46 @@ Return Value: } } - +template<> void -MLASCALL -MlasTranspose( - const uint8_t* Input, - uint8_t* Output, - size_t M, - size_t N +MlasTransposeThreaded( + void* Context, + ptrdiff_t ThreadId ) -/*++ - -Routine Description: - - This routine transposes the input matrix (M rows by N columns) to the - output matrix (N rows by M columns). - -Arguments: - - Input - Supplies the input buffer. - - Output - Supplies the output buffer. +{ + const auto* WorkBlock = (MLAS_TRANPOSE_WORK_BLOCK*)Context; - M - Supplies the number of rows for the input matrix and the number of - columns for the output matrix. + // + // Partition the operation along the M dimension. + // - N - Supplies the number of columns for the input matrix and the number of - rows for the output matrix. + size_t IndexM; + size_t CountM; + MlasPartitionWork(ThreadId, WorkBlock->ThreadCountM, WorkBlock->M, &IndexM, &CountM); -Return Value: + // + // Set transpose parameters. + // - None. + const size_t M = WorkBlock->M; + const size_t N = WorkBlock->N; ---*/ -{ - size_t n = N; + const uint8_t* Input = WorkBlock->Input + IndexM * N; + uint8_t* Output = WorkBlock->Output + IndexM; // // Transpose elements from the input matrix to the output matrix 8 columns // at a time. // + + size_t n = N; + #if defined(MLAS_TARGET_POWER) while (n >= 16) { const uint8_t* s = Input; uint8_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 16) { MlasTranspose16x16Block(s, N, d, M); @@ -848,7 +851,7 @@ Return Value: const uint8_t* s = Input; uint8_t* d = Output; - size_t m = M; + size_t m = CountM; #if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_LSX_INTRINSICS) @@ -886,7 +889,7 @@ Return Value: const uint8_t* s = Input; uint8_t* d = Output; - size_t m = M; + size_t m = CountM; while (m >= 8) { @@ -912,17 +915,140 @@ Return Value: } } +template void MLASCALL MlasTranspose( + const DataType* Input, + DataType* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ) +/*++ + +Routine Description: + + This routine transposes the input matrix (M rows by N columns) to the + output matrix (N rows by M columns). + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + M - Supplies the number of rows for the input matrix and the number of + columns for the output matrix. + + N - Supplies the number of columns for the input matrix and the number of + rows for the output matrix. + + ThreadPool - Supplies the thread pool object to use, else nullptr if the + base library threading support should be used. + +Return Value: + + None. + +--*/ +{ + MLAS_TRANPOSE_WORK_BLOCK WorkBlock; + + // + // Capture the transpose parameters to the work block. + // + + WorkBlock.Input = Input; + WorkBlock.Output = Output; + WorkBlock.M = M; + WorkBlock.N = N; + + // + // Compute the number of target threads given the complexity of the transpose + // operation. Limit the number of threads to the number of rows and try to + // keep each thread processing a minimum number of elements before using + // another thread. + // + + ptrdiff_t ThreadCountM = MlasGetMaximumThreadCount(ThreadPool); + + if (size_t(ThreadCountM) > M) { + ThreadCountM = ptrdiff_t(M); + } + + WorkBlock.ThreadCountM = ThreadCountM; + + MlasExecuteThreaded(MlasTransposeThreaded, &WorkBlock, ThreadCountM, ThreadPool); +} + +template +void +MLASCALL +MlasTranspose( + const uint32_t* Input, + uint32_t* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ); + +template +void +MLASCALL +MlasTranspose( + const uint16_t* Input, + uint16_t* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ); + +template +void +MLASCALL +MlasTranspose( + const uint8_t* Input, + uint8_t* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ); + +template<> +void +MLASCALL +MlasTranspose( const int8_t* Input, int8_t* Output, size_t M, - size_t N) + size_t N, + MLAS_THREADPOOL* ThreadPool + ) { MlasTranspose( reinterpret_cast(Input), reinterpret_cast(Output), M, - N); + N, + ThreadPool); +} + +template<> +void +MLASCALL +MlasTranspose( + const float* Input, + float* Output, + size_t M, + size_t N, + MLAS_THREADPOOL* ThreadPool + ) +{ + MlasTranspose( + reinterpret_cast(Input), + reinterpret_cast(Output), + M, + N, + ThreadPool); } diff --git a/src/lib/x86_64/ConvSymKernelAvx2.S b/src/lib/x86_64/ConvSymKernelAvx2.S index 3004599..194f210 100644 --- a/src/lib/x86_64/ConvSymKernelAvx2.S +++ b/src/lib/x86_64/ConvSymKernelAvx2.S @@ -23,6 +23,91 @@ Abstract: .intel_syntax noprefix + .extern CheckSaturationForVPMADDUBSW + + .macro CheckSaturation VecReg1Num, VecReg2Num + +// +// Save all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11) +// + + push rax + push rcx + push rdx + push rsi + push rdi + push r8 + push r9 + push r10 + push r11 + + sub rsp, 512 # reserve space for 16 YMM registers (32 bytes) + +// +// Save YMM registers (YMM0 to YMM15) +// + + vmovdqu [rsp], ymm0 + vmovdqu [rsp+32], ymm1 + vmovdqu [rsp+64], ymm2 + vmovdqu [rsp+96], ymm3 + vmovdqu [rsp+128], ymm4 + vmovdqu [rsp+160], ymm5 + vmovdqu [rsp+192], ymm6 + vmovdqu [rsp+224], ymm7 + vmovdqu [rsp+256], ymm8 + vmovdqu [rsp+288], ymm9 + vmovdqu [rsp+320], ymm10 + vmovdqu [rsp+352], ymm11 + vmovdqu [rsp+384], ymm12 + vmovdqu [rsp+416], ymm13 + vmovdqu [rsp+448], ymm14 + vmovdqu [rsp+480], ymm15 + + lea rdi, [rsp+32*\VecReg1Num\()] # first operand (unsigned) + lea rsi, [rsp+32*\VecReg2Num\()] # second operand (signed) + + call CheckSaturationForVPMADDUBSW + +// +// Restore YMM registers +// + + vmovdqu ymm0, [rsp] + vmovdqu ymm1, [rsp+32] + vmovdqu ymm2, [rsp+64] + vmovdqu ymm3, [rsp+96] + vmovdqu ymm4, [rsp+128] + vmovdqu ymm5, [rsp+160] + vmovdqu ymm6, [rsp+192] + vmovdqu ymm7, [rsp+224] + vmovdqu ymm8, [rsp+256] + vmovdqu ymm9, [rsp+288] + vmovdqu ymm10, [rsp+320] + vmovdqu ymm11, [rsp+352] + vmovdqu ymm12, [rsp+384] + vmovdqu ymm13, [rsp+416] + vmovdqu ymm14, [rsp+448] + vmovdqu ymm15, [rsp+480] + + add rsp, 512 # clean up the reserved stack space + +// +// Restore all caller-saved registers (RAX, RCX, RDX, RSI, RDI, R8, R9, R10, R11) +// + + pop r11 + pop r10 + pop r9 + pop r8 + pop rdi + pop rsi + pop rdx + pop rcx + pop rax + + .endm + /*++ Macro Description: @@ -52,9 +137,15 @@ Implicit Arguments: .macro MultiplyAccumulateRowAvx2 Vec1Reg, Vec2Reg +#if defined(ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + CheckSaturation 2,0 +#endif vpmaddubsw ymm3,ymm2,ymm0 vpmaddwd ymm3,ymm3,ymm12 vpaddd \Vec1Reg\(),\Vec1Reg\(),ymm3 +#if defined(ENABLE_CONVSYMKERNELAVX2_SAT_CHECKER) + CheckSaturation 2,1 +#endif vpmaddubsw ymm2,ymm2,ymm1 vpmaddwd ymm2,ymm2,ymm12 vpaddd \Vec2Reg\(),\Vec2Reg\(),ymm2 diff --git a/src/ort_include/core/common/common.h b/src/ort_include/core/common/common.h index 0822eba..adfd341 100644 --- a/src/ort_include/core/common/common.h +++ b/src/ort_include/core/common/common.h @@ -148,6 +148,26 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch abort(); \ } while (false) +#define ORT_THROW_FROM_STATUS(status) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException( \ + ORT_WHERE_WITH_STACK, status.ToString()) \ + .what()); \ + abort(); \ + } while (false) + +#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \ + do { \ + ::onnxruntime::PrintFinalMessage( \ + ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \ + ::onnxruntime::MakeString(__VA_ARGS__), \ + ::onnxruntime::common::category, \ + ::onnxruntime::common::code) \ + .what()); \ + abort(); \ + } while (false) + #else #define ORT_TRY try @@ -180,6 +200,16 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch #define ORT_THROW_EX(ex, ...) \ throw ex(__VA_ARGS__) +#define ORT_THROW_FROM_STATUS(status) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, status.ToString(), status.Category(), \ + static_cast<::onnxruntime::common::StatusCode>(status.Code())) + +#define ORT_THROW_WITH_CATEGORY_AND_CODE(category, code, ...) \ + throw ::onnxruntime::OnnxRuntimeException(ORT_WHERE_WITH_STACK, \ + ::onnxruntime::MakeString(__VA_ARGS__), \ + ::onnxruntime::common::category, \ + ::onnxruntime::common::code) + #endif #define ORT_MAKE_STATUS(category, code, ...) \ @@ -237,7 +267,7 @@ void LogRuntimeError(uint32_t session_id, const common::Status& status, const ch auto _status = (expr); \ if ((!_status.IsOK())) { \ ::onnxruntime::LogRuntimeError(0, _status, __FILE__, static_cast(__FUNCTION__), __LINE__); \ - ORT_THROW(_status); \ + ORT_THROW_FROM_STATUS(_status); \ } \ } while (0) @@ -272,7 +302,7 @@ inline std::wstring ToWideString(const std::wstring& s) { return s; } inline std::string ToWideString(const std::string& s) { return s; } #endif -constexpr size_t kMaxStrLen = 2048; +constexpr size_t kMaxStrLen = 4096; // Returns whether `key` is in `container`. // Like C++20's map/set contains() member function. diff --git a/src/ort_include/core/common/cpuid_info.h b/src/ort_include/core/common/cpuid_info.h index 4c9e7e8..9c67ebb 100644 --- a/src/ort_include/core/common/cpuid_info.h +++ b/src/ort_include/core/common/cpuid_info.h @@ -15,6 +15,14 @@ class CPUIDInfo { return cpuid_info; } + std::string_view GetCPUVendor() const { + return vendor_; + } + + uint32_t GetCPUVendorId() const { + return vendor_id_; + } + bool HasAMX_BF16() const { return has_amx_bf16_; } bool HasAVX() const { return has_avx_; } bool HasAVX2() const { return has_avx2_; } @@ -25,6 +33,7 @@ class CPUIDInfo { bool HasSSE3() const { return has_sse3_; } bool HasSSE4_1() const { return has_sse4_1_; } bool IsHybrid() const { return is_hybrid_; } + bool HasTPAUSE() const { return has_tpause_; } // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } @@ -104,6 +113,7 @@ class CPUIDInfo { bool has_sse3_{false}; bool has_sse4_1_{false}; bool is_hybrid_{false}; + bool has_tpause_{false}; std::vector core_uarchs_; // micro-arch of each core @@ -118,9 +128,15 @@ class CPUIDInfo { bool has_arm_sve_i8mm_{false}; bool has_arm_neon_bf16_{false}; + std::string vendor_; + uint32_t vendor_id_; + + uint32_t GetVendorId(const std::string& vendor); + #if defined(CPUIDINFO_ARCH_X86) void X86Init(); + std::string GetX86Vendor(int32_t* data); #elif defined(CPUIDINFO_ARCH_ARM) @@ -136,6 +152,7 @@ class CPUIDInfo { #elif defined(_WIN32) void ArmWindowsInit(); + std::string GetArmWindowsVendor(); #elif defined(__APPLE__) diff --git a/src/ort_include/core/common/exceptions.h b/src/ort_include/core/common/exceptions.h index 494a770..6d0f6ed 100644 --- a/src/ort_include/core/common/exceptions.h +++ b/src/ort_include/core/common/exceptions.h @@ -11,6 +11,7 @@ #include #include "core/common/common.h" +#include "core/common/status.h" #include "core/common/code_location.h" namespace onnxruntime { @@ -35,12 +36,44 @@ class OnnxRuntimeException : public std::exception { /** Create a new exception that captures the location it was thrown from. @param location Location in the source code the exception is being thrown from + @param msg Message containing additional information about the exception cause. + @param category Error category + @param code Error code + */ + + OnnxRuntimeException(const CodeLocation& location, + const std::string& message, + common::StatusCategory category, + common::StatusCode code) noexcept + : OnnxRuntimeException(location, nullptr, message, category, code) { + } + + /** + Create a new exception that captures the location it was thrown from. + The instance will be created with ONNXRUNTIME category and FAIL code. + @param location Location in the source code the exception is being thrown from @param failed_condition Optional string containing the condition that failed. e.g. "tensor.Size() == input.Size()". May be nullptr. @param msg Message containing additional information about the exception cause. */ - OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) - : location_{location} { + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg) noexcept + : OnnxRuntimeException(location, failed_condition, msg, + common::StatusCategory::ONNXRUNTIME, common::StatusCode::FAIL) { + } + + /** + Create a new exception that captures the location it was thrown from. + @param location Location in the source code the exception is being thrown from + @param failed_condition Optional string containing the condition that failed. + e.g. "tensor.Size() == input.Size()". May be nullptr. + @param msg Message containing additional information about the exception cause. + @param category Error category + @param code Error code + */ + OnnxRuntimeException(const CodeLocation& location, const char* failed_condition, const std::string& msg, + common::StatusCategory category, + common::StatusCode code) + : location_{location}, category_(category), code_(code) { std::ostringstream ss; ss << location.ToString(CodeLocation::kFilenameAndPath); // output full path in case just the filename is ambiguous @@ -58,6 +91,14 @@ class OnnxRuntimeException : public std::exception { what_ = ss.str(); } + common::StatusCategory Category() const noexcept { + return category_; + } + + common::StatusCode Code() const noexcept { + return code_; + } + const char* what() const noexcept override { return what_.c_str(); } @@ -66,6 +107,8 @@ class OnnxRuntimeException : public std::exception { const CodeLocation location_; const std::vector stacktrace_; std::string what_; + common::StatusCategory category_; + common::StatusCode code_; }; } // namespace onnxruntime diff --git a/src/ort_include/core/common/logging/logging.h b/src/ort_include/core/common/logging/logging.h index 3ad27d3..508c22d 100644 --- a/src/ort_include/core/common/logging/logging.h +++ b/src/ort_include/core/common/logging/logging.h @@ -50,8 +50,9 @@ */ -namespace onnxruntime { +struct OrtLogger; // opaque API type. is always an instance of Logger +namespace onnxruntime { namespace logging { using Timestamp = std::chrono::time_point; @@ -84,7 +85,7 @@ using Timestamp = std::chrono::time_point; #endif #endif // __APPLE__ -#if _WIN32 || ORT_USE_CXX20_STD_CHRONO +#if ORT_USE_CXX20_STD_CHRONO namespace timestamp_ns = std::chrono; #else namespace timestamp_ns = ::date; @@ -351,6 +352,10 @@ class Logger { logging_manager_->SendProfileEvent(eventRecord); } + // convert to API type for custom ops and plugin EPs + OrtLogger* ToExternal() { return reinterpret_cast(this); } + const OrtLogger* ToExternal() const { return reinterpret_cast(this); } + private: const LoggingManager* logging_manager_; const std::string id_; diff --git a/src/ort_include/core/common/parse_string.h b/src/ort_include/core/common/parse_string.h index 941e3f3..6345b2a 100644 --- a/src/ort_include/core/common/parse_string.h +++ b/src/ort_include/core/common/parse_string.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -12,18 +13,45 @@ namespace onnxruntime { +namespace detail { + +// Whether we will use std::from_chars() to parse to `T`. +#if defined(_LIBCPP_VERSION) +// Note: Currently (e.g., in LLVM 19), libc++'s std::from_chars() doesn't support floating point types yet. +template +constexpr bool ParseWithFromChars = !std::is_same_v && std::is_integral_v; +#else +template +constexpr bool ParseWithFromChars = !std::is_same_v && (std::is_integral_v || std::is_floating_point_v); +#endif + +} // namespace detail + /** * Tries to parse a value from an entire string. + * If successful, sets `value` and returns true. Otherwise, does not modify `value` and returns false. */ template -bool TryParseStringWithClassicLocale(std::string_view str, T& value) { - if constexpr (std::is_integral::value && std::is_unsigned::value) { - // if T is unsigned integral type, reject negative values which will wrap - if (!str.empty() && str[0] == '-') { - return false; - } +std::enable_if_t, bool> +TryParseStringWithClassicLocale(std::string_view str, T& value) { + T parsed_value{}; + const auto [ptr, ec] = std::from_chars(str.data(), str.data() + str.size(), parsed_value); + + if (ec != std::errc{}) { + return false; } + if (ptr != str.data() + str.size()) { + return false; + } + + value = parsed_value; + return true; +} + +template +std::enable_if_t, bool> +TryParseStringWithClassicLocale(std::string_view str, T& value) { // don't allow leading whitespace if (!str.empty() && std::isspace(str[0], std::locale::classic())) { return false; diff --git a/src/ort_include/core/common/profiler_common.h b/src/ort_include/core/common/profiler_common.h index 0074d5e..ab97325 100644 --- a/src/ort_include/core/common/profiler_common.h +++ b/src/ort_include/core/common/profiler_common.h @@ -81,8 +81,8 @@ class EpProfiler { virtual ~EpProfiler() = default; virtual bool StartProfiling(TimePoint profiling_start_time) = 0; // called when profiling starts virtual void EndProfiling(TimePoint start_time, Events& events) = 0; // called when profiling ends, save all captures numbers to "events" - virtual void Start(uint64_t){}; // called before op start, accept an id as argument to identify the op - virtual void Stop(uint64_t){}; // called after op stop, accept an id as argument to identify the op + virtual void Start(uint64_t) {} // called before op start, accept an id as argument to identify the op + virtual void Stop(uint64_t) {} // called after op stop, accept an id as argument to identify the op }; // Demangle C++ symbols diff --git a/src/ort_include/core/common/spin_pause.h b/src/ort_include/core/common/spin_pause.h index 49b71e5..4d987f1 100644 --- a/src/ort_include/core/common/spin_pause.h +++ b/src/ort_include/core/common/spin_pause.h @@ -3,26 +3,11 @@ #pragma once -#if defined(_M_AMD64) -#include -#endif - -#if defined(__x86_64__) -#include -#endif - namespace onnxruntime { - namespace concurrency { // Intrinsic to use in spin-loops - -inline void SpinPause() { -#if defined(_M_AMD64) || defined(__x86_64__) - _mm_pause(); -#endif -} +void SpinPause(); } // namespace concurrency - } // namespace onnxruntime diff --git a/src/ort_include/core/common/status.h b/src/ort_include/core/common/status.h index 8f171da..da9735a 100644 --- a/src/ort_include/core/common/status.h +++ b/src/ort_include/core/common/status.h @@ -43,7 +43,9 @@ enum StatusCode { MODEL_LOADED = 8, NOT_IMPLEMENTED = 9, INVALID_GRAPH = 10, - EP_FAIL = 11 + EP_FAIL = 11, + MODEL_LOAD_CANCELED = 12, + MODEL_REQUIRES_COMPILATION = 13, }; constexpr const char* StatusCodeToString(StatusCode status) noexcept { @@ -72,6 +74,10 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept { return "INVALID_GRAPH"; case StatusCode::EP_FAIL: return "EP_FAIL"; + case StatusCode::MODEL_LOAD_CANCELED: + return "MODEL_LOAD_CANCELED"; + case StatusCode::MODEL_REQUIRES_COMPILATION: + return "MODEL_REQUIRES_COMPILATION"; default: return "GENERAL ERROR"; } @@ -104,6 +110,10 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { return HRESULT_FROM_WIN32(ERROR_FILE_CORRUPT); case StatusCode::EP_FAIL: return HRESULT_FROM_WIN32(ERROR_INTERNAL_ERROR); + case StatusCode::MODEL_LOAD_CANCELED: + return HRESULT_FROM_WIN32(ERROR_CANCELLED); + case StatusCode::MODEL_REQUIRES_COMPILATION: + return HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED); default: return E_FAIL; } diff --git a/src/ort_include/core/framework/callback.h b/src/ort_include/core/framework/callback.h index e69de29..88f14d7 100644 --- a/src/ort_include/core/framework/callback.h +++ b/src/ort_include/core/framework/callback.h @@ -0,0 +1,74 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once +#include "core/common/common.h" + +namespace onnxruntime { +struct OrtCallback { + void (*f)(void* param) noexcept; + void* param; +}; + +/** + * f will be freed in this call + */ +void OrtRunCallback(OrtCallback* f) noexcept; + +/** + * Invokes the contained OrtCallback with operator()(T). + * Useful for something like a std::unique_ptr<> deleter. + */ +struct OrtCallbackInvoker { + OrtCallbackInvoker() noexcept + : callback{nullptr, nullptr} {} + + OrtCallbackInvoker(OrtCallback callback_to_invoke) noexcept + : callback(callback_to_invoke) {} + + OrtCallback callback; + + template + void operator()(T) noexcept { + if (callback.f) { + callback.f(callback.param); + } + } +}; + +/** + * Invokes the contained OrtCallback upon destruction or being assigned to. + */ +class ScopedOrtCallbackInvoker { + public: + explicit ScopedOrtCallbackInvoker(OrtCallback callback) noexcept + : callback_(callback) {} + + ScopedOrtCallbackInvoker(ScopedOrtCallbackInvoker&& other) noexcept + : callback_(other.callback_) { + other.callback_.f = nullptr; + other.callback_.param = nullptr; + } + + ScopedOrtCallbackInvoker& operator=(ScopedOrtCallbackInvoker&& other) noexcept { + if (callback_.f) { + callback_.f(callback_.param); + } + + callback_ = other.callback_; + other.callback_.f = nullptr; + other.callback_.param = nullptr; + + return *this; + } + + ~ScopedOrtCallbackInvoker() noexcept { + if (callback_.f) { + callback_.f(callback_.param); + } + } + + private: + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ScopedOrtCallbackInvoker); + OrtCallback callback_; +}; +} // namespace onnxruntime diff --git a/src/ort_include/core/framework/float16.h b/src/ort_include/core/framework/float16.h index dac0a01..97420ff 100644 --- a/src/ort_include/core/framework/float16.h +++ b/src/ort_include/core/framework/float16.h @@ -261,19 +261,19 @@ struct BFloat16 : onnxruntime_float16::BFloat16Impl { // initializers with MLFloat16 and BFloat16 from unsigned short // E.g 10_f16 or 10_b16 #if !defined(__CUDACC__) && !defined(__HIPCC__) -inline MLFloat16 operator"" _f16(unsigned long long int v) noexcept { +inline MLFloat16 operator""_f16(unsigned long long int v) noexcept { return MLFloat16::FromBits(narrow(v)); } -inline MLFloat16 operator"" _fp16(long double v) noexcept { +inline MLFloat16 operator""_fp16(long double v) noexcept { return MLFloat16(static_cast(v)); } -inline BFloat16 operator"" _b16(unsigned long long int v) noexcept { +inline BFloat16 operator""_b16(unsigned long long int v) noexcept { return BFloat16::FromBits((narrow(v))); } -inline BFloat16 operator"" _bfp16(long double v) noexcept { +inline BFloat16 operator""_bfp16(long double v) noexcept { return BFloat16(static_cast(v)); } #endif diff --git a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h index 38a4c59..263fa35 100644 --- a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h +++ b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h @@ -10,7 +10,7 @@ /* Modifications Copyright (c) Microsoft. */ #include -#include + #pragma once #include "onnxruntime_config.h" // build/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h:162:71: @@ -217,18 +217,18 @@ class ThreadPoolProfiler { WAIT_REVOKE, MAX_EVENT }; - ThreadPoolProfiler(int, const CHAR_TYPE*) {}; + ThreadPoolProfiler(int, const CHAR_TYPE*) {} ~ThreadPoolProfiler() = default; ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ThreadPoolProfiler); - void Start() {}; + void Start() {} std::string Stop() { return "not available for minimal build"; } - void LogStart() {}; - void LogEnd(ThreadPoolEvent){}; - void LogEndAndStart(ThreadPoolEvent){}; - void LogStartAndCoreAndBlock(std::ptrdiff_t){}; - void LogCoreAndBlock(std::ptrdiff_t){}; - void LogThreadId(int) {}; - void LogRun(int) {}; + void LogStart() {} + void LogEnd(ThreadPoolEvent) {} + void LogEndAndStart(ThreadPoolEvent) {} + void LogStartAndCoreAndBlock(std::ptrdiff_t) {} + void LogCoreAndBlock(std::ptrdiff_t) {} + void LogThreadId(int) {} + void LogRun(int) {} std::string DumpChildThreadStat() { return {}; } }; #else @@ -255,7 +255,7 @@ class ThreadPoolProfiler { void LogCoreAndBlock(std::ptrdiff_t block_size); // called in main thread to log core and block size for task breakdown void LogThreadId(int thread_idx); // called in child thread to log its id void LogRun(int thread_idx); // called in child thread to log num of run - std::string DumpChildThreadStat(); // return all child statitics collected so far + std::string DumpChildThreadStat(); // return all child statistics collected so far private: static const char* GetEventName(ThreadPoolEvent); @@ -738,7 +738,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // Allocate a new tag to use to identify work items from a given // thread in a parallel section. Ideally, threads will have // unique tags, but re-use is not incorrect if the counter wraps - // (for intsance, if a long-running workload is calling into ORT + // (for instance, if a long-running workload is calling into ORT // from a fresh thread for each request). We must not re-use the // default tag 0 which is used to identify work items added via // Schedule as opposed to requests for help in parallel sections. @@ -911,7 +911,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } } - // Now we know that dispatch is finshed, we synchronize with the + // Now we know that dispatch is finished, we synchronize with the // tasks that were created (if any) for the parallel section. We // revoke tasks still in queues, and then wait for any that are // still running. @@ -1003,7 +1003,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // * First, suppose that a single job is running a series of loops. // Its main thread enters a parallel loop. Initially, let's assume // its preferred worker array is [_,0,1,2], writing "_" for the - // unusued element for the par_idx=0 work that the main thread will + // unused element for the par_idx=0 work that the main thread will // run. // // The main thread schedules the dispatcher task onto worker 0. @@ -1466,11 +1466,14 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter status = ThreadStatus::Spinning; } - void SetBlocked(std::function should_block, + bool SetBlocked(std::function should_block, std::function post_block) { std::unique_lock lk(mutex); - assert(GetStatus() == ThreadStatus::Spinning); - status.store(ThreadStatus::Blocking, std::memory_order_relaxed); + auto old_status = status.exchange(ThreadStatus::Blocking, std::memory_order_seq_cst); + if (old_status != ThreadStatus::Spinning) { + // Encountered a logical error + return false; + } if (should_block()) { status.store(ThreadStatus::Blocked, std::memory_order_relaxed); do { @@ -1479,6 +1482,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter post_block(); } status.store(ThreadStatus::Spinning, std::memory_order_relaxed); + return true; } private: @@ -1557,62 +1561,66 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // Attempt to block if (!t) { - td.SetBlocked( // Pre-block test - [&]() -> bool { - bool should_block = true; - // Check whether work was pushed to us while attempting to block. We make - // this test while holding the per-thread status lock, and after setting - // our status to ThreadStatus::Blocking. - // - // This synchronizes with ThreadPool::Schedule which pushes work to the queue - // and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake): - // - // Main thread: Worker: - // #1 Push work #A Set status blocking - // #2 Read worker status #B Check queue - // #3 Wake if blocking/blocked - // - // If #A is before #2 then main sees worker blocked and wakes - // - // If #A if after #2 then #B will see #1, and we abandon blocking - assert(!t); - t = q.PopFront(); - if (t) { - should_block = false; - } - - // No work pushed to us, continue attempting to block. The remaining - // test is to synchronize with termination requests. If we are - // shutting down and all worker threads blocked without work, that's - // we are done. - if (should_block) { - blocked_++; - if (done_ && blocked_ == num_threads_) { - should_block = false; - // Almost done, but need to re-check queues. - // Consider that all queues are empty and all worker threads are preempted - // right after incrementing blocked_ above. Now a free-standing thread - // submits work and calls destructor (which sets done_). If we don't - // re-check queues, we will exit leaving the work unexecuted. - if (NonEmptyQueueIndex() != -1) { - // Note: we must not pop from queues before we decrement blocked_, - // otherwise the following scenario is possible. Consider that instead - // of checking for emptiness we popped the only element from queues. - // Now other worker threads can start exiting, which is bad if the - // work item submits other work. So we just check emptiness here, - // which ensures that all worker threads exit at the same time. - blocked_--; - } else { - should_exit = true; + if (!td.SetBlocked( // Pre-block test + [&]() -> bool { + bool should_block = true; + // Check whether work was pushed to us while attempting to block. We make + // this test while holding the per-thread status lock, and after setting + // our status to ThreadStatus::Blocking. + // + // This synchronizes with ThreadPool::Schedule which pushes work to the queue + // and then tests for ThreadStatus::Blocking/Blocked (via EnsureAwake): + // + // Main thread: Worker: + // #1 Push work #A Set status blocking + // #2 Read worker status #B Check queue + // #3 Wake if blocking/blocked + // + // If #A is before #2 then main sees worker blocked and wakes + // + // If #A if after #2 then #B will see #1, and we abandon blocking + assert(!t); + t = q.PopFront(); + if (t) { + should_block = false; + } + + // No work pushed to us, continue attempting to block. The remaining + // test is to synchronize with termination requests. If we are + // shutting down and all worker threads blocked without work, that's + // we are done. + if (should_block) { + blocked_++; + if (done_ && blocked_ == num_threads_) { + should_block = false; + // Almost done, but need to re-check queues. + // Consider that all queues are empty and all worker threads are preempted + // right after incrementing blocked_ above. Now a free-standing thread + // submits work and calls destructor (which sets done_). If we don't + // re-check queues, we will exit leaving the work unexecuted. + if (NonEmptyQueueIndex() != -1) { + // Note: we must not pop from queues before we decrement blocked_, + // otherwise the following scenario is possible. Consider that instead + // of checking for emptiness we popped the only element from queues. + // Now other worker threads can start exiting, which is bad if the + // work item submits other work. So we just check emptiness here, + // which ensures that all worker threads exit at the same time. + blocked_--; + } else { + should_exit = true; + } + } } - } - } - return should_block; - }, - // Post-block update (executed only if we blocked) - [&]() { - blocked_--; - }); + return should_block; + }, + // Post-block update (executed only if we blocked) + [&]() { + blocked_--; + })) { + // Encountered a fatal logic error in SetBlocked + should_exit = true; + break; + } // Thread just unblocked. Unless we picked up work while // blocking, or are exiting, then either work was pushed to // us, or it was pushed to an overloaded queue diff --git a/src/ort_include/core/platform/env.h b/src/ort_include/core/platform/env.h index 47e74ea..970567a 100644 --- a/src/ort_include/core/platform/env.h +++ b/src/ort_include/core/platform/env.h @@ -37,7 +37,9 @@ namespace Eigen { class ThreadPoolInterface; } namespace onnxruntime { - +namespace concurrency { + inline void SpinPause(){} +} #ifdef _WIN32 using PIDType = unsigned long; using FileOffsetType = int64_t; @@ -78,7 +80,6 @@ struct ThreadOptions { // Set or unset denormal as zero. bool set_denormal_as_zero = false; - }; std::ostream& operator<<(std::ostream& os, const LogicalProcessors&); @@ -157,9 +158,6 @@ class Env { virtual common::Status ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, gsl::span buffer) const = 0; - - - /** Gets the canonical form of a file path (symlinks resolved). */ virtual common::Status GetCanonicalPath( const PathString& path, diff --git a/src/ort_include/core/session/onnxruntime_c_api.h b/src/ort_include/core/session/onnxruntime_c_api.h index a187fbd..a2f518a 100644 --- a/src/ort_include/core/session/onnxruntime_c_api.h +++ b/src/ort_include/core/session/onnxruntime_c_api.h @@ -1,47 +1,6246 @@ -#pragma once - -#ifndef _WIN32 -#define _In_ -#define _In_z_ -#define _In_opt_ -#define _In_opt_z_ -#define _Out_ -#define _Outptr_ -#define _Out_opt_ -#define _Inout_ -#define _Inout_opt_ -#define _Frees_ptr_opt_ -#define _Ret_maybenull_ -#define _Ret_notnull_ -#define _Check_return_ -#define _Outptr_result_maybenull_ -#define _In_reads_(X) -#define _Inout_updates_(X) -#define _Out_writes_(X) -#define _Inout_updates_all_(X) -#define _Out_writes_bytes_all_(X) -#define _Out_writes_all_(X) -#define _Success_(X) -#define _Outptr_result_buffer_maybenull_(X) -#else -#include -#endif - -#ifdef _WIN32 -#define ORTCHAR_T wchar_t -#else -#define ORTCHAR_T char -#endif - -/// ORTCHAR_T, ORT_TSTR are reserved specifically for path handling. -/// All other strings are UTF-8 encoded, use char and std::string -#ifndef ORT_TSTR -#ifdef _WIN32 -#define ORT_TSTR(X) L##X -// When X is a macro, L##X is not defined. In this case, we need to use ORT_TSTR_ON_MACRO. -#define ORT_TSTR_ON_MACRO(X) L"" X -#else -#define ORT_TSTR(X) X -#define ORT_TSTR_ON_MACRO(X) X -#endif -#endif \ No newline at end of file +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// See docs\c_cxx\README.md on generating the Doxygen documentation from this file + +/** \mainpage ONNX Runtime + * + * ONNX Runtime is a high-performance inference and training graph execution engine for deep learning models. + * + * ONNX Runtime's C, C++ APIs offer an easy to use interface to onboard and execute onnx models. + * - \subpage c_cpp_api "Core C, C++ APIs" + * - \subpage training_c_cpp_api "Training C, C++ APIs for on-device training" + * + * \page c_cpp_api Core C, C++ APIs + *

C

+ * + * ::OrtApi - Click here to go to the structure with all C API functions. + * + *

C++

+ * + * ::Ort - Click here to go to the namespace holding all of the C++ wrapper classes + * + * It is a set of header only wrapper classes around the C API. The goal is to turn the C style return value error codes into C++ exceptions, and to + * automate memory management through standard C++ RAII principles. + * + * \addtogroup Global + * ONNX Runtime C API + * @{ + */ + +#pragma once +#include +#include +#include +#include + +/** \brief The API version defined in this header + * + * This value is used by some API functions to behave as this version of the header expects. + */ +#define ORT_API_VERSION 23 + +#ifdef __cplusplus +extern "C" { +#endif + +//! @} +// SAL2 Definitions +#ifndef _MSC_VER +#define _In_ +#define _In_z_ +#define _In_opt_ +#define _In_opt_z_ +#define _Out_ +#define _Outptr_ +#define _Out_opt_ +#define _Inout_ +#define _Inout_opt_ +#define _Frees_ptr_opt_ +#define _Ret_maybenull_ +#define _Ret_notnull_ +#define _Check_return_ +#define _Outptr_result_maybenull_ +#define _In_reads_(X) +#define _Inout_updates_(X) +#define _Out_writes_(X) +#define _Inout_updates_all_(X) +#define _Out_writes_bytes_all_(X) +#define _Out_writes_all_(X) +#define _Success_(X) +#define _Outptr_result_buffer_maybenull_(X) +#define ORT_ALL_ARGS_NONNULL __attribute__((nonnull)) +#else +#include +#define ORT_ALL_ARGS_NONNULL +#endif + +#ifdef _WIN32 +// Define ORT_DLL_IMPORT if your program is dynamically linked to Ort. +// dllexport is not used, we use a .def file. +#ifdef ORT_DLL_IMPORT +#define ORT_EXPORT __declspec(dllimport) +#else +#define ORT_EXPORT +#endif +#define ORT_API_CALL _stdcall +#define ORT_MUST_USE_RESULT +#define ORTCHAR_T wchar_t +#else +// To make symbols visible on macOS/iOS +#ifdef __APPLE__ +#define ORT_EXPORT __attribute__((visibility("default"))) +#else +#define ORT_EXPORT +#endif +#define ORT_API_CALL +#define ORT_MUST_USE_RESULT __attribute__((warn_unused_result)) +#define ORTCHAR_T char +#endif + +/// ORTCHAR_T, ORT_TSTR are reserved specifically for path handling. +/// All other strings are UTF-8 encoded, use char and std::string +#ifndef ORT_TSTR +#ifdef _WIN32 +#define ORT_TSTR(X) L##X +// When X is a macro, L##X is not defined. In this case, we need to use ORT_TSTR_ON_MACRO. +#define ORT_TSTR_ON_MACRO(X) L"" X +#else +#define ORT_TSTR(X) X +#define ORT_TSTR_ON_MACRO(X) X +#endif +#endif + +// On Windows, ORT_FILE is a wchar_t version of the __FILE__ macro. +// Otherwise, ORT_FILE is equivalent to __FILE__. +#ifndef ORT_FILE +#define ORT_FILE_INTERNAL(x) ORT_TSTR(x) +#define ORT_FILE ORT_FILE_INTERNAL(__FILE__) +#endif + +// Any pointer marked with _In_ or _Out_, cannot be NULL. + +// Windows users should use unicode paths when possible to bypass the MAX_PATH limitation +// Every pointer marked with _In_ or _Out_, cannot be NULL. Caller should ensure that. +// for ReleaseXXX(...) functions, they can accept NULL pointer. + +#ifdef __cplusplus +// For any compiler with C++11 support, MSVC 2015 and greater, or Clang version supporting noexcept. +// Such complex condition is needed because compilers set __cplusplus value differently. +#ifndef __has_feature +#define __has_feature(x) 0 +#endif +#if ((__cplusplus >= 201103L) || (_MSC_VER >= 1900) || (defined(__has_feature) && __has_feature(cxx_noexcept))) +#define NO_EXCEPTION noexcept +#else +#define NO_EXCEPTION throw() +#endif +#else +#define NO_EXCEPTION +#endif + +// __VA_ARGS__ on Windows and Linux are different +#define ORT_API(RETURN_TYPE, NAME, ...) RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION + +#define ORT_API_STATUS(NAME, ...) \ + _Success_(return == 0) _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) \ + NO_EXCEPTION ORT_MUST_USE_RESULT + +// XXX: Unfortunately, SAL annotations are known to not work with function pointers +#define ORT_API2_STATUS(NAME, ...) \ + _Check_return_ _Ret_maybenull_ OrtStatusPtr(ORT_API_CALL* NAME)(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT + +// Used in *.cc files. Almost as same as ORT_API_STATUS, except without ORT_MUST_USE_RESULT and ORT_EXPORT +#define ORT_API_STATUS_IMPL(NAME, ...) \ + _Success_(return == 0) _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION + +#define ORT_CLASS_RELEASE(X) void(ORT_API_CALL * Release##X)(_Frees_ptr_opt_ Ort##X * input) + +#ifdef __DOXYGEN__ +#undef ORT_API_STATUS +#define ORT_API_STATUS(NAME, ...) OrtStatus* NAME(__VA_ARGS__) +#undef ORT_API2_STATUS +#define ORT_API2_STATUS(NAME, ...) OrtStatus* NAME(__VA_ARGS__) +#undef ORT_CLASS_RELEASE +#define ORT_CLASS_RELEASE(X) void Release##X(Ort##X* input) +#undef NO_EXCEPTION +#define NO_EXCEPTION +#endif +/** \addtogroup Global + * ONNX Runtime C API + * @{ + */ + +/** Copied from TensorProto::DataType + * Currently, Ort doesn't support complex64, complex128 + */ +typedef enum ONNXTensorElementDataType { + ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, // maps to c type float + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, // maps to c type uint8_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, // maps to c type int8_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, // maps to c type uint16_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, // maps to c type int16_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, // maps to c type int32_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, // maps to c type int64_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string + ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, + ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t + ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components + ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components + ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, // Non-IEEE floating-point format based on IEEE754 single-precision + // float 8 types were introduced in onnx 1.14, see https://onnx.ai/onnx/technical/float8.html + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, // Non-IEEE floating-point format based on IEEE754 single-precision + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision + // Int4 types were introduced in ONNX 1.16. See https://onnx.ai/onnx/technical/int4.html + ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of packed uint4 values (size == 1 byte) + ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of packed int4 values (size == 1 byte) +} ONNXTensorElementDataType; + +// Synced with onnx TypeProto oneof +typedef enum ONNXType { + ONNX_TYPE_UNKNOWN, + ONNX_TYPE_TENSOR, + ONNX_TYPE_SEQUENCE, + ONNX_TYPE_MAP, + ONNX_TYPE_OPAQUE, + ONNX_TYPE_SPARSETENSOR, + ONNX_TYPE_OPTIONAL +} ONNXType; + +// These types are synced with internal +// SparseFormatFlags +typedef enum OrtSparseFormat { + ORT_SPARSE_UNDEFINED = 0, + ORT_SPARSE_COO = 0x1, + ORT_SPARSE_CSRC = 0x2, + ORT_SPARSE_BLOCK_SPARSE = 0x4 +} OrtSparseFormat; + +// Enum allows to query sparse tensor indices +enum OrtSparseIndicesFormat { + ORT_SPARSE_COO_INDICES, + ORT_SPARSE_CSR_INNER_INDICES, + ORT_SPARSE_CSR_OUTER_INDICES, + ORT_SPARSE_BLOCK_SPARSE_INDICES +}; + +/** \brief Logging severity levels + * + * In typical API usage, specifying a logging severity level specifies the minimum severity of log messages to show. + */ +typedef enum OrtLoggingLevel { + ORT_LOGGING_LEVEL_VERBOSE, ///< Verbose informational messages (least severe). + ORT_LOGGING_LEVEL_INFO, ///< Informational messages. + ORT_LOGGING_LEVEL_WARNING, ///< Warning messages. + ORT_LOGGING_LEVEL_ERROR, ///< Error messages. + ORT_LOGGING_LEVEL_FATAL, ///< Fatal error messages (most severe). +} OrtLoggingLevel; + +typedef enum OrtErrorCode { + ORT_OK, + ORT_FAIL, + ORT_INVALID_ARGUMENT, + ORT_NO_SUCHFILE, + ORT_NO_MODEL, + ORT_ENGINE_ERROR, + ORT_RUNTIME_EXCEPTION, + ORT_INVALID_PROTOBUF, + ORT_MODEL_LOADED, + ORT_NOT_IMPLEMENTED, + ORT_INVALID_GRAPH, + ORT_EP_FAIL, + ORT_MODEL_LOAD_CANCELED, + ORT_MODEL_REQUIRES_COMPILATION, +} OrtErrorCode; + +typedef enum OrtOpAttrType { + ORT_OP_ATTR_UNDEFINED = 0, + ORT_OP_ATTR_INT, + ORT_OP_ATTR_INTS, + ORT_OP_ATTR_FLOAT, + ORT_OP_ATTR_FLOATS, + ORT_OP_ATTR_STRING, + ORT_OP_ATTR_STRINGS, +} OrtOpAttrType; + +//! @} +#define ORT_RUNTIME_CLASS(X) \ + struct Ort##X; \ + typedef struct Ort##X Ort##X + +/** \addtogroup Global + * ONNX Runtime C API + * @{ + */ +// The actual types defined have an Ort prefix +ORT_RUNTIME_CLASS(Env); +ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success +ORT_RUNTIME_CLASS(MemoryInfo); +ORT_RUNTIME_CLASS(IoBinding); +ORT_RUNTIME_CLASS(Session); // Don't call ReleaseSession from Dllmain (because session owns a thread pool) +ORT_RUNTIME_CLASS(Value); +ORT_RUNTIME_CLASS(RunOptions); +ORT_RUNTIME_CLASS(TypeInfo); +ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); +ORT_RUNTIME_CLASS(MapTypeInfo); +ORT_RUNTIME_CLASS(SequenceTypeInfo); +ORT_RUNTIME_CLASS(OptionalTypeInfo); +ORT_RUNTIME_CLASS(SessionOptions); +ORT_RUNTIME_CLASS(CustomOpDomain); +ORT_RUNTIME_CLASS(ModelMetadata); +ORT_RUNTIME_CLASS(ThreadPoolParams); +ORT_RUNTIME_CLASS(ThreadingOptions); +ORT_RUNTIME_CLASS(ArenaCfg); +ORT_RUNTIME_CLASS(PrepackedWeightsContainer); +ORT_RUNTIME_CLASS(TensorRTProviderOptionsV2); +ORT_RUNTIME_CLASS(NvTensorRtRtxProviderOptions); +ORT_RUNTIME_CLASS(CUDAProviderOptionsV2); +ORT_RUNTIME_CLASS(CANNProviderOptions); +ORT_RUNTIME_CLASS(DnnlProviderOptions); +ORT_RUNTIME_CLASS(Op); +ORT_RUNTIME_CLASS(OpAttr); +ORT_RUNTIME_CLASS(Logger); +ORT_RUNTIME_CLASS(ShapeInferContext); +ORT_RUNTIME_CLASS(LoraAdapter); +ORT_RUNTIME_CLASS(ValueInfo); +ORT_RUNTIME_CLASS(Node); +ORT_RUNTIME_CLASS(Graph); +ORT_RUNTIME_CLASS(Model); +ORT_RUNTIME_CLASS(ModelCompilationOptions); +ORT_RUNTIME_CLASS(HardwareDevice); +ORT_RUNTIME_CLASS(EpDevice); +ORT_RUNTIME_CLASS(KeyValuePairs); + +#ifdef _MSC_VER +typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; +#else +typedef OrtStatus* OrtStatusPtr; +#endif + +/** \brief Memory allocation interface + * + * Structure of function pointers that defines a memory allocator. This can be created and filled in by the user for custom allocators. + * + * When an allocator is passed to any function, be sure that the allocator object is not destroyed until the last allocated object using it is freed. + */ +typedef struct OrtAllocator { + uint32_t version; ///< Must be initialized to ORT_API_VERSION + void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes + void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p); ///< Free a block of memory previously allocated with OrtAllocator::Alloc + const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); ///< Return a pointer to an ::OrtMemoryInfo that describes this allocator + /** + * @brief Optional allocation function to use for memory allocations made during session initialization. + * Use this function if you want to separate allocations made by ORT during Run() calls from + * those made during session initialization. This allows for separate memory management strategies for these allocations. + */ + void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes +} OrtAllocator; + +typedef void(ORT_API_CALL* OrtLoggingFunction)( + void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, + const char* message); + +/** \brief Graph optimization level + * + * Refer to https://www.onnxruntime.ai/docs/performance/graph-optimizations.html#graph-optimization-levels + * for an in-depth understanding of the Graph Optimization Levels. + */ +typedef enum GraphOptimizationLevel { + ORT_DISABLE_ALL = 0, + ORT_ENABLE_BASIC = 1, + ORT_ENABLE_EXTENDED = 2, + ORT_ENABLE_ALL = 99 +} GraphOptimizationLevel; + +typedef enum ExecutionMode { + ORT_SEQUENTIAL = 0, + ORT_PARALLEL = 1, +} ExecutionMode; + +/** \brief Language projection identifiers + * /see OrtApi::SetLanguageProjection + */ +typedef enum OrtLanguageProjection { + ORT_PROJECTION_C = 0, + ORT_PROJECTION_CPLUSPLUS = 1, + ORT_PROJECTION_CSHARP = 2, + ORT_PROJECTION_PYTHON = 3, + ORT_PROJECTION_JAVA = 4, + ORT_PROJECTION_WINML = 5, + ORT_PROJECTION_NODEJS = 6, +} OrtLanguageProjection; + +struct OrtKernelInfo; +typedef struct OrtKernelInfo OrtKernelInfo; +struct OrtKernelContext; +typedef struct OrtKernelContext OrtKernelContext; +struct OrtCustomOp; +typedef struct OrtCustomOp OrtCustomOp; + +typedef enum OrtAllocatorType { + OrtInvalidAllocator = -1, + OrtDeviceAllocator = 0, + OrtArenaAllocator = 1 +} OrtAllocatorType; + +/** \brief Memory types for allocated memory, execution provider specific types should be extended in each provider. + */ +// Whenever this struct is updated, please also update the MakeKey function in onnxruntime / core / framework / execution_provider.cc +typedef enum OrtMemType { + OrtMemTypeCPUInput = -2, ///< Any CPU memory used by non-CPU execution provider + OrtMemTypeCPUOutput = -1, ///< CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED + OrtMemTypeCPU = OrtMemTypeCPUOutput, ///< Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. CUDA_PINNED + OrtMemTypeDefault = 0, ///< The default allocator for execution provider +} OrtMemType; + +/** \brief This mimics OrtDevice type constants so they can be returned in the API + */ +typedef enum OrtMemoryInfoDeviceType { + OrtMemoryInfoDeviceType_CPU = 0, + OrtMemoryInfoDeviceType_GPU = 1, + OrtMemoryInfoDeviceType_FPGA = 2 +} OrtMemoryInfoDeviceType; + +typedef enum OrtHardwareDeviceType { + OrtHardwareDeviceType_CPU, + OrtHardwareDeviceType_GPU, + OrtHardwareDeviceType_NPU +} OrtHardwareDeviceType; + +/** \brief These are the default EP selection policies used by ORT when doing automatic EP selection. + */ +typedef enum OrtExecutionProviderDevicePolicy { + OrtExecutionProviderDevicePolicy_DEFAULT, + OrtExecutionProviderDevicePolicy_PREFER_CPU, + OrtExecutionProviderDevicePolicy_PREFER_NPU, + OrtExecutionProviderDevicePolicy_PREFER_GPU, + OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE, + OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY, + OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER, +} OrtExecutionProviderDevicePolicy; + +/** \brief Delegate to allow providing custom OrtEpDevice selection logic + * + * This delegate is called by the EP selection code to allow the user to provide custom device selection logic. + * The user can use this to select OrtEpDevice instances from the list of available devices. + * + * \param ep_devices The list of available devices. + * \param num_devices The number of available devices. + * \param model_metadata The model metadata. + * \param runtime_metadata The runtime metadata. May be nullptr. + * \param selected Pre-allocated array to populate with selected OrtEpDevice pointers from ep_devices. + * \param max_selected The maximum number of devices that can be selected in the pre-allocated array. + Currently the maximum is 8. + * \param num_selected The number of selected devices. + * \param state Opaque pointer. Required to use the delegate from other languages like C# and python. + * + * \return OrtStatus* Selection status. Return nullptr on success. + * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. + * ORT will release the OrtStatus* if not null. + */ +typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices, + _In_ size_t num_devices, + _In_ const OrtKeyValuePairs* model_metadata, + _In_opt_ const OrtKeyValuePairs* runtime_metadata, + _Inout_ const OrtEpDevice** selected, + _In_ size_t max_selected, + _Out_ size_t* num_selected, + _In_ void* state); + +/** \brief Algorithm to use for cuDNN Convolution Op + */ +typedef enum OrtCudnnConvAlgoSearch { + OrtCudnnConvAlgoSearchExhaustive, // expensive exhaustive benchmarking using cudnnFindConvolutionForwardAlgorithmEx + OrtCudnnConvAlgoSearchHeuristic, // lightweight heuristic based search using cudnnGetConvolutionForwardAlgorithm_v7 + OrtCudnnConvAlgoSearchDefault, // default algorithm using CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM +} OrtCudnnConvAlgoSearch; + +/** \brief CUDA Provider Options + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_CUDA + */ +typedef struct OrtCUDAProviderOptions { +#ifdef __cplusplus + OrtCUDAProviderOptions() + : device_id{}, + cudnn_conv_algo_search{OrtCudnnConvAlgoSearchExhaustive}, + gpu_mem_limit{SIZE_MAX}, + arena_extend_strategy{}, + do_copy_in_default_stream{1}, + has_user_compute_stream{}, + user_compute_stream{}, + default_memory_arena_cfg{}, + tunable_op_enable{false}, + tunable_op_tuning_enable{false}, + tunable_op_max_tuning_duration_ms{} {} +#endif + + /** \brief CUDA device Id + * Defaults to 0. + */ + int device_id; + + /** \brief CUDA Convolution algorithm search configuration. + * See enum OrtCudnnConvAlgoSearch for more details. + * Defaults to OrtCudnnConvAlgoSearchExhaustive. + */ + OrtCudnnConvAlgoSearch cudnn_conv_algo_search; + + /** \brief CUDA memory limit (To use all possible memory pass in maximum size_t) + * Defaults to SIZE_MAX. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ + size_t gpu_mem_limit; + + /** \brief Strategy used to grow the memory arena + * 0 = kNextPowerOfTwo
+ * 1 = kSameAsRequested
+ * Defaults to 0. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ + int arena_extend_strategy; + + /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the CUDA EP + * 0 = Use separate streams for copying and compute. + * 1 = Use the same stream for copying and compute. + * Defaults to 1. + * WARNING: Setting this to 0 may result in data races for some models. + * Please see issue #4829 for more details. + */ + int do_copy_in_default_stream; + + /** \brief Flag indicating if there is a user provided compute stream + * Defaults to 0. + */ + int has_user_compute_stream; + + /** \brief User provided compute stream. + * If provided, please set `has_user_compute_stream` to 1. + */ + void* user_compute_stream; + + /** \brief CUDA memory arena configuration parameters + */ + OrtArenaCfg* default_memory_arena_cfg; + + /** \brief Enable TunableOp for using. + * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. + * This option can be overridden by environment variable ORT_CUDA_TUNABLE_OP_ENABLE. + */ + int tunable_op_enable; + + /** \brief Enable TunableOp for tuning. + * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. + * This option can be overridden by environment variable ORT_CUDA_TUNABLE_OP_TUNING_ENABLE. + */ + int tunable_op_tuning_enable; + + /** \brief Max tuning duration time limit for each instance of TunableOp. + * Defaults to 0 to disable the limit. + */ + int tunable_op_max_tuning_duration_ms; + +} OrtCUDAProviderOptions; + +/** \brief ROCM Provider Options + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_ROCM + */ +typedef struct OrtROCMProviderOptions { +#ifdef __cplusplus + OrtROCMProviderOptions() + : device_id{}, + miopen_conv_exhaustive_search{0}, + gpu_mem_limit{SIZE_MAX}, + arena_extend_strategy{}, + do_copy_in_default_stream{1}, + has_user_compute_stream{}, + user_compute_stream{}, + default_memory_arena_cfg{}, + enable_hip_graph{false}, + tunable_op_enable{false}, + tunable_op_tuning_enable{false}, + tunable_op_max_tuning_duration_ms{} {} +#endif + + /** \brief ROCM device Id + * Defaults to 0. + */ + int device_id; + + /** \brief ROCM MIOpen Convolution algorithm exhaustive search option. + * Defaults to 0 (false). + */ + int miopen_conv_exhaustive_search; + + /** \brief ROCM memory limit (To use all possible memory pass in maximum size_t) + * Defaults to SIZE_MAX. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ + size_t gpu_mem_limit; + + /** \brief Strategy used to grow the memory arena + * 0 = kNextPowerOfTwo
+ * 1 = kSameAsRequested
+ * Defaults to 0. + * \note If a ::OrtArenaCfg has been applied, it will override this field + */ + int arena_extend_strategy; + + /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the ROCM EP + * 0 = Use separate streams for copying and compute. + * 1 = Use the same stream for copying and compute. + * Defaults to 1. + * WARNING: Setting this to 0 may result in data races for some models. + * Please see issue #4829 for more details. + */ + int do_copy_in_default_stream; + + /** \brief Flag indicating if there is a user provided compute stream + * Defaults to 0. + */ + int has_user_compute_stream; + + /** \brief User provided compute stream. + * If provided, please set `has_user_compute_stream` to 1. + */ + void* user_compute_stream; + + /** \brief ROCM memory arena configuration parameters + */ + OrtArenaCfg* default_memory_arena_cfg; + + int enable_hip_graph; + + /** \brief Enable TunableOp for using. + * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. + * This option can be overridden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. + */ + int tunable_op_enable; + + /** \brief Enable TunableOp for tuning. + * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. + * This option can be overridden by environment variable ORT_ROCM_TUNABLE_OP_TUNING_ENABLE. + */ + int tunable_op_tuning_enable; + + /** \brief Max tuning duration time limit for each instance of TunableOp. + * Defaults to 0 to disable the limit. + */ + int tunable_op_max_tuning_duration_ms; + +} OrtROCMProviderOptions; + +/** \brief TensorRT Provider Options + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + */ +typedef struct OrtTensorRTProviderOptions { + int device_id; ///< CUDA device id (0 = default device) + int has_user_compute_stream; // indicator of user specified CUDA compute stream. + void* user_compute_stream; // user specified CUDA compute stream. + int trt_max_partition_iterations; // maximum iterations for TensorRT parser to get capability + int trt_min_subgraph_size; // minimum size of TensorRT subgraphs + size_t trt_max_workspace_size; // maximum workspace size for TensorRT. + int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true + int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true + const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name. + int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true + int trt_dla_enable; // enable DLA. Default 0 = false, nonzero = true + int trt_dla_core; // DLA core number. Default 0 + int trt_dump_subgraphs; // dump TRT subgraph. Default 0 = false, nonzero = true + int trt_engine_cache_enable; // enable engine caching. Default 0 = false, nonzero = true + const char* trt_engine_cache_path; // specify engine cache path + int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true + const char* trt_engine_decryption_lib_path; // specify engine decryption library path + int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true + // This is the legacy struct and don't add new fields here. + // For new field that can be represented by string, please add it in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h + // For non-string field, need to create a new separate api to handle it. +} OrtTensorRTProviderOptions; + +/** \brief MIGraphX Provider Options + * + * \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX + */ +typedef struct OrtMIGraphXProviderOptions { + int device_id; // hip device id. + int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true + int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true + const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name + int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true + const char* migraphx_save_model_path; // migraphx model path name + int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true + const char* migraphx_load_model_path; // migraphx model path name + bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false +} OrtMIGraphXProviderOptions; + +/** \brief OpenVINO Provider Options + * \brief This Struct is frozen since ORT 1.13.0. Its maintained part of Legacy API for compatibility. + * \brief For latest OpenVINO Provider Options update to the ProviderOptions map. + * \brief Latest OpenVINO Provider Options are listed in the + * \htmlonly + * onnxruntime document. + * \endhtmlonly + * \see OrtApi::SessionOptionsAppendExecutionProvider() + */ +typedef struct OrtOpenVINOProviderOptions { +#ifdef __cplusplus + OrtOpenVINOProviderOptions() : device_type{}, + enable_npu_fast_compile{}, + device_id{}, + num_of_threads{}, + cache_dir{}, + context{}, + enable_opencl_throttling{}, + enable_dynamic_shapes{} {} +#endif + /** \brief Device type string + * + * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16" + */ + const char* device_type; + unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled + const char* device_id; + size_t num_of_threads; ///< 0 = Use default number of threads + const char* cache_dir; // path is set to empty by default + void* context; + unsigned char enable_opencl_throttling; ///< 0 = disabled, nonzero = enabled + unsigned char enable_dynamic_shapes; ///< 0 = disabled, nonzero = enabled +} OrtOpenVINOProviderOptions; + +struct OrtApi; +typedef struct OrtApi OrtApi; + +struct OrtTrainingApi; +typedef struct OrtTrainingApi OrtTrainingApi; + +struct OrtModelEditorApi; +typedef struct OrtModelEditorApi OrtModelEditorApi; + +struct OrtCompileApi; +typedef struct OrtCompileApi OrtCompileApi; + +struct OrtEpApi; +typedef struct OrtEpApi OrtEpApi; + +/** \brief The helper interface to get the right version of OrtApi + * + * Get a pointer to this structure through ::OrtGetApiBase + */ +struct OrtApiBase { + /** \brief Get a pointer to the requested version of the ::OrtApi + * + * \param[in] version Must be ::ORT_API_VERSION + * \return The ::OrtApi for the version requested, nullptr will be returned if this version is unsupported, for example when using a runtime + * older than the version created with this header file. + * + * One can call GetVersionString() to get the version of the Onnxruntime library for logging + * and error reporting purposes. + */ + const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION; + + /** \brief Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1") + * + * \return UTF-8 encoded version string. Do not deallocate the returned buffer. + */ + const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; +}; + +typedef struct OrtApiBase OrtApiBase; + +/** \brief The Onnxruntime library's entry point to access the C API + * + * Call this to get the a pointer to an ::OrtApiBase + */ +ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION; + +/** \brief Thread work loop function + * + * Onnxruntime will provide the working loop on custom thread creation + * Argument is an onnxruntime built-in type which will be provided when thread pool calls OrtCustomCreateThreadFn + */ +typedef void (*OrtThreadWorkerFn)(void* ort_worker_fn_param); + +typedef const struct OrtCustomHandleType { + char __place_holder; +}* OrtCustomThreadHandle; + +/** \brief Ort custom thread creation function + * + * The function should return a thread handle to be used in onnxruntime thread pools + * Onnxruntime will throw exception on return value of nullptr or 0, indicating that the function failed to create a thread + */ +typedef OrtCustomThreadHandle (*OrtCustomCreateThreadFn)(void* ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void* ort_worker_fn_param); + +/** \brief Custom thread join function + * + * Onnxruntime thread pool destructor will call the function to join a custom thread. + * Argument ort_custom_thread_handle is the value returned by OrtCustomCreateThreadFn + */ +typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle); + +typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api); + +/** \brief Callback function for RunAsync + * + * \param[in] user_data User specific data that passed back to the callback + * \param[out] outputs On succeed, outputs host inference results, on error, the value will be nullptr + * \param[out] num_outputs Number of outputs, on error, the value will be zero + * \param[out] status On error, status will provide details + */ +typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status); + +/** \brief The C API + * + * All C API functions are defined inside this structure as pointers to functions. + * Call OrtApiBase::GetApi to get a pointer to it + * + * \nosubgrouping + */ +struct OrtApi { + /// \name OrtStatus + /// @{ + + /** + * \brief Create an OrtStatus from a null terminated string + * + * \param[in] code + * \param[in] msg A null-terminated string. Its contents will be copied. + * \return A new OrtStatus object, must be destroyed with OrtApi::ReleaseStatus + */ + OrtStatus*(ORT_API_CALL* CreateStatus)(OrtErrorCode code, _In_ const char* msg)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /** \brief Get OrtErrorCode from OrtStatus + * + * \param[in] status + * \return OrtErrorCode that \p status was created with + */ + OrtErrorCode(ORT_API_CALL* GetErrorCode)(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /** \brief Get error string from OrtStatus + * + * \param[in] status + * \return The error message inside the `status`. Do not free the returned value. + */ + const char*(ORT_API_CALL* GetErrorMessage)(_In_ const OrtStatus* status)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /// @} + /// \name OrtEnv + /// @{ + + /** \brief Create an OrtEnv + * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateEnv, OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + + /** \brief Create an OrtEnv + * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. If you want to provide your + * own logging function, consider setting it using the SetUserLoggingFunction API instead. + * \param[in] logging_function A pointer to a logging function. + * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `logging_function`. This parameter is optional. + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateEnvWithCustomLogger, _In_ OrtLoggingFunction logging_function, _In_opt_ void* logger_param, + _In_ OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + + /** \brief Enable Telemetry + * + * \note Telemetry events are on by default since they are lightweight + * \param[in] env + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(EnableTelemetryEvents, _In_ const OrtEnv* env); + /** \brief Disable Telemetry + * + * \see OrtApi::EnableTelemetryEvents + * \param[in] env + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(DisableTelemetryEvents, _In_ const OrtEnv* env); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief Create an OrtSession from a model file + * + * \param[in] env + * \param[in] model_path + * \param[in] options + * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + // TODO: document the path separator convention? '/' vs '\' + // TODO: should specify the access characteristics of model_path. Is this read only during the + // execution of CreateSession, or does the OrtSession retain a handle to the file/directory + // and continue to access throughout the OrtSession lifetime? + // What sort of access is needed to model_path : read or read/write? + ORT_API2_STATUS(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession from memory + * + * \param[in] env + * \param[in] model_data + * \param[in] model_data_length + * \param[in] options + * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + + /** \brief Run the model in an ::OrtSession + * + * Will not return until the model run has completed. Multiple threads might be used to run the model based on + * the options in the ::OrtSession and settings used when creating the ::OrtEnv + * + * \param[in] session + * \param[in] run_options If nullptr, will use a default ::OrtRunOptions + * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names + * \param[in] inputs Array of ::OrtValue%s of the input values + * \param[in] input_len Number of elements in the input_names and inputs arrays + * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names + * \param[in] output_names_len Number of elements in the output_names and outputs array + * \param[out] outputs Array of ::OrtValue%s that the outputs are stored in. This can also be + * an array of nullptr values, in this case ::OrtValue objects will be allocated and pointers + * to them will be set into the `outputs` array. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(Run, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options, + _In_reads_(input_len) const char* const* input_names, + _In_reads_(input_len) const OrtValue* const* inputs, size_t input_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _Inout_updates_all_(output_names_len) OrtValue** outputs); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Create an ::OrtSessionOptions object + * + * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these + * functions to enable them in the session:
+ * OrtSessionOptionsAppendExecutionProvider_CPU
+ * OrtSessionOptionsAppendExecutionProvider_CUDA
+ * OrtSessionOptionsAppendExecutionProvider_(remaining providers...)
+ * The order they are called indicates the preference order as well. In other words call this method + * on your most preferred execution provider first followed by the less preferred ones. + * If none are called Ort will use its internal CPU execution provider. + * + * \param[out] options The newly created OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSessionOptions, _Outptr_ OrtSessionOptions** options); + + /** \brief Set filepath to save optimized model after graph level transformations + * + * \param[in] options + * \param[in] optimized_model_filepath + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetOptimizedModelFilePath, _Inout_ OrtSessionOptions* options, + _In_ const ORTCHAR_T* optimized_model_filepath); + + /** \brief Create a copy of an existing ::OrtSessionOptions + * + * \param[in] in_options OrtSessionOptions to copy + * \param[out] out_options Returned newly created ::OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CloneSessionOptions, _In_ const OrtSessionOptions* in_options, + _Outptr_ OrtSessionOptions** out_options); + + /** \brief Set execution mode + * + * Controls whether you want to execute operators in your graph sequentially or in parallel. Usually when the model + * has many branches, setting this option to ExecutionMode.ORT_PARALLEL will give you better performance. + * See [docs/ONNX_Runtime_Perf_Tuning.md] for more details. + * + * \param[in] options + * \param[in] execution_mode + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetSessionExecutionMode, _Inout_ OrtSessionOptions* options, ExecutionMode execution_mode); + + /** \brief Enable profiling for a session + * + * \param[in] options + * \param[in] profile_file_prefix + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(EnableProfiling, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix); + + /** \brief Disable profiling for a session + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(DisableProfiling, _Inout_ OrtSessionOptions* options); + + /** \brief Enable the memory pattern optimization + * + * The idea is if the input shapes are the same, we could trace the internal memory allocation + * and generate a memory pattern for future request. So next time we could just do one allocation + * with a big chunk for all the internal memory allocation. + * \note Memory pattern optimization is only available when Sequential Execution mode is enabled (see OrtApi::SetSessionExecutionMode) + * + * \see OrtApi::DisableMemPattern + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(EnableMemPattern, _Inout_ OrtSessionOptions* options); + + /** \brief Disable the memory pattern optimization + * + * \see OrtApi::EnableMemPattern + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(DisableMemPattern, _Inout_ OrtSessionOptions* options); + + /** \brief Enable the memory arena on CPU + * + * Arena may pre-allocate memory for future usage. + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(EnableCpuMemArena, _Inout_ OrtSessionOptions* options); + + /** \brief Disable the memory arena on CPU + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(DisableCpuMemArena, _Inout_ OrtSessionOptions* options); + + /** \brief Set session log id + * + * \param[in] options + * \param[in] logid The log identifier. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetSessionLogId, _Inout_ OrtSessionOptions* options, const char* logid); + + /** \brief Set session log verbosity level + * + * Applies to session load, initialization, etc + * + * \param[in] options + * \param[in] session_log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetSessionLogVerbosityLevel, _Inout_ OrtSessionOptions* options, int session_log_verbosity_level); + + /** \brief Set session log severity level + * + * \param[in] options + * \param[in] session_log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetSessionLogSeverityLevel, _Inout_ OrtSessionOptions* options, int session_log_severity_level); + + /** \brief Set the optimization level to apply when loading a graph + * + * Please see https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html for an in-depth explanation + * \param[in,out] options The session options object + * \param[in] graph_optimization_level The optimization level + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, + GraphOptimizationLevel graph_optimization_level); + + /** \brief Sets the number of threads used to parallelize the execution within nodes + * + * When running a single node operation, ex. add, this sets the maximum number of threads to use. + * + * \note If built with OpenMP, this has no effect on the number of threads used. In this case + * use the OpenMP env variables to configure the number of intra op num threads. + * + * \param[in] options + * \param[in] intra_op_num_threads Number of threads to use
+ * A value of 0 will use the default number of threads
+ * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetIntraOpNumThreads, _Inout_ OrtSessionOptions* options, int intra_op_num_threads); + + /** \brief Sets the number of threads used to parallelize the execution of the graph + * + * If nodes can be run in parallel, this sets the maximum number of threads to use to run them in parallel. + * + * \note If sequential execution is enabled this value is ignored, it acts as if it was set to 1. + * + * \param[in] options + * \param[in] inter_op_num_threads Number of threads to use
+ * A value of 0 will use the default number of threads
+ * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetInterOpNumThreads, _Inout_ OrtSessionOptions* options, int inter_op_num_threads); + + /// @} + /// \name OrtCustomOpDomain + /// @{ + + /** \brief Create a custom op domain + * + * \param[in] domain + * \param[out] out Newly created domain. Must be freed with OrtApi::ReleaseCustomOpDomain + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtCustomOpDomain** out); + + /** \brief Add a custom op to a custom op domain + * + * \note The OrtCustomOp* pointer must remain valid until the ::OrtCustomOpDomain using it is released + * + * \param[in] custom_op_domain + * \param[in] op + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ const OrtCustomOp* op); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Add custom op domain to a session options + * + * \note The OrtCustomOpDomain* must not be deleted until all sessions using it are released + * + * \param[in] options + * \param[in] custom_op_domain + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(AddCustomOpDomain, _Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain); + + /** \deprecated Use OrtApi::RegisterCustomOpsLibrary_V2. + * + * Registers custom ops from a shared library. + * + * Loads a shared library (dll on windows, so on linux, etc) named 'library_path' and looks for this entry point: + * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); + * It then passes in the provided session options to this function along with the api base. + * The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in + * session options are destroyed, or if an error occurs and it is non null. + * + * \param[in] options + * \param[in] library_path + * \param[out] library_handle OS specific handle to the loaded library (Use FreeLibrary on Windows, dlclose on Linux, etc.. to unload) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, _Outptr_ void** library_handle); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief Get input count for a session + * + * This number must also match the number of inputs passed to OrtApi::Run + * + * \see OrtApi::SessionGetInputTypeInfo, OrtApi::SessionGetInputName, OrtApi::Session + * + * \param[in] session + * \param[out] out Number of inputs + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetInputCount, _In_ const OrtSession* session, _Out_ size_t* out); + + /** \brief Get output count for a session + * + * This number must also match the number of outputs returned by OrtApi::Run + * + * \see OrtApi::SessionGetOutputTypeInfo, OrtApi::SessionGetOutputName, OrtApi::Session + * + * \param[in] session + * \param[out] out Number of outputs + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOutputCount, _In_ const OrtSession* session, _Out_ size_t* out); + + /** \brief Get overridable initializer count + * + * \see OrtApi::SessionGetOverridableInitializerTypeInfo, OrtApi::SessionGetOverridableInitializerName + * + * \param[in] session + * \param[in] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOverridableInitializerCount, _In_ const OrtSession* session, _Out_ size_t* out); + + /** \brief Get input type information + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) + * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetInputTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get output type information + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) + * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOutputTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get overridable initializer type information + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) + * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOverridableInitializerTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get input name + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) + * \param[in] allocator + * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetInputName, _In_ const OrtSession* session, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** \brief Get output name + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) + * \param[in] allocator + * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOutputName, _In_ const OrtSession* session, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** \brief Get overridable initializer name + * + * \param[in] session + * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) + * \param[in] allocator + * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetOverridableInitializerName, _In_ const OrtSession* session, size_t index, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /// @} + /// \name OrtRunOptions + /// @{ + + /** \brief Create an OrtRunOptions + * + * \param[out] out Returned newly created ::OrtRunOptions. Must be freed with OrtApi::ReleaseRunOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateRunOptions, _Outptr_ OrtRunOptions** out); + + /** \brief Set per-run log verbosity level + * + * \see OrtApi::RunOptionsGetRunLogVerbosityLevel + * + * \param[in] options + * \param[in] log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RunOptionsSetRunLogVerbosityLevel, _Inout_ OrtRunOptions* options, int log_verbosity_level); + + /** \brief Set per-run log severity level + * + * \see OrtApi::RunOptionsGetRunLogSeverityLevel + * + * \param[in] options + * \param[in] log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). + */ + ORT_API2_STATUS(RunOptionsSetRunLogSeverityLevel, _Inout_ OrtRunOptions* options, int log_severity_level); + + /** \brief Set per-run tag + * + * This is used in a per-run log identifier. + * + * \see OrtApi::RunOptionsGetRunTag + * + * \param[in] options + * \param[in] run_tag The run tag. + */ + ORT_API2_STATUS(RunOptionsSetRunTag, _Inout_ OrtRunOptions* options, _In_ const char* run_tag); + + /** \brief Get per-run log verbosity level + * + * \see OrtApi::RunOptionsSetRunLogVerbosityLevel + * + * \param[in] options + * \param[out] log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptions* options, + _Out_ int* log_verbosity_level); + + /** \brief Get per-run log severity level + * + * \see OrtApi::RunOptionsSetRunLogSeverityLevel + * + * \param[in] options + * \param[out] log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). + */ + ORT_API2_STATUS(RunOptionsGetRunLogSeverityLevel, _In_ const OrtRunOptions* options, _Out_ int* log_severity_level); + + /** \brief Get per-run tag + * + * This is used in a per-run log identifier. + * + * \see OrtApi::RunOptionsSetRunTag + * + * \param[in] options + * \param[out] run_tag The run tag. + * Do not free this value, it is owned by `options`. It will be invalidated if the run tag + * changes (i.e., with OrtApi::RunOptionsSetRunTag) or `options` is freed. + */ + ORT_API2_STATUS(RunOptionsGetRunTag, _In_ const OrtRunOptions* options, _Out_ const char** run_tag); + + /** \brief Set terminate flag + * + * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error. + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RunOptionsSetTerminate, _Inout_ OrtRunOptions* options); + + /** \brief Clears the terminate flag + * + * Used so the OrtRunOptions instance can be used in a new OrtApi::Run call without it instantly terminating + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options); + + /// @} + /// \name OrtValue + /// @{ + + /** \brief Create a tensor + * + * Create a tensor using a supplied ::OrtAllocator + * + * \param[in] allocator + * \param[in] shape Pointer to the tensor shape dimensions. + * \param[in] shape_len The number of tensor shape dimensions. + * \param[in] type + * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, _Outptr_ OrtValue** out); + + /** \brief Create a tensor backed by a user supplied buffer + * + * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. + * p_data is owned by caller. ReleaseValue won't release p_data. + * + * If you wish to transfer ownership of p_data to ORT use CreateTensorWithDataAndDeleterAsOrtValue. + * + * \param[in] info Memory description of where the p_data buffer resides (CPU vs GPU etc). + * \param[in] p_data Pointer to the data buffer. + * \param[in] p_data_len The number of bytes in the data buffer. + * \param[in] shape Pointer to the tensor shape dimensions. + * \param[in] shape_len The number of tensor shape dimensions. + * \param[in] type The data type. + * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, + size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); + + /** \brief Return if an ::OrtValue is a tensor type + * + * \param[in] value A tensor type (string tensors are not supported) + * \param[out] out Set to 1 iff ::OrtValue is a tensor, 0 otherwise + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(IsTensor, _In_ const OrtValue* value, _Out_ int* out); + + /** \brief Get a pointer to the raw data inside a tensor + * + * Used to read/write/modify the internal tensor data directly. + * \note The returned pointer is valid until the \p value is destroyed. + * + * \param[in] value A tensor type (string tensors are not supported) + * \param[out] out Filled in with a pointer to the internal storage + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorMutableData, _In_ OrtValue* value, _Outptr_ void** out); + + /** \brief Set all strings at once in a string tensor + * + * \param[in,out] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING + * \param[in] s An array of strings. Each string in this array must be null terminated. + * \param[in] s_len Count of strings in s (Must match the size of \p value's tensor shape) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(FillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len); + + /** \brief Get total byte length for all strings in a string tensor + * + * Typically used with OrtApi::GetStringTensorContent + * + * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING + * \param[out] len Total byte length of all strings (does not include trailing nulls) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* len); + + /** \brief Get all strings from a string tensor + * + * An example of the results:
+ * Given \p value is a string tensor with the strings { "This" "is" "a" "test" }
+ * \p s must have a size of 11 bytes
+ * \p offsets must have 4 elements
+ * After the call, these values will be filled in:
+ * \p s will contain "Thisisatest"
+ * \p offsets will contain { 0, 4, 6, 7 }
+ * The length of the last string is just s_len - offsets[last] + * + * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING + * \param[in] s Buffer to sequentially write all tensor strings to. Each string is NOT null-terminated. + * \param[in] s_len Number of bytes of buffer pointed to by \p s (Get it from OrtApi::GetStringTensorDataLength) + * \param[out] offsets Array of start offsets into the strings written to \p s + * \param[in] offsets_len Number of elements in offsets + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s, + size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len); + + /// @} + /// \name OrtTypeInfo + /// @{ + + /** \brief Get ::OrtTensorTypeAndShapeInfo from an ::OrtTypeInfo + * + * \param[in] type_info + * \param[out] out Do not free this value, it will be valid until type_info is freed. + * If type_info does not represent tensor, this value will be set to nullptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtTensorTypeAndShapeInfo** out); + + /** \brief Get ::ONNXType from ::OrtTypeInfo + * + * \param[in] type_info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ enum ONNXType* out); + + /// @} + /// \name OrtTensorTypeAndShapeInfo + /// @{ + + /** \brief Create an ::OrtTensorTypeAndShapeInfo object + * + * \param[out] out Returns newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out); + + /** \brief Set element type in ::OrtTensorTypeAndShapeInfo + * + * \param[in] info + * \param[in] type + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo* info, enum ONNXTensorElementDataType type); + + /** \brief Set shape information in ::OrtTensorTypeAndShapeInfo + * + * \param[in] info + * \param[in] dim_values Array with `dim_count` elements. Can contain negative values. + * \param[in] dim_count Number of elements in `dim_values` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); + + /** \brief Get element type in ::OrtTensorTypeAndShapeInfo + * + * \see OrtApi::SetTensorElementType + * + * \param[in] info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo* info, + _Out_ enum ONNXTensorElementDataType* out); + + /** \brief Get dimension count in ::OrtTensorTypeAndShapeInfo + * + * \see OrtApi::GetDimensions + * + * \param[in] info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); + + /** \brief Get dimensions in ::OrtTensorTypeAndShapeInfo + * + * \param[in] info + * \param[out] dim_values Array with `dim_values_length` elements. On return, filled with the dimensions stored in the ::OrtTensorTypeAndShapeInfo + * \param[in] dim_values_length Number of elements in `dim_values`. Use OrtApi::GetDimensionsCount to get this value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, + size_t dim_values_length); + + /** \brief Get symbolic dimension names in ::OrtTensorTypeAndShapeInfo + * + * \param[in] info + * \param[in] dim_params Array with `dim_params_length` elements. On return filled with pointers to null terminated strings of the dimension names + * \param[in] dim_params_length Number of elements in `dim_params`. Use OrtApi::GetDimensionsCount to get this value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, + _Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length); + + /** \brief Get total number of elements in a tensor shape from an ::OrtTensorTypeAndShapeInfo + * + * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). + * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. + * + * Examples:
+ * [] = 1
+ * [1,3,4] = 12
+ * [2,0,4] = 0
+ * [-1,3,4] = -1
+ * + * \param[in] info + * \param[out] out Number of elements + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); + + /// @} + /// \name OrtValue + /// @{ + + /** \brief Get type and shape information from a tensor ::OrtValue + * + * \param[in] value Must be a tensor (not a map/sequence/etc) or will return failure + * \param[out] out Newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); + + /** \brief Get type information of an OrtValue + * + * \param[in] value + * \param[out] out Newly created ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out); + + /** \brief Get ONNXType of an ::OrtValue + * + * \param[in] value + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out); + + /// @} + /// \name OrtMemoryInfo + /// @{ + + /** \brief Create an ::OrtMemoryInfo + * + * \param[in] name + * \param[in] type + * \param[in] id + * \param[in] mem_type + * \param[out] out Newly created ::OrtMemoryInfo. Must be freed with OrtAPi::ReleaseMemoryInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateMemoryInfo, _In_ const char* name, enum OrtAllocatorType type, int id, + enum OrtMemType mem_type, _Outptr_ OrtMemoryInfo** out); + + /** \brief Create an ::OrtMemoryInfo for CPU memory + * + * Special case version of OrtApi::CreateMemoryInfo for CPU based memory. Same as using OrtApi::CreateMemoryInfo with name = "Cpu" and id = 0. + * + * \param[in] type + * \param[in] mem_type + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type, + _Outptr_ OrtMemoryInfo** out); + + /** \brief Compare ::OrtMemoryInfo objects for equality + * + * Compares all settings of each ::OrtMemoryInfo for equality + * + * \param[in] info1 + * \param[in] info2 + * \param[out] out Set to 0 if equal, -1 if not equal + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, _In_ const OrtMemoryInfo* info2, _Out_ int* out); + + /** \brief Get name from ::OrtMemoryInfo + * + * \param[in] ptr + * \param[out] out Writes null terminated string to this pointer. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtMemoryInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out); + + /** \brief Get the id from ::OrtMemoryInfo + */ + ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out); + + /** \brief Get the ::OrtMemType from ::OrtMemoryInfo + */ + ORT_API2_STATUS(MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtMemType* out); + + /** \brief Get the ::OrtAllocatorType from ::OrtMemoryInfo + */ + ORT_API2_STATUS(MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out); + + /// @} + /// \name OrtAllocator + /// @{ + + /// \brief Calls OrtAllocator::Alloc function + ORT_API2_STATUS(AllocatorAlloc, _Inout_ OrtAllocator* ort_allocator, size_t size, _Outptr_ void** out); + /// \brief Calls OrtAllocator::Free function + ORT_API2_STATUS(AllocatorFree, _Inout_ OrtAllocator* ort_allocator, void* p); + /// \brief Calls OrtAllocator::Info function + ORT_API2_STATUS(AllocatorGetInfo, _In_ const OrtAllocator* ort_allocator, _Outptr_ const struct OrtMemoryInfo** out); + + /** \brief Get the default allocator + * + * The default allocator is a CPU based, non-arena. Always returns the same pointer to the same default allocator. + * + * \param[out] out Returned value should NOT be freed + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetAllocatorWithDefaultOptions, _Outptr_ OrtAllocator** out); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Override session symbolic dimensions + * + * Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable + * optimizations that can take advantage of fixed values (such as memory planning, etc) + * + * \param[in] options + * \param[in] dim_denotation + * \param[in] dim_value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* dim_denotation, + _In_ int64_t dim_value); + + /// @} + /// \name OrtValue + /// @{ + + /* Internal information (not seen in Doxygen) + * + * APIs to support non-tensor types - map and sequence. + * Currently only the following types are supported + * Note: the following types should be kept in sync with data_types.h + * Map types + * ========= + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * std::map + * + * Sequence types + * ============== + * std::vector + * std::vector + * std::vector + * std::vector + * std::vector> + * std::vector + */ + + /** \brief Get non tensor data from an ::OrtValue + * + * If `value` is of type ONNX_TYPE_MAP, you need to retrieve the keys and values + * separately. Use index=0 to retrieve keys and index=1 to retrieve values. + * If `value` is of type ONNX_TYPE_SEQUENCE, use index to retrieve the index'th element + * of the sequence. + * + * \param[in] value + * \param[in] index See above for usage based on `value` type + * \param[in] allocator Allocator used to allocate ::OrtValue + * \param[out] out Created ::OrtValue that holds the element requested. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetValue, _In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** out); + + /** \brief Get non tensor value count from an ::OrtValue + * + * If `value` is of type ONNX_TYPE_MAP 2 will always be returned. For ONNX_TYPE_SEQUENCE + * the number of elements in the sequence will be returned + * + * \param[in] value + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetValueCount, _In_ const OrtValue* value, _Out_ size_t* out); + + /** \brief Create a map or sequence ::OrtValue + * + * To construct a map (ONNX_TYPE_MAP), use num_values = 2 and `in` should be an array of 2 ::OrtValue%s + * representing keys and values.
+ * + * To construct a sequence (ONNX_TYPE_SEQUENCE), use num_values = N where N is the number of the elements in the + * sequence. 'in' should be an array of N ::OrtValue%s. + * + * \param[in] in See above for details + * \param[in] num_values + * \param[in] value_type Must be either ONNX_TYPE_MAP or ONNX_TYPE_SEQUENCE + * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateValue, _In_reads_(num_values) const OrtValue* const* in, size_t num_values, + enum ONNXType value_type, _Outptr_ OrtValue** out); + + /** \brief Create an opaque (custom user defined type) ::OrtValue + * + * Constructs an ::OrtValue that contains a value of non-standard type created for + * experiments or while awaiting standardization. ::OrtValue in this case would contain + * an internal representation of the Opaque type. Opaque types are distinguished from + * each other by two strings 1) domain and 2) type name. The combination of the two + * must be unique, so the type representation is properly identified internally. The combination + * must be properly registered from within ORT at both compile/run time or by another API. + * + * To construct the ::OrtValue pass domain and type names, also a pointer to a data container + * the type of which must be known to both ORT and the client program. That data container may or may + * not match the internal representation of the Opaque type. The sizeof(data_container) is passed for + * verification purposes. + * + * \param[in] domain_name Null terminated string of the domain name + * \param[in] type_name Null terminated string of the type name + * \param[in] data_container User pointer Data to populate ::OrtValue + * \param[in] data_container_size Size in bytes of what `data_container` points to + * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateOpaqueValue, _In_z_ const char* domain_name, _In_z_ const char* type_name, + _In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out); + + /** \brief Get internal data from an opaque (custom user defined type) ::OrtValue + * + * Copies internal data from an opaque value into a user provided buffer + * + * \see OrtApi::CreateOpaqueValue + * + * \param[in] domain_name Null terminated string of the domain name + * \param[in] type_name Null terminated string of the type name + * \param[in] in The opaque ::OrtValue + * \param[out] data_container Buffer to copy data into + * \param[out] data_container_size Size in bytes of the buffer pointed to by data_container. Must match the size of the internal buffer. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in, + _Out_ void* data_container, size_t data_container_size); + + /// @} + /// \name OrtKernelInfo + /// Custom operator APIs. + /// @{ + + /** \brief Get a float stored as an attribute in the graph node + * + * \param[in] info ::OrtKernelInfo instance + * \param[in] name Null terminated string of the name of the attribute + * \param[out] out Pointer to memory where the attribute will be stored + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ float* out); + + /** \brief Fetch a 64-bit int stored as an attribute in the graph node + * + * \param[in] info ::OrtKernelInfo instance + * \param[in] name Null terminated string of the name of the attribute + * \param[out] out Pointer to memory where the attribute will be stored + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ int64_t* out); + + /** \brief Fetch a string stored as an attribute in the graph node + * + * If `out` is nullptr, the value of `size` is set to the true size of the string + * attribute, and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual string attribute's size, + * the value of `size` is set to the true size of the string attribute, the provided memory + * is filled with the attribute's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string attribute's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string attribute + * and a failure status is returned.) + * + * \param[in] info ::OrtKernelInfo instance + * \param[in] name Null terminated string of the name of the attribute + * \param[out] out Pointer to memory where the attribute will be stored + * \param[in,out] size See above comments for details + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, + _Inout_ size_t* size); + + /// @} + /// \name OrtKernelContext + /// Custom operator APIs. + /// @{ + + /** \brief Used for custom operators, get the input count of a kernel + * + * \see ::OrtCustomOp + */ + ORT_API2_STATUS(KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); + + /** \brief Used for custom operators, get the output count of a kernel + * + * \see ::OrtCustomOp + */ + ORT_API2_STATUS(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); + + /** \brief Used for custom operators, get an input of a kernel + * + * The function attempts fetches the input of the kernel. If the input is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] index See KernelContext_GetInputCount for boundaries check. + * \param[out] out OrtValue if the input is present otherwise is set nullptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, + _Out_ const OrtValue** out); + + /** \brief Used for custom operators, get an output of a kernel + * + * The function attempts fetches the output of the kernel. If the output is optional + * and not present, the function returns success and out is set to nullptr. + * + * \param[in] context ::OrtKernelContext instance + * \param[in] index See KernelContext_GetOutputCount for boundaries check. + * \param[in] dim_values output dimensions + * \param[in] dim_count number of dimensions + * \param[out] out a ptr to OrtValue to output otherwise set to nullptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, + _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out); + + /// @} + /// \name OrtEnv + /// @{ + ORT_CLASS_RELEASE(Env); + /// @} + /// \name OrtStatus + /// @{ + ORT_CLASS_RELEASE(Status); + /// @} + /// \name OrtMemoryInfo + /// @{ + ORT_CLASS_RELEASE(MemoryInfo); + /// @} + /// \name OrtSession + /// @{ + ORT_CLASS_RELEASE(Session); // Don't call ReleaseSession from Dllmain (because session owns a thread pool) + /// @} + /// \name OrtValue + /// @{ + ORT_CLASS_RELEASE(Value); + /// @} + /// \name OrtRunOptions + /// @{ + ORT_CLASS_RELEASE(RunOptions); + /// @} + /// \name OrtTypeInfo + /// @{ + ORT_CLASS_RELEASE(TypeInfo); + /// @} + /// \name OrtTensorTypeAndShapeInfo + /// @{ + ORT_CLASS_RELEASE(TensorTypeAndShapeInfo); + /// @} + /// \name OrtSessionOptions + /// @{ + ORT_CLASS_RELEASE(SessionOptions); + /// @} + /// \name OrtCustomOpDomain + /// @{ + ORT_CLASS_RELEASE(CustomOpDomain); + + /// @} + /// \name OrtTypeInfo + /// @{ + + /** \brief Get denotation from type information + * + * Augments ::OrtTypeInfo to return denotations on the type. + * + * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. + * + * \param[in] type_info + * \param[out] denotation Pointer to the null terminated denotation string is written to this pointer. This pointer is valid until the object is destroyed or the name is changed, do not free. + * \param[out] len Length in bytes of the string returned in `denotation` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const char** const denotation, + _Out_ size_t* len); + + /** \brief Get detailed map information from an ::OrtTypeInfo + * + * This augments ::OrtTypeInfo to return an ::OrtMapTypeInfo when the type is a map. + * The OrtMapTypeInfo has additional information about the map's key type and value type. + * + * This is used by WinML to support model reflection APIs. + * + * \param[out] type_info + * \param[out] out A pointer to the ::OrtMapTypeInfo. Do not free this value. If type_info + * does not contain a map, this value will be set to nullptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtMapTypeInfo** out); + + /** \brief Cast ::OrtTypeInfo to an ::OrtSequenceTypeInfo + * + * This api augments ::OrtTypeInfo to return an ::OrtSequenceTypeInfo when the type is a sequence. + * The ::OrtSequenceTypeInfo has additional information about the sequence's element type. + * + * This is used by WinML to support model reflection APIs. + * + * \param[in] type_info + * \param[out] out A pointer to the OrtSequenceTypeInfo. Do not free this value. If type_info + * doesn not contain a sequence, this value will be set to nullptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out); + + /// @} + /// \name OrtMapTypeInfo + /// @{ + + /** \brief Get key type from an ::OrtMapTypeInfo + * + * Key types are restricted to being scalar types. + * + * This is used by WinML to support model reflection APIs. + * + * \param[in] map_type_info + * \param[out] out + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out); + + /** \brief Get the value type from an ::OrtMapTypeInfo + * + * \param[in] map_type_info + * \param[out] type_info A copy of the OrtTypeInfo for the map value type. + * The user must free this value with ReleaseTypeInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info); + + /// @} + /// \name OrtSequenceTypeInfo + /// @{ + + /** \brief Get element type from an ::OrtSequenceTypeInfo + * + * This is used by WinML to support model reflection APIs. + * + * \param[in] sequence_type_info + * \param[out] type_info A copy of the OrtTypeInfo for the sequence element type. + * The user must free this value with ReleaseTypeInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, + _Outptr_ OrtTypeInfo** type_info); + + /// @} + /// \name OrtMapTypeInfo + /// @{ + ORT_CLASS_RELEASE(MapTypeInfo); + /// @} + /// \name OrtSequenceTypeInfo + /// @{ + ORT_CLASS_RELEASE(SequenceTypeInfo); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief End profiling and return filename of the profile data + * + * Profiling is turned on through OrtApi::EnableProfiling + * + * \param[in] session + * \param[in] allocator + * \param[out] out Null terminated string of the filename, allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionEndProfiling, _In_ OrtSession* session, _Inout_ OrtAllocator* allocator, _Outptr_ char** out); + + /** \brief Get ::OrtModelMetadata from an ::OrtSession + * + * \param[in] session + * \param[out] out Newly created ::OrtModelMetadata. Must be freed using OrtApi::ReleaseModelMetadata + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetModelMetadata, _In_ const OrtSession* session, _Outptr_ OrtModelMetadata** out); + + /// @} + /// \name OrtModelMetadata + /// @{ + + /** \brief Get `producer name` from an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetProducerName, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** \brief Get `graph name` from an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetGraphName, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** \brief Get `domain` from an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, + _Outptr_ char** value); + + /** \brief Get `description` from an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /** \brief Return data for a key in the custom metadata map in an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[in] allocator + * \param[in] key Null terminated string + * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` + * `value` will be set to nullptr if the given key is not found in the custom metadata map. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value); + + /** \brief Get version number from an ::OrtModelMetadata + * + * \param[in] model_metadata + * \param[out] value Set to the version number + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value); + + ORT_CLASS_RELEASE(ModelMetadata); + + /// @} + /// \name OrtEnv + /// @{ + + /** \brief Create an OrtEnv + * + * Create an environment with global threadpools that will be shared across sessions. + * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use + * its own thread pools. + * + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[in] tp_options + * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateEnvWithGlobalThreadPools, OrtLoggingLevel log_severity_level, _In_ const char* logid, + _In_ const OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Use global thread pool on a session + * + * Disable using per session thread pool and use the shared global threadpool. + * This should be used in conjunction with OrtApi::CreateEnvWithGlobalThreadPools. + * + * \param[in] options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(DisablePerSessionThreads, _Inout_ OrtSessionOptions* options); + + /// @} + /// \name OrtThreadingOptions + /// @{ + + /** \brief Create an ::OrtThreadingOptions + * + * \param[out] out Newly created ::OrtThreadingOptions. Must be freed with OrtApi::ReleaseThreadingOptions + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out); + + ORT_CLASS_RELEASE(ThreadingOptions); + + /// @} + /// \name OrtModelMetadata + /// @{ + + /** + * + * \param[in] model_metadata + * \param[in] allocator + * \param[out] keys Array of null terminated strings (array count = num_keys) allocated using `allocator`. + * The strings and the pointer array must be freed using `allocator` + * `keys` will be set to nullptr if the custom metadata map is empty. + * \param[out] num_keys Set to the number of elements in the `keys` array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** + * + * Override symbolic dimensions (by specific name strings) with actual values + * if known at session initialization time to enable optimizations that can + * take advantage of fixed values (such as memory planning, etc) + * + */ + ORT_API2_STATUS(AddFreeDimensionOverrideByName, + _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, + _In_ int64_t dim_value); + + /// @} + /// \name Misc + /// @{ + + /** \brief Get the names of all available providers + * + * \note The providers in the list are not guaranteed to be usable. They may fail to load due to missing system dependencies. + * For example, if the CUDA/cuDNN libraries are not installed, the CUDA provider will report an error when it is added to the session options. + * + * \param[out] out_ptr Set to a pointer to an array of null terminated strings of the available providers. The entries and the + * array itself must be freed using OrtApi::ReleaseAvailableProviders + * \param[out] provider_length Set to the number of entries in the `out_ptr` array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetAvailableProviders, _Outptr_ char*** out_ptr, _Out_ int* provider_length); + + /** \brief Release data from OrtApi::GetAvailableProviders. This API will never fail + * so you can rely on it in a noexcept code. + * + * \param[in] ptr The `out_ptr` result from OrtApi::GetAvailableProviders. + * \param[in] providers_length The `provider_length` result from OrtApi::GetAvailableProviders + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ReleaseAvailableProviders, _In_ char** ptr, + _In_ int providers_length); + + /// @} + /// \name OrtValue + /// @{ + + /** \brief Get the length of a single string in a string tensor + * + * \param[in] value A string tensor + * \param[in] index Index of the string in the tensor + * \param[out] out Set to number of bytes of the string element + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetStringTensorElementLength, _In_ const OrtValue* value, size_t index, _Out_ size_t* out); + + /** \brief Get a single string from a string tensor + * + * \param[in] value A string tensor + * \param[in] s_len Number of bytes in the `s` buffer. Must match the value returned by OrtApi::GetStringTensorElementLength. + * \param[in] index Index of the string in the tensor + * \param[out] s The string element contents in UTF-8 encoding. The string is NOT null-terminated. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetStringTensorElement, _In_ const OrtValue* value, size_t s_len, size_t index, _Out_writes_bytes_all_(s_len) void* s); + + /** \brief Set a single string in a string tensor + * + * \param[in] value A string tensor + * \param[in] s A null terminated UTF-8 encoded string + * \param[in] index Index of the string in the tensor to set + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(FillStringTensorElement, _Inout_ OrtValue* value, _In_ const char* s, size_t index); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Set a session configuration entry as a pair of strings + * + * If a configuration with same key exists, this will overwrite the configuration with the given config_value. + * + * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h + * + * \param[in] options + * \param[in] config_key A null terminated string representation of the config key + * \param[in] config_value A null terminated string representation of the config value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options, + _In_z_ const char* config_key, _In_z_ const char* config_value); + + /// @} + /// \name OrtAllocator + /// @{ + + /** \brief Create an allocator for an ::OrtSession following an ::OrtMemoryInfo + * + * \param[in] session + * \param[in] mem_info valid ::OrtMemoryInfo instance + * \param[out] out Newly created ::OrtAllocator. Must be freed with OrtApi::ReleaseAllocator + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateAllocator, _In_ const OrtSession* session, _In_ const OrtMemoryInfo* mem_info, + _Outptr_ OrtAllocator** out); + + /** \brief Release an ::OrtAllocator obtained from OrtApi::CreateAllocator + */ + ORT_CLASS_RELEASE(Allocator); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief Run a model using Io Bindings for the inputs & outputs + * + * \see OrtApi::Run + * + * \param[in] session + * \param[in] run_options + * \param[in] binding_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RunWithBinding, _Inout_ OrtSession* session, _In_ const OrtRunOptions* run_options, _In_ const OrtIoBinding* binding_ptr); + + /** \brief Create an ::OrtIoBinding instance + * + * An IoBinding object allows one to bind pre-allocated ::OrtValue%s to input names. + * Thus if you want to use a raw on device buffer as input or output you can avoid + * extra copy during runtime. + * + * \param[in] session + * \param[out] out Newly created ::OrtIoBinding. Must be freed with OrtApi::ReleaseIoBinding + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateIoBinding, _Inout_ OrtSession* session, _Outptr_ OrtIoBinding** out); + + /// @} + /// \name OrtIoBinding + /// @{ + + /** \brief Release an ::OrtIoBinding obtained from OrtApi::CreateIoBinding + */ + ORT_CLASS_RELEASE(IoBinding); + + /** \brief Bind an ::OrtValue to an ::OrtIoBinding input + * + * When using OrtApi::RunWithBinding this value is used for the named input + * + * \param[in] binding_ptr + * \param[in] name Name for the model input + * \param[in] val_ptr ::OrtValue of Tensor type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(BindInput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); + + /** \brief Bind an ::OrtValue to an ::OrtIoBinding output + * + * When using OrtApi::RunWithBinding this value is used for the named output + * + * \param[in] binding_ptr + * \param[in] name Null terminated string of the model output name + * \param[in] val_ptr ::OrtValue of Tensor type. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(BindOutput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); + + /** \brief Bind an ::OrtIoBinding output to a device + * + * Binds the ::OrtValue to a device which is specified by ::OrtMemoryInfo. + * You can either create an instance of ::OrtMemoryInfo with a device id or obtain one from the allocator that you have created/are using + * This is useful when one or more outputs have dynamic shapes and, it is hard to pre-allocate and bind a chunk of + * memory within ::OrtValue ahead of time. + * + * \see OrtApi::RunWithBinding + * + * \param[in] binding_ptr + * \param[in] name Null terminated string of the device name + * \param[in] mem_info_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(BindOutputToDevice, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtMemoryInfo* mem_info_ptr); + + /** \brief Get the names of an ::OrtIoBinding's outputs + * + * Returns the names of the outputs in the order they were bound. This is useful after running the model + * with bound outputs because the returned names are in order in which output ::OrtValue are returned. This is useful if + * the order of outputs and their names is not known. + * + * \param[in] binding_ptr + * \param[in] allocator Allocator used to allocate continuous buffers for output strings and lengths. + * \param[out] buffer Returns an array of non-null terminated UTF-8 strings. The number of strings stored is returned in the count parameter. + * This buffer is allocated using `allocator` and must be freed using it. + * \param[out] lengths Returns an array of `count` lengths of the strings returned in `buffer` + * This buffer is allocated using `allocator` and must be freed using it. + * \param[out] count Number of strings returned. If `binding_ptr` has no bound outputs, zero is returned, + * no memory allocation is performed and buffer and lengths are set to nullptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetBoundOutputNames, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, + _Out_ char** buffer, _Out_writes_all_(count) size_t** lengths, _Out_ size_t* count); + + /** \brief Get the output ::OrtValue objects from an ::OrtIoBinding + * + * Returns an array of pointers to individually allocated ::OrtValue%s that contain results of a model execution with OrtApi::RunWithBinding + * The array contains the same number of ::OrtValue%s and they are in the same order as they were bound with OrtApi::BindOutput + * or OrtApi::BindOutputToDevice. + * + * The returned ::OrtValue%s must be released using OrtApi::ReleaseValue after they are no longer needed. + * The array is allocated using the specified instance of the allocator and must be freed using the same allocator after + * all the ::OrtValue%s contained therein are individually released. + * + * \param[in] binding_ptr + * \param[in] allocator Allocator used to allocate output array + * \param[out] output Set to the allocated array of allocated ::OrtValue outputs. Set to nullptr if there are 0 outputs. + * \param[out] output_count Set to number of ::OrtValue%s returned + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetBoundOutputValues, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, + _Out_writes_all_(output_count) OrtValue*** output, _Out_ size_t* output_count); + + /** \brief Clears any previously set Inputs for an ::OrtIoBinding + */ + void(ORT_API_CALL* ClearBoundInputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /** \brief Clears any previously set Outputs for an ::OrtIoBinding + */ + void(ORT_API_CALL* ClearBoundOutputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; + + /// @} + /// \name OrtValue + /// @{ + + /** \brief Direct memory access to a specified tensor element + * + * For example, given a tensor with shape of [3,224,224], a pointer to the element at location [2,150,128] can be retrieved + * + * This function only works for numeric type tensors (No strings, etc). + * This is a no-copy method whose returned pointer is valid until the passed in ::OrtValue is free'd. + * + * \param[in] value + * \param[in] location_values Pointer to an array of index values that specify an element's location relative to its shape + * \param[in] location_values_count Number of elements in location_values. Must match the number of elements in the tensor's shape. + * \param[out] out Set to a pointer to the element specified + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out); + + /// @} + /// \name OrtEnv + /// @{ + + /** \brief Create an allocator and register it with the ::OrtEnv + * + * Enables sharing the allocator between multiple sessions that use the same env instance. + * Lifetime of the created allocator will be valid for the duration of the environment. + * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. + * + * See https://onnxruntime.ai/docs/get-started/with-c.html for details. + * + * \param[in] env ::OrtEnv instance + * \param[in] mem_info + * \param[in] arena_cfg Pass nullptr for defaults + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, + _In_ const OrtArenaCfg* arena_cfg); + + /** \brief Set language projection + * + * Set the language projection for collecting telemetry data when Env is created. + * + * The default is ORT_PROJECTION_C, which means it will classify the language not in the list to C also. + * + * \param[in] ort_env + * \param[in] projection + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief Return the time that profiling was started + * + * \note The timer precision varies per platform. On Windows and MacOS, the precision will be ~100ns + * + * \param[in] session + * \param[out] out nanoseconds of profiling's start time + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionGetProfilingStartTimeNs, _In_ const OrtSession* session, _Outptr_ uint64_t* out); + + /// @} + /// \name OrtThreadingOptions + /// @{ + + /** \brief Set global intra-op thread count + * + * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools + * + * \param[in] tp_options + * \param[in] intra_op_num_threads Number of threads, special values:
+ * 0 = Use default thread count
+ * 1 = The invoking thread will be used; no threads will be created in the thread pool. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads); + + /** \brief Set global inter-op thread count + * + * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools + * + * \param[in] tp_options + * \param[in] inter_op_num_threads Number of threads, special values:
+ * 0 = Use default thread count
+ * 1 = The invoking thread will be used; no threads will be created in the thread pool. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads); + + /** \brief Set global spin control options + * + * This will configure the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. + * Allow spinning of thread pools when their queues are empty. This will set the value for both + * inter_op and intra_op threadpools. + * + * \param[in] tp_options + * \param[in] allow_spinning Valid values are 0 or 1.
+ * 0 = It won't spin (recommended if CPU usage is high)
+ * 1 = Threadpool will spin to wait for queue to become non-empty + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Add a pre-allocated initializer to a session + * + * If a model contains an initializer with a name that is same as the name passed to this call, + * ORT will use this initializer instance instead of deserializing one from the model file. This + * is useful when you want to share the same initializer across sessions. + * + * \param[in] options + * \param[in] name Null terminated string of the initializer name + * \param[in] val ::OrtValue containing the initializer. Its lifetime and the underlying initializer buffer must be + * managed by the user (created using the OrtApi::CreateTensorWithDataAsOrtValue) and it must outlive the session object + * to which it is added. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ const char* name, + _In_ const OrtValue* val); + + /// @} + /// \name OrtEnv + /// @{ + + /** + * Create a custom environment with global threadpools and logger that will be shared across sessions. + * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use + * its own thread pools. + * + * \param[in] logging_function A pointer to a logging function. + * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `logging_function`. + * \param[in] log_severity_level The log severity level. + * \param[in] logid The log identifier. + * \param[in] tp_options + * \param[out] out Newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateEnvWithCustomLoggerAndGlobalThreadPools, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel log_severity_level, + _In_ const char* logid, _In_ const struct OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Append CUDA provider to session options + * + * If CUDA is not available (due to a non CUDA enabled build, or if CUDA is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] cuda_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA, + _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options); + + /** \brief Append ROCM execution provider to the session options + * + * If ROCM is not available (due to a non ROCM enabled build, or if ROCM is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] rocm_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_ROCM, + _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options); + + /** \brief Append OpenVINO execution provider to the session options + * + * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail. + * + * \param[in] options + * \param[in] provider_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO, + _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options); + + /// @} + /// \name OrtThreadingOptions + /// @{ + + /** \brief Set threading flush-to-zero and denormal-as-zero + * + * Sets global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. + * Flush-to-zero and denormal-as-zero are applied to threads in both intra and inter global thread pool. + * \note This option is not needed if the models used have no denormals. Having no denormals is recommended as this option may hurt model accuracy. + * + * \param[in] tp_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* tp_options); + + /// @} + /// \name OrtArenaCfg + /// @{ + + /** \deprecated Use OrtApi::CreateArenaCfgV2 + * + * This will create the configuration of an arena that can eventually be used to define an arena based allocator's behavior + * + * \param[in] max_mem Use 0 to allow ORT to choose the default + * \param[in] arena_extend_strategy Use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested + * \param[in] initial_chunk_size_bytes Use -1 to allow ORT to choose the default + * \param[in] max_dead_bytes_per_chunk Use -1 to allow ORT to choose the default + * \param[in] out A pointer to an OrtArenaCfg instance + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateArenaCfg, _In_ size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, + int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out); + + ORT_CLASS_RELEASE(ArenaCfg); + + /// @} + /// \name OrtModelMetadata + /// @{ + + /** + * Use this to obtain the description of the graph present in the model + * (doc_string field of the GraphProto message within the ModelProto message). + * If it doesn't exist, an empty string will be returned. + * + * \param[in] model_metadata An instance of ::OrtModelMetadata + * \param[in] allocator Allocator used to allocate the string that will be returned back + * \param[out] value Set to a null terminated string allocated using `allocator`. The caller is responsible for freeing it using `allocator` + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata, + _Inout_ OrtAllocator* allocator, _Outptr_ char** value); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Append TensorRT provider to session options + * + * If TensorRT is not available (due to a non TensorRT enabled build, or if TensorRT is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] tensorrt_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_TensorRT, + _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options); + + /// @} + /// \name Misc + /// @{ + + /** \brief Set current GPU device ID + * + * Set the current device id of the GPU execution provider (CUDA/tensorrt/rocm). The device id should be less + * than the total number of devices available. This is only useful when multiple-GPUs are installed and it is + * required to restrict execution to a single GPU. + * + * \param[in] device_id + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetCurrentGpuDeviceId, _In_ int device_id); + + /** \brief Get current GPU device ID + * + * Get the current device id of the GPU execution provider (CUDA/tensorrt/rocm). + * + * \see OrtApi::SetCurrentGpuDeviceId + * + * \param[out] device_id + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetCurrentGpuDeviceId, _In_ int* device_id); + + /// @} + /// \name OrtKernelInfo + /// Custom operator APIs. + /// @{ + + /** \brief Fetch an array of int64_t values stored as an attribute in the graph node + * + * + * If `out` is nullptr, the value of `size` is set to the true size of the attribute + * array's size, and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual attribute array's size, + * the value of `size` is set to the true size of the attribute array's size, + * the provided memory is filled with the attribute's contents, + * and a success status is returned. + * + * If the `size` parameter is less than the actual attribute array's size and `out` + * is not nullptr, the value of `size` is set to the true size of the attribute array's size + * and a failure status is returned.) + * + * \param[in] info instance + * \param[in] name name of the attribute to be parsed + * \param[out] out pointer to memory where the attribute's contents are to be stored + * \param[in, out] size actual size of attribute array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ float* out, _Inout_ size_t* size); + + /** \brief Fetch an array of int64_t values stored as an attribute in the graph node + * + * If `out` is nullptr, the value of `size` is set to the true size of the attribute + * array's size, and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual attribute array's size, + * the value of `size` is set to the true size of the attribute array's size, + * the provided memory is filled with the attribute's contents, + * and a success status is returned. + * + * If the `size` parameter is less than the actual attribute array's size and `out` + * is not nullptr, the value of `size` is set to the true size of the attribute array's size + * and a failure status is returned.) + * + * \param[in] info instance + * \param[in] name name of the attribute to be parsed + * \param[out] out pointer to memory where the attribute's contents are to be stored + * \param[in, out] size actual size of attribute array + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, + _Out_ int64_t* out, _Inout_ size_t* size); + + /// @} + /// \name OrtArenaCfg + /// @{ + + /** \brief Create an ::OrtArenaCfg + * + * Create the configuration of an arena that can eventually be used to define an arena based allocator's behavior. + * + * Supported keys are (See https://onnxruntime.ai/docs/get-started/with-c.html for details on what the + * following parameters mean and how to choose these values.): + * "max_mem": Maximum memory that can be allocated by the arena based allocator. + * Use 0 for ORT to pick the best value. Default is 0. + * "arena_extend_strategy": 0 = kNextPowerOfTwo, 1 = kSameAsRequested. + * Use -1 to allow ORT to choose the default. + * "initial_chunk_size_bytes": (Possible) Size of the first allocation in the arena. + * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. + * Ultimately, the first allocation size is determined by the allocation memory request. + * "max_dead_bytes_per_chunk": Threshold of unused memory in an allocated chunk of arena memory after + * crossing which the current chunk is chunked into 2. + * "initial_growth_chunk_size_bytes": (Possible) Size of the second allocation in the arena. + * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. + * "max_power_of_two_extend_bytes": The maximum enxtend size if arena strategy is `kNextPowerOfTwo`. + * It is not an allocation limit, it is only a limit for extension when requested byte is less than the limit. + * When requested bytes is more than the limit, allocator will still return as requested. + * Use -1 to allow ORT to choose the default 1GB for max_power_of_two_extend_bytes. + * Ultimately, the allocation size is determined by the allocation memory request. + * Further allocation sizes are governed by the arena extend strategy. + * + * \param[in] arena_config_keys Keys to configure the arena + * \param[in] arena_config_values Values to configure the arena + * \param[in] num_keys Number of keys in `arena_config_keys` and `arena_config_values` + * \param[out] out Newly created ::OrtArenaCfg. Must be freed with OrtApi::ReleaseArenaCfg + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateArenaCfgV2, _In_reads_(num_keys) const char* const* arena_config_keys, + _In_reads_(num_keys) const size_t* arena_config_values, _In_ size_t num_keys, + _Outptr_ OrtArenaCfg** out); + + /// @} + /// \name OrtRunOptions + /// @{ + + /** \brief Set a single run configuration entry as a pair of strings + * + * If a configuration with same key exists, this will overwrite the configuration with the given config_value + * + * The config_key and the format of config_value are defined in onnxruntime_run_options_config_keys.h + * + * \param[in] options + * \param[in] config_key A null terminated string representation of the config key + * \param[in] config_value A null terminated string representation of the config value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(AddRunConfigEntry, _Inout_ OrtRunOptions* options, + _In_z_ const char* config_key, _In_z_ const char* config_value); + + /// @} + /// \name OrtPrepackedWeightsContainer + /// @{ + + /** \brief Create an ::OrtPrepackedWeightsContainer + * + * This container will hold pre-packed buffers of shared initializers for sharing between sessions + * (i.e.) if there are shared initializers that can be shared between sessions, the pre-packed buffers + * of these (if any) may possibly be shared to provide memory footprint savings. Pass this container + * to sessions that you would like to share pre-packed buffers of shared initializers at session + * creation time. + * + * \param[out] out Newly created ::OrtPrepackedWeightsContainer. Must be freed with OrtApi::ReleasePrepackedWeightsContainer + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreatePrepackedWeightsContainer, _Outptr_ OrtPrepackedWeightsContainer** out); + + /** \brief Release OrtPrepackedWeightsContainer instance + * + * \note instance must not be released until the sessions using it are released + */ + ORT_CLASS_RELEASE(PrepackedWeightsContainer); + + /// @} + /// \name OrtSession + /// @{ + + /** \brief Create session with prepacked weights container + * + * Same functionality offered by OrtApi::CreateSession except that a container that contains + * pre-packed weights' buffers is written into/read from by the created session. + * This is useful when used in conjunction with OrtApi::AddInitializer which injects + * shared initializer info into sessions. Wherever possible, the pre-packed versions of these + * shared initializers are cached in this container so that multiple sessions can just re-use + * these instead of duplicating these in memory. + * + * \param[in] env OrtEnv instance instance + * \param[in] model_path Null terminated string of the path (wchar on Windows, char otherwise) + * \param[in] options + * \param[in] prepacked_weights_container + * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSessionWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _Outptr_ OrtSession** out); + + /** \brief Create session from memory with prepacked weights container + * + * Same functionality offered by OrtApi::CreateSessionFromArray except that a container that contains + * pre-packed weights' buffers is written into/read from by the created session. + * This is useful when used in conjunction with OrtApi::AddInitializer which injects + * shared initializer info into sessions. Wherever possible, the pre-packed versions of these + * shared initializers are cached in this container so that multiple sessions can just re-use + * these instead of duplicating these in memory. + * + * \param[in] env + * \param[in] model_data Array of bytes holding the model + * \param[in] model_data_length Number of bytes in `model_data_model` + * \param[in] options + * \param[in] prepacked_weights_container + * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, + _Outptr_ OrtSession** out); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Append TensorRT execution provider to the session options + * + * If TensorRT is not available (due to a non TensorRT enabled build), this function will return failure. + * + * This is slightly different from OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, it takes an + * ::OrtTensorRTProviderOptions which is publicly defined. This takes an opaque ::OrtTensorRTProviderOptionsV2 + * which must be created with OrtApi::CreateTensorRTProviderOptions. + * + * For OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, the user needs to instantiate ::OrtTensorRTProviderOptions + * as well as allocate/release buffers for some members of ::OrtTensorRTProviderOptions. + * Here, OrtApi::CreateTensorRTProviderOptions and Ortapi::ReleaseTensorRTProviderOptions will do the memory management for you. + * + * \param[in] options + * \param[in] tensorrt_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_TensorRT_V2, + _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options); + + /// @} + /// \name OrtTensorRTProviderOptionsV2 + /// @{ + + /** \brief Create an OrtTensorRTProviderOptionsV2 + * + * \param[out] out Newly created ::OrtTensorRTProviderOptionsV2. Must be released with OrtApi::ReleaseTensorRTProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateTensorRTProviderOptions, _Outptr_ OrtTensorRTProviderOptionsV2** out); + + /** \brief Set options in a TensorRT Execution Provider. + * + * Please refer to https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc + * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2 + * and value should be its related range. Recreates the options and only sets the supplied values. + * + * For example, key="trt_max_workspace_size" and value="2147483648" + * + * \param[in] tensorrt_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(UpdateTensorRTProviderOptions, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Get serialized TensorRT provider options string. + * + * For example, "trt_max_workspace_size=2147483648;trt_max_partition_iterations=10;trt_int8_enable=1;......" + * + * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with OrtApi::CreateAllocator or OrtApi::GetAllocatorWithDefaultOptions + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorRTProviderOptionsAsString, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an ::OrtTensorRTProviderOptionsV2 + * + * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does + */ + void(ORT_API_CALL* ReleaseTensorRTProviderOptions)(_Frees_ptr_opt_ OrtTensorRTProviderOptionsV2* input); + + /// @} + /// \name OrtSessionOptions + /// @{ + + /** \brief Enable custom operators + * + * See onnxruntime-extensions: https://github.com/microsoft/onnxruntime-extensions.git + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options); + + /// @} + /// \name OrtAllocator + /// @{ + + /** \brief Register a custom allocator + * + * Enables sharing between multiple sessions that use the same env instance. + * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. + * + * The behavior of this is exactly the same as OrtApi::CreateAndRegisterAllocator except + * instead of ORT creating an allocator based on provided info, in this case + * ORT uses the user-provided custom allocator. + * See https://onnxruntime.ai/docs/get-started/with-c.html for details. + * + * \param[in] env + * \param[in] allocator User provided allocator + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(RegisterAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator* allocator); + + /** \brief Unregister a custom allocator + * + * It is an error if you provide an ::OrtMemoryInfo not corresponding to any + * registered allocators for sharing. + * + * \param[in] env + * \param[in] mem_info + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(UnregisterAllocator, _Inout_ OrtEnv* env, + _In_ const OrtMemoryInfo* mem_info); + + /// @} + /// \name OrtValue + /// @{ + + /** \brief Sets *out to 1 iff an ::OrtValue is a SparseTensor, and 0 otherwise + * + * \param[in] value existing ::OrtValue + * \param[out] out unless an error occurs, contains 1 iff the value contains an instance + * of sparse tensor or 0 otherwise. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(IsSparseTensor, _In_ const OrtValue* value, _Out_ int* out); + + /** \brief Create an ::OrtValue with a sparse tensor that is empty. + * + * Use FillSparseTensor() functions to populate sparse tensor with non-zero values and + * format specific indices data. + * Use ReleaseValue to destroy the sparse tensor, this will also release the buffer inside the output value + * if any was allocated. + * \param[in,out] allocator allocator to use when performing an allocation. Allocation will be performed + * by FillSparseTensor() APIs. The lifespan of the allocator instance must eclipse the lifespan + * this sparse tensor instance as the same allocator will be used to free memory. + * \param[in] dense_shape shape of the original dense tensor + * \param[in] dense_shape_len number of shape dimensions being passed + * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx + * \param[out] out Should be freed by calling ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* dense_shape, + size_t dense_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); + + /** + * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. + * This will allocate required memory and copy the supplied NNZ values and COO indices into that memory allocation. + * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. + * + * \param[in,out] ort_value ::OrtValue to populate with data + * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified + * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. + * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. + * \param[in] values_shape pointer to values shape array + * \param[in] values_shape_len length of the values_shape + * \param[in] values pointer to an array of values. For strings, pass const char**. + * \param[in] indices_data pointer to a location of COO indices + * \param[in] indices_num number of COO indices + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(FillSparseTensorCoo, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, + _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, + _In_ const int64_t* indices_data, size_t indices_num); + + /** + * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. + * This will allocate required memory and copy the supplied NNZ values and CSR indices into that memory allocation. + * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. + * + * \param[in,out] ort_value ::OrtValue to populate with data + * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified + * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. + * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. + * \param[in] values_shape pointer to values shape array + * \param[in] values_shape_len length of the values_shape + * \param[in] values - pointer to an array of values. For strings, pass const char**. + * \param[in] inner_indices_data pointer to a location of CSR inner indices + * \param[in] inner_indices_num number of CSR inner indices + * \param[in] outer_indices_data pointer to a location of CSR outer indices + * \param[in] outer_indices_num number of CSR outer indices + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(FillSparseTensorCsr, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, + _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, + _In_ const int64_t* inner_indices_data, size_t inner_indices_num, + _In_ const int64_t* outer_indices_data, size_t outer_indices_num); + + /** + * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. + * This will allocate required memory and copy the supplied NNZ values and BlockSparse indices into that memory allocation. + * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. + * + * \param[in,out] ort_value ::OrtValue to populate with data + * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified + * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. + * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. + * \param[in] values_shape + * \param[in] values_shape_len + * \param[in] values structure with values information + * \param[in] indices_shape_data pointer to a location of indices shape + * \param[in] indices_shape_len length of the block sparse indices shape + * \param[in] indices_data pointer to a location of indices data. Shape will determine the length of the indices data. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(FillSparseTensorBlockSparse, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, + _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, + _In_ const int64_t* indices_shape_data, size_t indices_shape_len, + _In_ const int32_t* indices_data); + + /** + * Create an ::OrtValue with a sparse tensor. This is the first step. + * Next, use UseIndices() functions to supply sparse tensor with + * format specific indices data and set its sparse format to a specific enum value. + * This will not perform memory allocations. It will + * use supplied user buffer which should outlive the created sparse tensor. + * Use OrtApi::ReleaseValue to destroy the sparse tensor. It would not release the supplied values buffer. + * This function can not be used to map strings from the user allocated memory. Strings must always be copied + * and have UTF-8 encoding. Therefore, use OrtApi::CreateSparseTensorAsOrtValue above and then fill it with data + * using appropriate Make*() function. + * + * \param[in] info memory info where sparse values reside. + * \param[in,out] p_data pointer to a user allocated buffer with values. To create a full sparse tensor with no non-zero + * values, pass nullptr + * \param[in] dense_shape shape of the original dense tensor + * \param[in] dense_shape_len number of shape dimensions being passed + * \param[in] values_shape shape of the values data. To create a fully sparse tensor with no non-zero values, + * pass {0} shape. + * \param[in] values_shape_len number of values shape dimensions + * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx + * \param[out] out Should be freed by calling ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(CreateSparseTensorWithValuesAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, + _In_ const int64_t* dense_shape, size_t dense_shape_len, + _In_ const int64_t* values_shape, size_t values_shape_len, + ONNXTensorElementDataType type, _Outptr_ OrtValue** out); + + /** + * This assigns Coo format indices to the SparseTensor that was created by + * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to + * ORT_SPARSE_COO. This will not allocate any additional memory for data. The life span of + * indices_data buffer should eclipse the life span of this ::OrtValue. + * + * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue + * \param[in,out] indices_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. + * \param[in] indices_num number of COO indices. Should either be 0 for fully sparse tensors, be equal + * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue for 1-D {nnz} indices or + * be twice as number of nnz values for a 2-D indices {nnz, 2} + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(UseCooIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* indices_data, size_t indices_num); + + /** + * The assigns CSR format indices to the SparseTensor that was created by + * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to + * ORT_SPARSE_CSRC. This will not allocate any additional memory for data. The life spans of + * inner_data and outer_data buffers should eclipse the life span of this ::OrtValue. + * + * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue + * \param[in,out] inner_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. + * \param[in] inner_num number of inner CSR indices. Should either be 0 for fully sparse tensors or be equal + * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue. + * \param[in,out] outer_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. + * \param[in] outer_num number of CSR outer indices. Should either be 0 for fully sparse tensors or + * equal to rows + 1 of the dense shape. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(UseCsrIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* inner_data, size_t inner_num, + _Inout_ int64_t* outer_data, size_t outer_num); + + /** + * The assigns BlockSparse format indices to the SparseTensor that was created by + * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to + * ORT_SPARSE_BLOCK_SPARSE. This will not allocate any additional memory for data. The life span of + * indices_data buffer must eclipse the lifespan of this ::OrtValue. + * + * \param[in,out] ort_value OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue + * \param[in] indices_shape pointer to indices shape. Use {0} for fully sparse tensors + * \param[in] indices_shape_len length of the indices shape + * \param[in,out] indices_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(UseBlockSparseIndices, _Inout_ OrtValue* ort_value, const int64_t* indices_shape, size_t indices_shape_len, _Inout_ int32_t* indices_data); + + /** \brief Returns sparse tensor format enum iff a given ort value contains an instance of sparse tensor. + * + * \param[in] ort_value ::OrtValue that contains an instance of sparse tensor + * \param[out] out pointer to out parameter + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSparseTensorFormat, _In_ const OrtValue* ort_value, _Out_ enum OrtSparseFormat* out); + + /** \brief Returns data type and shape of sparse tensor values (nnz) iff ::OrtValue contains a SparseTensor. + * + * \param[in] ort_value An ::OrtValue that contains a fully constructed sparse tensor + * \param[out] out Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSparseTensorValuesTypeAndShape, _In_ const OrtValue* ort_value, _Outptr_ OrtTensorTypeAndShapeInfo** out); + + /** \brief Returns numeric data for sparse tensor values (nnz). For string values use GetStringTensor*(). + * + * \param[in] ort_value an instance of ::OrtValue containing sparse tensor + * \param[out] out returns a pointer to values data. Do not attempt to free this ptr. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outptr_ const void** out); + + /** \brief Returns data type, shape for the type of indices specified by indices_format. + * + * \param[in] ort_value ::OrtValue containing sparse tensor. + * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse + * tensor does not contain. + * \param[out] out an instance of ::OrtTensorTypeAndShapeInfo. Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out); + + /** \brief Returns indices data for the type of the indices specified by indices_format + * + * \param[in] ort_value ::OrtValue containing sparse tensor. + * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse tensor does not contain. + * \param[out] num_indices Pointer to where the number of indices entries is returned + * \param[out] indices Returned pointer to the indices data. Do not free the returned pointer as it refers to internal data owned by the ::OrtValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices); + /// @} + /// \name OrtSessionOptions + /// @{ + + /** + * \brief Sets out to 1 iff an optional type OrtValue has an element, 0 otherwise (OrtValue is None) + * Use this API to find if the optional type OrtValue is None or not. + * If the optional type OrtValue is not None, use the OrtValue just like any other OrtValue. + * For example, if you get an OrtValue that corresponds to Optional(tensor) and + * if HasValue() returns true, use it as tensor and so on. + + * \param[in] value Input OrtValue. + * \param[out] out indicating if the input OrtValue contains data (1) or if it is a None (0) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(HasValue, _In_ const OrtValue* value, _Out_ int* out); + + /// @} + /// \name OrtKernelContext + /// Custom operator APIs. + /// @{ + + /** \brief Used for custom operators, gets the GPU compute stream to use to launch the custom a GPU kernel + * \see ::OrtCustomOp + * \param[in] context OrtKernelContext instance + * \param[out] out Returns pointer to a GPU compute stream that can be used to launch the custom GPU kernel. + * If retrieving the GPU compute stream is not relevant (GPU not enabled in the build, kernel partitioned to + * some other EP), then a nullptr is returned as the output param. + * Do not free or mutate the returned pointer as it refers to internal data owned by the underlying session. + * Only use it for custom kernel launching. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out); + + /// @} + /// \name GetTensorMemoryInfo + /// @{ + /** \brief Returns a pointer to the ::OrtMemoryInfo of a Tensor + * \param[in] value ::OrtValue containing tensor. + * \param[out] mem_info ::OrtMemoryInfo of the tensor. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetTensorMemoryInfo, _In_ const OrtValue* value, _Out_ const OrtMemoryInfo** mem_info); + + /// @} + /// \name GetExecutionProviderApi + /// @{ + /** \brief Get a pointer to the requested version of the Execution Provider specific + * API extensions to the OrtApi + * \param[in] provider_name The name of the execution provider name. Currently only the following + * values are supported: "DML". + * \param[in] version Must be ::ORT_API_VERSION. + * \param[out] provider_api A void pointer containing a reference to the execution provider versioned api structure. + * For example, the provider_api pointer can be cast to the OrtDmlApi* when the provider_name is "DML". + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetExecutionProviderApi, _In_ const char* provider_name, _In_ uint32_t version, _Outptr_ const void** provider_api); + + /// @} + + /// \name SessionOptions + /// @{ + /** \brief Set custom thread creation function + * + * \param[in] options Session options + * \param[in] ort_custom_create_thread_fn Custom thread creation function + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); + + /** \brief Set creation options for custom thread + * + * \param[in] options Session options + * \param[in] ort_custom_thread_creation_options Custom thread creation options (can be nullptr) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options); + + /** \brief Set custom thread join function + * + * \param[in] options Session options + * \param[in] ort_custom_join_thread_fn Custom join thread function, must not be nullptr when ort_custom_create_thread_fn is set + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); + /// @} + + /// \name OrtThreadingOptions + /// @{ + /** \brief Set custom thread creation function for global thread pools + * + * \param[inout] tp_options + * \param[in] ort_custom_create_thread_fn Custom thread creation function + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); + + /** \brief Set custom thread creation options for global thread pools + * + * \param[inout] tp_options + * \param[in] ort_custom_thread_creation_options Custom thread creation options (can be nullptr) + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options); + + /** \brief Set custom thread join function for global thread pools + * + * \param[inout] tp_options + * \param[in] ort_custom_join_thread_fn Custom thread join function, must not be nullptr when global ort_custom_create_thread_fn is set + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); + /// @} + + /** \brief Synchronize bound inputs. The call may be necessary for some providers, such as cuda, + * in case the system that allocated bound memory operated on a different stream. However, the + * operation is provider specific and could be a no-op. + * + * \param[inout] binding_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SynchronizeBoundInputs, _Inout_ OrtIoBinding* binding_ptr); + + /** \brief Synchronize bound outputs. The call may be necessary for some providers, such as cuda, + * in case the system that allocated bound memory operated on a different stream. However, the + * operation is provider specific and could be a no-op. + * + * \param[inout] binding_ptr + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SynchronizeBoundOutputs, _Inout_ OrtIoBinding* binding_ptr); + + /// \name OrtSessionOptions + /// @{ + + /** \brief Append CUDA execution provider to the session options + * + * If CUDA is not available (due to a non CUDA enabled build), this function will return failure. + * + * This is slightly different from OrtApi::SessionOptionsAppendExecutionProvider_CUDA, it takes an + * ::OrtCUDAProviderOptions which is publicly defined. This takes an opaque ::OrtCUDAProviderOptionsV2 + * which must be created with OrtApi::CreateCUDAProviderOptions. + * + * For OrtApi::SessionOptionsAppendExecutionProvider_CUDA, the user needs to instantiate ::OrtCUDAProviderOptions + * as well as allocate/release buffers for some members of ::OrtCUDAProviderOptions. + * Here, OrtApi::CreateCUDAProviderOptions and Ortapi::ReleaseCUDAProviderOptions will do the memory management for you. + * + * \param[in] options + * \param[in] cuda_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA_V2, + _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptionsV2* cuda_options); + + /// @} + /// \name OrtCUDAProviderOptionsV2 + /// @{ + + /** \brief Create an OrtCUDAProviderOptionsV2 + * + * \param[out] out Newly created ::OrtCUDAProviderOptionsV2. Must be released with OrtApi::ReleaseCudaProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(CreateCUDAProviderOptions, _Outptr_ OrtCUDAProviderOptionsV2** out); + + /** \brief Set options in a CUDA Execution Provider. + * + * Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options + * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2 + * and value should be its related range. Recreates the options and only sets the supplied values. + * + * For example, key="device_id" and value="0" + * + * \param[in] cuda_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(UpdateCUDAProviderOptions, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** + * Get serialized CUDA provider options string. + * + * For example, "device_id=0;arena_extend_strategy=0;......" + * + * \param cuda_options - OrtCUDAProviderOptionsV2 instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(GetCUDAProviderOptionsAsString, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an ::OrtCUDAProviderOptionsV2 + * + * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does + * + * \since Version 1.11. + */ + void(ORT_API_CALL* ReleaseCUDAProviderOptions)(_Frees_ptr_opt_ OrtCUDAProviderOptionsV2* input); + + /// @} + + /** \brief Append MIGraphX provider to session options + * + * If MIGraphX is not available (due to a non MIGraphX enabled build, or if MIGraphX is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] migraphx_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.11. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_MIGraphX, + _In_ OrtSessionOptions* options, _In_ const OrtMIGraphXProviderOptions* migraphx_options); + + /** \brief Replace initialized Tensors with external data with the data provided in initializers. + * + * The function will find the initialized TensorProtos with external data in the graph with the provided names and + * replace them with the provided tensors. The API verifies that the TensorProto being replaced + * has an external data reference and has the same name, dimensions and data type as its replacement. The replacement + * will occur before any of the optimizations take place. The data will be copied into the graph + * since TensorProto can't refer to the user provided buffers. + * + * Once the model has been loaded, the OrtValue(s) added to SessionOptions instance will be removed + * from the internal SessionOptions copy to save memory, the user provided buffers can then be deallocated + * and the SessionOptions instance that refers to them can be destroyed. + * + * \param[in] options + * \param[in] initializer_names Array of null terminated UTF-8 encoded strings of the initializers names. + * \param[in] initializers Array of ::OrtValue type + * \param[in] num_initializers Number of elements in the initializer_names and initializers + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.12. + */ + ORT_API2_STATUS(AddExternalInitializers, _In_ OrtSessionOptions* options, + _In_reads_(num_initializers) const char* const* initializer_names, + _In_reads_(num_initializers) const OrtValue* const* initializers, size_t num_initializers); + + /** \brief: Create attribute of onnxruntime operator + * + * \param[in] name Name of the attribute + * \param[in] data Data content of the attribute + * \param[in] len Number of bytes stored in data + * \param[in] type Data type + * \param[out] op_attr Attribute that has been created, which must be released by OrtApi::ReleaseOpAttr + * + * \since Version 1.12. + */ + ORT_API2_STATUS(CreateOpAttr, + _In_ const char* name, + _In_ const void* data, + _In_ int len, + _In_ OrtOpAttrType type, + _Outptr_ OrtOpAttr** op_attr); + + /* \brief: Release op attribute + * + * \param[in] opAttr Attribute created by OrtApi::CreateOpAttr + * + * \since Version 1.12. + */ + ORT_CLASS_RELEASE(OpAttr); + + /** \brief: Create onnxruntime native operator + * + * \param[in] info Kernel info + * \param[in] op_name Operator name + * \param[in] domain Operator domain + * \param[in] version Operator opset version + * \param[in] type_constraint_names Name of the type constraints, such as "T" or "T1" + * \param[in] type_constraint_values Type of each constraints + * \param[in] type_constraint_count Number of constraints + * \param[in] attr_values Attributes used to initialize the operator + * \param[in] attr_count Number of the attributes + * \param[in] input_count Number of inputs + * \param[in] output_count Number of outputs + * \param[out] ort_op Operator that has been created + * + * \since Version 1.12. + */ + ORT_API2_STATUS(CreateOp, + _In_ const OrtKernelInfo* info, + _In_z_ const char* op_name, + _In_z_ const char* domain, + int version, + _In_reads_(type_constraint_count) const char** type_constraint_names, + _In_reads_(type_constraint_count) const ONNXTensorElementDataType* type_constraint_values, + int type_constraint_count, + _In_reads_(attr_count) const OrtOpAttr* const* attr_values, + int attr_count, + int input_count, + int output_count, + _Outptr_ OrtOp** ort_op); + + /** \brief: Invoke the operator created by OrtApi::CreateOp + * The inputs must follow the order as specified in onnx specification + * + * \param[in] context Kernel context + * \param[in] ort_op Operator that has been created + * \param[in] input_values Array of inputs + * \param[in] input_count Number of inputs + * \param[in] output_values Array of outputs + * \param[in] output_count Number of outputs + * + * \since Version 1.12. + */ + ORT_API2_STATUS(InvokeOp, + _In_ const OrtKernelContext* context, + _In_ const OrtOp* ort_op, + _In_ const OrtValue* const* input_values, + _In_ int input_count, + _Inout_ OrtValue* const* output_values, + _In_ int output_count); + + /* \brief: Release an onnxruntime operator + * + * \param[in] Op Operator created by OrtApi::CreateOp + * + * \since Version 1.12. + */ + ORT_CLASS_RELEASE(Op); + + /** \brief: Append execution provider to the session options. + * \param[in] options + * \param[in] provider_name - provider to add. + * \param[in] provider_options_keys - keys to configure the provider options + * \param[in] provider_options_values - values to configure the provider options + * \param[in] num_keys - number of keys passed in + * + * Currently supported provider names: + * QNNExecutionProvider (or QNN) + * OpenVINOExecutionProvider (or OpenVINO) + * XnnpackExecutionProvider (or XNNPACK) + * WebNNExecutionProvider (or WEBNN) + * WebGpuExecutionProvider (or WebGPU) + * AzureExecutionProvider (or AZURE) + * JsExecutionProvider (or JS) + * VitisAIExecutionProvider (or VitisAI) + * CoreMLExecutionProvider (or CoreML) + * + * Note: If an execution provider has a dedicated SessionOptionsAppendExecutionProvider_ function + * that should be used to add it. + * + * QNN supported keys: + * "backend_type": Type of QNN backend. Specifies a backend path that is the associated QNN backend library file + * name. E.g., given backend type "htp", on Windows, the backend path would be "QnnHtp.dll", and on other + * platforms, it would be "libQnnHtp.so". Mutually exclusive with "backend_path". + * Available options: + * -# "cpu" + * -# "gpu" + * -# "htp": Default. + * -# "saver" + * -# "ir" + * "backend_path": File path to QNN backend library. Mutually exclusive with "backend_type". + * "profiling_level": QNN profiling level. + * Available options: + * -# "off": Default. + * -# "basic" + * -# "detailed" + * "profiling_file_path": QNN profiling file path if ETW not enabled. + * "rpc_control_latency": QNN RPC control latency. + * "vtcm_mb": QNN VTCM size in MB. default to 0(not set). + * "htp_performance_mode": QNN performance mode. + * Available options: + * -# "burst" + * -# "balanced" + * -# "default": Default. + * -# "high_performance" + * -# "high_power_saver" + * -# "low_balanced" + * -# "extreme_power_saver" + * -# "low_power_saver" + * -# "power_saver" + * -# "sustained_high_performance" + * "dump_qnn_ir_dlc": Use the QnnIr backend library to write .dlc files for each subgraph dispatched to QNN. When + * enabled, inference results will be incorrect. Use only for debugging. + * -# "0": Default: disabled + * -# "1": enabled + * "dump_qnn_ir_dlc_dir": Set the directory into which QnnIr will be configured to write QNN graphs as .dlc files. + * Default is current working directory. + * "qnn_ir_backend_path": File path to the QnnIr backend library. If "dump_qnn_ir_dlc" is enabled, use this path + * instead of looking for the Ir backend in the standard location. + * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will + * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and + * may alter model/EP partitioning. Use only for debugging. + * "qnn_context_priority": QNN context priority. + * Available options: + * -# "low" + * -# "normal": Default. + * -# "normal_high" + * -# "high" + * "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. + * Available options: + * -# "0": Default. + * -# "1": Faster preparation time, less optimal graph. + * -# "2": Longer preparation time, more optimal graph. + * -# "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific + * details. + * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. + * Defaults to "0" (unknown). + * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. + * Available options: + * -# "0": Default (none). + * -# "68" + * -# "69" + * -# "73" + * -# "75" + * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). + * "enable_htp_fp16_precision": Used for float32 model for HTP backend. + * Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. + * -# "0": With fp32 precision. + * -# "1": Default. With fp16 precision. + * "offload_graph_io_quantization": Offload graph input quantization and graph output dequantization to another + * execution provider (typically CPU EP). + * -# "0": Disabled. QNN EP will handle quantization and dequantization of graph I/O. + * -# "1": Enabled. This is the default value. + * "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context + * binary. + * -# "0": Default. Disabled. + * -# "1": Enabled. + * "enable_htp_shared_memory_allocator": Enable the QNN HTP shared memory allocator. Requires libcdsprpc.so/dll to + * be available. + * -# "0": Default. Disabled. + * -# "1": Enabled. + * "dump_json_qnn_graph": Set to "1" to dump QNN graphs generated by QNN EP as JSON files. Each graph partition + * assigned to QNN EP is dumped to a separate file. + * "json_qnn_graph_dir": Directory in which to dump QNN JSON graphs. If not specified, QNN graphs are dumped in the + * program's current working directory. Ignored if "dump_json_qnn_graph" is not set. + * + * XNNPACK supported keys: + * "intra_op_num_threads": number of thread-pool size to use for XNNPACK execution provider. + * default value is 0, which means to use the session thread-pool size. + * + * \since Version 1.12. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, + _In_ const char* provider_name, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /* \brief: Get a copy of kernel info + * + * \param[in] info Kernel info + * \param[out] info_copy Copy of kernel info + * + * \since Version 1.12. + */ + ORT_API2_STATUS(CopyKernelInfo, + _In_ const OrtKernelInfo* info, + _Outptr_ OrtKernelInfo** info_copy); + + /* \brief: Release kernel info + * + * \param[in] KernelInfo A copy of kernel info returned by CopyKernelInfo + * + * \since Version 1.12. + */ + ORT_CLASS_RELEASE(KernelInfo); + + /// \name Ort Training + /// @{ + /** \brief Gets the Training C Api struct + * + * Call this function to access the ::OrtTrainingApi structure that holds pointers to functions that enable + * training with onnxruntime. + * \note A NULL pointer will be returned and no error message will be printed if the training api + * is not supported with this build. A NULL pointer will be returned and an error message will be + * printed if the provided version is unsupported, for example when using a runtime older than the + * version created with this header file. + * + * \param[in] version Must be ::ORT_API_VERSION + * \return The ::OrtTrainingApi struct for the version requested. + * + * \since Version 1.13 + */ + const OrtTrainingApi*(ORT_API_CALL* GetTrainingApi)(uint32_t version)NO_EXCEPTION; + + /// @} + + /** \brief Append CANN provider to session options + * + * If CANN is not available (due to a non CANN enabled build, or if CANN is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] cann_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CANN, + _In_ OrtSessionOptions* options, _In_ const OrtCANNProviderOptions* cann_options); + + /** \brief Create an OrtCANNProviderOptions + * + * \param[out] out created ::OrtCANNProviderOptions. Must be released with OrtApi::ReleaseCANNProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(CreateCANNProviderOptions, _Outptr_ OrtCANNProviderOptions** out); + + /** \brief Set options in a CANN Execution Provider. + * + * \param[in] cann_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(UpdateCANNProviderOptions, _Inout_ OrtCANNProviderOptions* cann_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Get serialized CANN provider options string. + * + * \param[in] cann_options OrtCANNProviderOptions instance + * \param[in] allocator a ptr to an instance of OrtAllocator obtained with CreateAllocator() + * or GetAllocatorWithDefaultOptions(), the specified allocator will be used to allocate + * continuous buffers for output strings and lengths. + * \param[out] ptr is a UTF-8 null terminated string allocated using 'allocator'. + * The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.13. + */ + ORT_API2_STATUS(GetCANNProviderOptionsAsString, _In_ const OrtCANNProviderOptions* cann_options, + _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an OrtCANNProviderOptions + * + * \param[in] input The pointer of OrtCANNProviderOptions which will been deleted + * + * \since Version 1.13. + */ + void(ORT_API_CALL* ReleaseCANNProviderOptions)(_Frees_ptr_opt_ OrtCANNProviderOptions* input); + + /* \brief Get OrtDevice type from MemoryInfo + * + * \since Version 1.14 + */ + void(ORT_API_CALL* MemoryInfoGetDeviceType)(_In_ const OrtMemoryInfo* ptr, _Out_ OrtMemoryInfoDeviceType* out); + + /* \brief Update the OrtEnv instance with custom log severity level + * + * \param[in] ort_env The OrtEnv instance being used + * \param[in] log_severity_level The log severity level. + * + * \since Version 1.14. + */ + ORT_API2_STATUS(UpdateEnvWithCustomLogLevel, _In_ OrtEnv* ort_env, OrtLoggingLevel log_severity_level); + + /* \brief Set affinities for intra op threads + * + * Affinity string follows format: + * logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id + * Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to. + * e.g. 1,2,3;4,5 + * specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th. + * To ease the configuration, an "interval" is also allowed: + * e.g. 1-8;8-16;17-24 + * orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth. + * Note: + * 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, + * ort does not set affinity on the main thread which is started and managed by the calling app; + * 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors, + * an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group. + * Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary. + * + * \since Version 1.14 + */ + ORT_API2_STATUS(SetGlobalIntraOpThreadAffinity, _Inout_ OrtThreadingOptions* tp_options, const char* affinity_string); + + /** \brief Register custom ops from a shared library. + * + * Loads a shared library (.dll on windows, .so on linux, etc) named 'library_name' and looks for this entry point: + * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); + * It then passes in the provided session options to this function along with the api base. + * + * The handle to the loaded library is automatically released by ORT when the last OrtSession that references the + * library handle is released. If no OrtSession is created, then the library handle is released when the provided + * OrtSessionOptions is released. + * + * \param[in] options The session options. + * \param[in] library_name The name of the shared library to load and register. Refer to OS-specific dynamic library + * loading utilities (e.g., LoadLibraryEx on Windows or dlopen on Linux/MacOS) for information + * on the format of library names and search paths. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(RegisterCustomOpsLibrary_V2, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* library_name); + + /** \brief Register custom ops by calling a RegisterCustomOpsFn function. + * + * Searches for registration_func_name and if found calls it. + * + * The library containing the function must either be linked against or previously loaded by the executable. + * + * If you want ONNX Runtime to load the library and manage its lifetime, use RegisterCustomOpsLibrary_V2. + * + * RegisterCustomOpsUsingFunction can be used in scenarios where it may not be possible for ONNX Runtime to load + * the library from a path. e.g. mobile platforms where the library must be linked into the app. + * + * The registration function must have the signature of RegisterCustomOpsFn: + * OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); + * + * See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for details on how the registration + * function should be implemented. + * + * \param[in] options OrtSessionOptions that is passed through as the first argument in the call to the + * registration function. + * \param[in] registration_func_name Name of registration function to use. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(RegisterCustomOpsUsingFunction, _Inout_ OrtSessionOptions* options, + _In_ const char* registration_func_name); + + /// \name OrtKernelInfo + /// Custom operator APIs. + /// @{ + + /** \brief Get the number of inputs from ::OrtKernelInfo. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the number of inputs + * during kernel/session creation. + * + * \param[in] info Instance of ::OrtKernelInfo. + * \param[out] out Pointer to variable assigned with the result on success. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); + + /** \brief Get the number of outputs from ::OrtKernelInfo. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the number of outputs + * during kernel/session creation. + * + * \param[in] info Instance of ::OrtKernelInfo. + * \param[out] out Pointer to variable assigned with the result on success. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetOutputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); + + /** \brief Get the name of a ::OrtKernelInfo's input. + * + * Used in the CreateKernel callback of an OrtCustomOp to query an input's name + * during kernel/session creation. + * + * If `out` is nullptr, the value of `size` is set to the size of the name + * string (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the name string's size, + * the value of `size` is set to the true size of the string (including null-terminator), + * the provided memory is filled with the string's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string + * and a failure status is returned. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[in] index The index of the input name to get. Returns a failure status if out-of-bounds. + * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the input's name. + * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, + _Inout_ size_t* size); + + /** \brief Get the name of a ::OrtKernelInfo's output. + * + * Used in the CreateKernel callback of an OrtCustomOp to query an output's name + * during kernel/session creation. + * + * If `out` is nullptr, the value of `size` is set to the size of the name + * string (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the name string's size, + * the value of `size` is set to the true size of the string (including null-terminator), + * the provided memory is filled with the string's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string + * and a failure status is returned. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[in] index The index of the output name to get. Returns a failure status if out-of-bounds. + * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the output's + * name. + * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, + _Inout_ size_t* size); + + /** \brief Get the type information for a ::OrtKernelInfo's input. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the shape and type information + * of an input during kernel/session creation. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[in] index Which input to get the type information for + * \param[out] type_info Pointer set to the resulting ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetInputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get the type information for a ::OrtKernelInfo's output. + * + * Used in the CreateKernel callback of an OrtCustomOp to query the shape and type information + * of an output during kernel/session creation. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[in] index Which input to get the type information for + * \param[out] type_info Pointer set to the resulting ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Get a ::OrtValue tensor stored as an attribute in the graph node. + * + * Used in the CreateKernel callback of an OrtCustomOp to get a tensor attribute. + * + * \param[in] info ::OrtKernelInfo instance. + * \param[in] name UTF-8 null-terminated string representing the attribute's name. + * \param[in] allocator Allocator used to allocate the internal tensor state. + * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue, + * which will also free internal tensor state allocated with the provided allocator. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, + _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out); + + /// @} + /// \name OrtSessionOptions + /// Custom operator APIs + /// @{ + + /** \brief Checks if the given session configuration entry exists. + * + * The config_key formats are defined in onnxruntime_session_options_config_keys.h + * + * Can be used in a custom operator library to check for session configuration entries + * that target one or more custom operators in the library. Example: The config entry + * custom_op.myop.some_key targets a custom op named "myop". + * + * \param[in] options The ::OrtSessionOptions instance. + * \param[in] config_key A null-terminated UTF-8 string representation of the configuration key. + * \param[out] out Pointer set to 1 if the entry exists and 0 otherwise. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(HasSessionConfigEntry, _In_ const OrtSessionOptions* options, + _In_z_ const char* config_key, _Out_ int* out); + + /** \brief Get a session configuration value. + * + * Returns a failure status if the configuration key does not exist. + * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h + * + * If `config_value` is nullptr, the value of `size` is set to the true size of the string + * value (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the actual string value's size, + * the value of `size` is set to the true size of the string value, the provided memory + * is filled with the value's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string value's size and `config_value` + * is not nullptr, the value of `size` is set to the true size of the string value + * and a failure status is returned. + * + * Can be used in a custom operator library to get session configuration entries + * that target one or more custom operators in the library. Example: The config entry + * custom_op.myop.some_key targets a custom op named "myop". + * + * \param[in] options The session options. + * \param[in] config_key A null-terminated UTF-8 string representation of the config key. + * \param[in] config_value Pointer to memory where the null-terminated UTF-8 string value will be stored. + * \param[in,out] size Pointer to the size of the `config_value` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.14 + */ + ORT_API2_STATUS(GetSessionConfigEntry, _In_ const OrtSessionOptions* options, + _In_z_ const char* config_key, _Out_ char* config_value, _Inout_ size_t* size); + + /// @} + + /** \brief Append dnnl provider to session options + * + * If oneDNN is not available, this function will return failure. + * + * \param[in] options + * \param[in] dnnl_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_Dnnl, + _In_ OrtSessionOptions* options, _In_ const OrtDnnlProviderOptions* dnnl_options); + + /** \brief Create an OrtDnnlProviderOptions + * + * \param[out] out Newly created ::OrtDnnlProviderOptions. Must be released with OrtApi::ReleaseDnnlProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(CreateDnnlProviderOptions, _Outptr_ OrtDnnlProviderOptions** out); + + /** \brief Set options in a oneDNN Execution Provider. + * + * Key should be in null terminated string format of the member of ::OrtDnnlProviderOptions + * and value should be its related range. + * + * For example, key="use_arena" and value="1" + * + * \param[in] dnnl_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(UpdateDnnlProviderOptions, _Inout_ OrtDnnlProviderOptions* dnnl_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** + * Get serialized oneDNN provider options string. + * + * For example, "use_arena=1;......" + * + * \param dnnl_options - OrtDnnlProviderOptions instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(GetDnnlProviderOptionsAsString, _In_ const OrtDnnlProviderOptions* dnnl_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an ::OrtDnnlProviderOptions + * + * \since Version 1.15. + */ + void(ORT_API_CALL* ReleaseDnnlProviderOptions)(_Frees_ptr_opt_ OrtDnnlProviderOptions* input); + + /// \name OrtKernelInfo + /// Custom operator APIs. + /// @{ + + /** \brief Get the graph node name from ::OrtKernelInfo. + * + * If `out` is nullptr, the value of `size` is set to the size of the name + * string (including null-terminator), and a success status is returned. + * + * If the `size` parameter is greater than or equal to the name string's size, + * the value of `size` is set to the true size of the string (including null-terminator), + * the provided memory is filled with the string's contents, and a success status is returned. + * + * If the `size` parameter is less than the actual string's size and `out` + * is not nullptr, the value of `size` is set to the true size of the string + * and a failure status is returned. + * + * Can be used in a custom operator's CreateKernel callback to get the name of the operator's node name in the graph. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the name. + * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.15 + */ + ORT_API2_STATUS(KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_ char* out, _Inout_ size_t* size); + + /** \brief Get the session logger from ::OrtKernelInfo. + * + * Used in the CreateKernel callback of an OrtCustomOp to get a logger that can be used to log + * messages. + * + * \param[in] info An instance of ::OrtKernelInfo. + * \param[out] logger Pointer set to the session's ::OrtLogger. Owned by ONNX Runtime, so do not free. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.15 + */ + ORT_API2_STATUS(KernelInfo_GetLogger, _In_ const OrtKernelInfo* info, _Outptr_ const OrtLogger** logger); + + /// @} + /// \name OrtKernelContext + /// Custom operator APIs. + /// @{ + + /** \brief Get the runtime logger from ::OrtKernelContext. + * + * Used in the KernelCompute callback of an OrtCustomOp to get a logger that can be used to log + * messages during inference. + * + * \param[in] context An instance of ::OrtKernelContext. + * \param[out] logger Pointer set to the kernel context's ::OrtLogger. Owned by ONNX Runtime, so do not free. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.15 + */ + ORT_API2_STATUS(KernelContext_GetLogger, _In_ const OrtKernelContext* context, _Outptr_ const OrtLogger** logger); + + /// @} + /// \name OrtLogger + /// Custom operator APIs. + /// @{ + + /** \brief Logs a message at the given severity level using the provided ::OrtLogger. + * + * Only messages with a severity level equal or greater than the ::OrtLogger's logging severity level + * are logged. Use OrtApi::Logger_GetLoggingSeverityLevel to get the ::OrtLogger's logging severity + * level. + * + * Can be used in custom operators to log messages with the logger retrieved via OrtApi::KernelInfo_GetLogger. + * + * \param[in] logger The ::OrtLogger instance. + * \param[in] log_severity_level The message's severity level. + * \param[in] message The message to log. + * \param[in] file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. + * \param[in] line_number The file line number in which the message is logged. Usually the value of __LINE__. + * \param[in] func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.15 + */ + ORT_API2_STATUS(Logger_LogMessage, _In_ const OrtLogger* logger, OrtLoggingLevel log_severity_level, + _In_z_ const char* message, _In_z_ const ORTCHAR_T* file_path, int line_number, + _In_z_ const char* func_name); + + /** \brief Get the logging severity level of the ::OrtLogger. + * + * Can be used in a custom operator to get the logging severity level of the ::OrtLogger associated with + * the ::OrtKernelInfo. + * + * \param[in] logger The ::OrtLogger instance. + * \param[out] out Pointer to variable assigned with the logging severity level on success. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.15 + */ + ORT_API2_STATUS(Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger, _Out_ OrtLoggingLevel* out); + + /// @} + + /** \brief Get a ::OrtValue tensor stored as a constant initializer in the graph node. + * + * Used in the CreateKernel callback of an OrtCustomOp to get a tensor value. + * + * \param[in] info ::OrtKernelInfo instance. + * \param[in] index The node index. + * \param[out] is_constant Is it a constant node input or not. + * \param[out] out The OrtValue tensor value. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out); + + /** \brief Get Optional Type information from an ::OrtTypeInfo + * + * This augments ::OrtTypeInfo to return an ::OrtOptionalTypeInfo when the type is optional. + * The OrtOptionalTypeInfo also has a nested ::OrtTypeInfo that describes the type of the optional value. + * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. + * The actual OrtValues that are supplied in place of optional type inputs should contain + * specific type that is described by ::OrtOptionalTypeInfo. + * + * So the picture: ::OrtTypeInfo -> ::OrtOptionalTypeInfo -> ::OrtTypeInfo (describes the type that can be supplied + * in place of the optional type when creating the actual ::OrtValue). + * + * \param[in] type_info + * \param[out] out A pointer to the ::OrtOptionalTypeInfo. Do not free this value, + * it is owned by OrtTypeInfo instance. When the type_info does not represent + * optional type, nullptr is returned in out. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeInfo* type_info, + _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out); + + /** \brief Get OrtTypeInfo for the allowed contained type from an ::OrtOptionalTypeInfo. + * + * This augments ::OrtOptionalTypeInfo to return an ::OrtTypeInfo for the contained type. + * The OrtOptionalTypeInfo has a nested ::OrtTypeInfo that describes the type of the optional value. + * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. + * The actual OrtValues that are supplied in place of optional type inputs should contain + * specific type that is described by the returned ::OrtTypeInfo. + * + * \param[in] optional_type_info + * \param[out] out A copy of ::OrtTypeInfo for what the optional value could be. + * The user must free this value with ReleaseTypeInfo. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, + _Outptr_ OrtTypeInfo** out); + + /** \brief Set a single string in a string tensor + * Do not zero terminate the string data. + * + * \param[in] value A string tensor + * \param[in] index - flat index of the element + * \param[in] length_in_bytes length of the buffer in utf-8 bytes (without the null terminator) + * \param[inout] buffer - address of return value + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(GetResizedStringTensorElementBuffer, _Inout_ OrtValue* value, _In_ size_t index, _In_ size_t length_in_bytes, _Inout_ char** buffer); + + /** \brief Get Allocator from KernelContext for a specific memoryInfo. Please use C API ReleaseAllocator to release out object + * + * \param[in] context OrtKernelContext instance + * \param[in] mem_info OrtMemoryInfo instance + * \param[out] out A pointer to OrtAllocator. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.15. + */ + ORT_API2_STATUS(KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out); + + /** \brief Returns a null terminated string of the build info including git info and cxx flags + * + * \return UTF-8 encoded version string. Do not deallocate the returned buffer. + * + * \since Version 1.15. + */ + const char*(ORT_API_CALL* GetBuildInfoString)(void); + + /// \name OrtROCMProviderOptions + /// @{ + + /** \brief Create an OrtROCMProviderOptions + * + * \param[out] out Newly created ::OrtROCMProviderOptions. Must be released with OrtApi::ReleaseROCMProviderOptions + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.16. + */ + ORT_API2_STATUS(CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out); + + /** \brief Set options in a ROCm Execution Provider. + * + * Please refer to https://onnxruntime.ai/docs/execution-providers/ROCm-ExecutionProvider.html + * to know the available keys and values. Key should be in null terminated string format of the member of + * ::OrtROCMProviderOptions and value should be its related range. + * + * For example, key="device_id" and value="0" + * + * \param[in] rocm_options + * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys + * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values + * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.16. + */ + ORT_API2_STATUS(UpdateROCMProviderOptions, _Inout_ OrtROCMProviderOptions* rocm_options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** + * Get serialized ROCm provider options string. + * + * For example, "device_id=0;arena_extend_strategy=0;......" + * + * \param rocm_options - OrtROCMProviderOptions instance + * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() + * the specified allocator will be used to allocate continuous buffers for output strings and lengths. + * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.16. + */ + ORT_API2_STATUS(GetROCMProviderOptionsAsString, _In_ const OrtROCMProviderOptions* rocm_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); + + /** \brief Release an ::OrtROCMProviderOptions + * + * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does + * + * \since Version 1.16. + */ + void(ORT_API_CALL* ReleaseROCMProviderOptions)(_Frees_ptr_opt_ OrtROCMProviderOptions* input); + + /** \brief Create an allocator with specific type and register it with the ::OrtEnv + * This API enhance CreateAndRegisterAllocator that it can create an allocator with specific type, not just CPU allocator + * Enables sharing the allocator between multiple sessions that use the same env instance. + * Lifetime of the created allocator will be valid for the duration of the environment. + * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. + * \param[in] env OrtEnv instance + * \param[in] provider_type ExecutionProvider type + * \param[in] mem_info OrtMemoryInfo instance + * \param[in] arena_cfg Arena configuration + * \param[in] provider_options_keys key of the provider options map + * \param[in] provider_options_values value of the provider options map + * \param[in] num_keys Length of the provider options map + */ + ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg, + _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); + + /** \brief Run the model asynchronously in a thread owned by intra op thread pool + * + * \param[in] session + * \param[in] run_options If nullptr, will use a default ::OrtRunOptions + * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names + * \param[in] input Array of ::OrtValue%s of the input values + * \param[in] input_len Number of elements in the input_names and inputs arrays + * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names + * \param[in] output_names_len Number of elements in the output_names and outputs array + * \param[out] output OrtValue* array of size output_names_len. + * On calling RunAsync, output[i] could either be a null or a pointer to a preallocated OrtValue. + * Later, the output array will be passed to run_async_callback with all null(s) filled with valid + * OrtValue pointer(s) allocated by onnxruntime. + * NOTE: it is customer's duty to finally release the output array and each of its member, + * regardless of whether the member (OrtValue*) is allocated by onnxruntime or preallocated by the customer. + * \param[in] run_async_callback Callback function on model run completion + * \param[in] user_data User data that pass back to run_async_callback + */ + ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options, + _In_reads_(input_len) const char* const* input_names, + _In_reads_(input_len) const OrtValue* const* input, size_t input_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _Inout_updates_all_(output_names_len) OrtValue** output, + _In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data); + + /** + * Update TensorRT EP provider option where its data type is pointer, for example 'user_compute_stream'. + * If the data type of the provider option can be represented by string please use UpdateTensorRTProviderOptions. + * + * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. + * + * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance + * \param key - Name of the provider option + * \param value - A pointer to the instance that will be assigned to this provider option + * + * \since Version 1.16. + */ + ORT_API2_STATUS(UpdateTensorRTProviderOptionsWithValue, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _In_ void* value); + + /** + * Get TensorRT EP provider option where its data type is pointer. + * If the data type of the provider option can be represented by string please use GetTensorRTProviderOptionsAsString. + * + * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance + * \param key - Name of the provider option + * \param ptr - A pointer to the instance that is kept by the provider option + * + * \since Version 1.16. + */ + ORT_API2_STATUS(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _Outptr_ void** ptr); + + /** + * Update CUDA EP provider option where its data type is pointer, for example 'user_compute_stream'. + * If the data type of the provider option can be represented by string please use UpdateCUDAProviderOptions. + * + * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. + * + * \param cuda_options - OrtCUDAProviderOptionsV2 instance + * \param key - Name of the provider option + * \param value - A pointer to the instance that will be assigned to this provider option + * + * \since Version 1.16. + */ + ORT_API2_STATUS(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value); + + /** + * Get CUDA EP provider option where its data type is pointer. + * If the data type of the provider option can be represented by string please use GetCUDAProviderOptionsAsString. + * + * \param cuda_options - OrtCUDAProviderOptionsV2 instance + * \param key - Name of the provider option + * \param ptr - A pointer to the instance that is kept by the provider option + * + * \since Version 1.16. + */ + ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr); + + /** + * Get a EP resource. + * E.g. a cuda stream or a cublas handle + * + * \param context - Kernel context + * \param resource_version - Version of the resource + * \param resource_id - Type of resource + * \param resource - A pointer to returned resource + * + * \since Version 1.16. + */ + ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resource_version, + _In_ int resource_id, _Outptr_ void** resource); + + /** \brief Set user logging function + * + * By default the logger created by the CreateEnv* functions is used to create the session logger as well. + * This function allows a user to override this default session logger with a logger of their own choosing. This way + * the user doesn't have to create a separate environment with a custom logger. This addresses the problem when + * the user already created an env but now wants to use a different logger for a specific session (for debugging or + * other reasons). + * + * \param[in] options + * \param[in] user_logging_function A pointer to a logging function. + * \param[in] user_logging_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `user_logging_function`. This parameter is optional. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); + + /** + * Get number of input from OrtShapeInferContext + * + * \param[in] context + * \param[out] out The number of inputs + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out); + + /** + * Get type and shape info of an input + * + * \param[in] context + * \param[in] index The index of the input + * \param[out] info Type shape info of the input + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info); + + /** + * Get attribute from OrtShapeInferContext. Note that OrtShapeInferContext is a per-node context, one could only read attribute from current node. + * + * \param[in] context + * \param[in] attr_name Name of the attribute + * \param[out] attr Handle of the attribute fetched + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr); + + /** + * Set type and shape info of an output + * + * \param[in] context + * \param[in] index The index of the output + * \param[out] info Type shape info of the output + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info); + + /** + * Set symbolic shape to type shape info + * + * \param[in] info Type shape info + * \param[in] dim_params Symbolic strings + * \param[in] dim_params_length Number of strings + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, _In_ const char* dim_params[], _In_ size_t dim_params_length); + + /** + * Read contents of an attribute to data + * + * \param[in] op_attr + * \param[in] type Attribute type + * \param[out] data Memory address to save raw content of the attribute + * \param[in] len Number of bytes allowed to store in data + * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success + * + * \since Version 1.17. + */ + ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); + + /** \brief Set whether to use deterministic compute. + * + * Default is false. If set to true, this will enable deterministic compute for GPU kernels where possible. + * Note that this most likely will have a performance cost. + * + * \param[in] options + * \param[in] value + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value); + + /** + * Run fn in parallel + * + * \param[in] context + * \param[in] fn Function accepting usr_data and an integer as iterator + * \param[in] total The number of times fn is to be invoked + * \param[in] num_batch Number of batches by which the "total" is to be divided in maximum. When zero, there is no limit + * \param[in] usr_data User data to be passed back to fn + * + * \since Version 1.17. + */ + ORT_API2_STATUS(KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data); + + /** \brief Append OpenVINO execution provider to the session options + * + * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail. + * + * \param[in] options + * \param[in] provider_options_keys + * \param[in] provider_options_values + * \param[in] num_keys + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2, + _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Append VitisAI provider to session options + * + * If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] provider_options_keys + * \param[in] provider_options_values + * \param[in] num_keys + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, + _In_ OrtSessionOptions* options, + _In_reads_(num_keys) const char* const* provider_options_keys, + _In_reads_(num_keys) const char* const* provider_options_values, + _In_ size_t num_keys); + + /** \brief Get scratch buffer from the corresponding allocator under the specific OrtMemoryInfo object. + * NOTE: callers are responsible to release this scratch buffer from the corresponding allocator + * \param[in] context OrtKernelContext instance + * \param[in] mem_info OrtMemoryInfo instance + * \param[in] count_or_bytes How many bytes is this scratch buffer + * \param[out] out A pointer to the scratch buffer + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. + */ + ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); + + /** \brief Get allocator from KernelInfo for a specific memory type. Please use C API ReleaseAllocator to release out object + * + * \param[in] info OrtKernelInfo instance + * \param[in] mem_type OrtMemType object + * \param[out] out A pointer to OrtAllocator + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. + */ + ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); + + /** \brief Replace initialized Tensors with external data with the provided files in memory + * + * The function will find the initialized TensorProtos with external data in the graph with the provided + * external file names and the file content in memory. The API gets the external file name, offset, data length + * from TensorProto, and locate the tensor data from the file in memory buffer. + * It creates a Tensor to replace the existing Tensor in graph. The replacement + * will occur before any of the optimizations take place. The data will be copied into the graph + * since TensorProto can't refer to the user provided buffers. + * + * \param[in] options + * \param[in] external_initializer_file_names Array of null terminated UTF-8 encoded strings of the file names + * which holds the external initializers. + * \param[in] external_initializer_file_buffer_array Array of pointers to the buffer of the file content. + * The buffer can be freed after session creation. + * \param[in] external_initializer_file_lengths Array of size_t to indicate the length of file content + * \param[in] num_external_initializer_files Number of external files + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.18. + */ + ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options, + _In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names, + _In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, + _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, + size_t num_external_initializer_files); + + /** \brief Create an OrtLoraAdapter + * + * The function attempts to locate file specified by adapter_file_path, read it and create an OrtLoraAdapter + * instance. The adapter_file_path should be a valid path to a file that contains a valid Lora Adapter + * format. The function attempts to validate the format at load time. The file will always be memory mapped, unless + * the platform does not support memory mapping, in which case the file will be read into memory. + * + * \param[in] adapter_file_path adapter file path. + * \param[in] allocator optional pointer to a device allocator. If specified + * data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU. + * The data would still be copied to device if required by the model at inference time. + * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with + * OrtApi::ReleaseLoraAdapter. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. + */ + ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator, + _Outptr_ OrtLoraAdapter** out); + + /** \brief Create an OrtLoraAdapter + * + * The function copies the bytes from the array and creates an OrtLoraAdapter instance. + * + * + * \param[in] bytes pointer to a valid Lora Adapter format buffer. + * \param[in] num_bytes length of bytes buffer. + * \param[in] allocator optional pointer to a device allocator. If specified + * data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU. + * The data would still be copied to device if required by the model at inference time. + * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with + * OrtApi::ReleaseLoraAdapter. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. + */ + ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator, + _Outptr_ OrtLoraAdapter** out); + + /** \brief Release an ::OrtLoraAdapter obtained from OrtApi::CreateLoraAdapter + */ + ORT_CLASS_RELEASE(LoraAdapter); + + /** \brief Add the Lora Adapter to the list of active adapters. + * + * The function adds the Lora Adapter to the list of active adapters. The Lora Adapter must be created with + * OrtApi::CreateLoraAdapter or FromArray. The Lora Adapter will be used by the session to run the model. + * The instance of the OrtRunOptions can then be used to customize the Run() calls. + * More than one OrtLoraAdapter can be active at the same time. Lora Parameters that belong to different + * Lora adapters that will be active at the same time must not overlap. + * This setting does not affect RunWithBinding. + * + * \param[in] options OrtRunOptions instance + * \param[in] adapter OrtLoraAdapter instance + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. + */ + ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter); + + /// @} + /// \name OrtEpDynamicOptions + /// @{ + + /** \brief Set DynamicOptions for EPs (Execution Providers) + * + * Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h` + * Look for `kOrtEpDynamicOptions` + * + * \param[in] sess OrtSession + * \param[in] keys Array of null terminated UTF8 encoded strings of EP dynamic option keys + * \param[in] values Array of null terminated UTF8 encoded string of EP dynamic option values + * \param[in] kv_len Number of elements in the keys and values arrays + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.20. + */ + ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, + _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); + + /** \brief Release an OrtValueInfo instance if it was not added to an OrtGraph. + * \since Version 1.22. + */ + ORT_CLASS_RELEASE(ValueInfo); + + /** \brief Release an OrtNode if it was not added to an OrtGraph. + * \since Version 1.22. + */ + ORT_CLASS_RELEASE(Node); + + /** \brief Release an OrtGraph. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.22. + */ + ORT_CLASS_RELEASE(Graph); + + /** \brief Release an OrtModel. + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.22. + */ + ORT_CLASS_RELEASE(Model); + + /** \brief Get the value name from an OrtValueInfo instance. + * \param[in] value_info The OrtValueInfo instance. + * \param[out] name The name of the OrtValueInfo + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.22. + */ + ORT_API2_STATUS(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); + + /** \brief Get the type information from an OrtValueInfo instance. + * \param[in] value_info The OrtValueInfo instance. + * \param[out] type_info The type info of the OrtValueInfo + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.22. + */ + ORT_API2_STATUS(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); + + /** \brief Get the Model Editor API instance + * + * Get the Model Editor API instance to create a new model or augment an existing model. + * + * \return Model Editor API struct + * + * \since Version 1.22. + */ + const OrtModelEditorApi*(ORT_API_CALL* GetModelEditorApi)(); + + /** \brief Create an OrtValue for a Tensor that uses pre-existing memory. + * + * ORT will take ownership of the memory and free it using the provided deleter when no longer in use. + * + * \param[in] deleter OrtAllocator instance that will be used to free the memory. + * Only the OrtAllocator:Info and OrtAllocator::Release functions are required. + * The OrtMemoryInfo returned by OrtAllocator::Info must match the location of p_data. + * \param[in] p_data Pointer to the memory that will be used by the Tensor. ORT will take ownership of the memory. + * \param[in] p_data_len Length of the memory in bytes. + * \param[in] shape Dimensions of the Tensor. All values should be > 0. + * \param[in] shape_len Number of dimensions in the shape array. + * \param[in] type Data type of the Tensor. + * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, + _In_ void* p_data, size_t p_data_len, + _In_ const int64_t* shape, size_t shape_len, + ONNXTensorElementDataType type, + _Outptr_ OrtValue** out); + + /** \brief sets load cancellation flag to abort session loading process. + * + * \param[in] options instance that was passed to the session at creation time. + * \param[in] cancel setting this to true after model loading process was initiated will + * attempt to cancel the loading process. If cancellation is successful, CreateSession() + * CreateSessionFromArray() or any other session creation API that take session options as an + * argument will return an OrtStatus indicating that session loading was canceled at user request, + * error code ORT_MODEL_LOAD_CANCELED. + * The APIs above would not return any valid Session instance. This is the best case effort and the result + * is not guaranteed. The session may have already been created and initialized + * before the cancellation request was issued. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, + _In_ bool cancel); + + /** \brief Get the Compile API instance. + * + * Get the Compile API instance to compile ONNX models. Execution providers that support compilation fuse a subgraph + * into an EPContext node that wraps a provider-specific binary representation of the subgraph. + * For more details about the EPContext design, refer to: + * \htmlonly + * EPContext design document. + * \endhtmlonly + * + * \return Compile API struct instance. + * + * \since Version 1.22. + */ + const OrtCompileApi*(ORT_API_CALL* GetCompileApi)(); + + // + // OrtKeyValuePairs + // + + /** \brief Create an OrtKeyValuePairs instance. + * + * \param[out] out A pointer to a newly created OrtKeyValuePairs instance. + * + * \note Must be released by calling ReleaseKeyValuePairs. + * + * \since Version 1.22. + */ + void(ORT_API_CALL* CreateKeyValuePairs)(_Outptr_ OrtKeyValuePairs** out); + + /** \brief Add a key-value pair to the OrtKeyValuePairs instance. + * + * \param[in] kvps OrtKeyValuePairs instance. + * \param[in] key Key to be added. + * \param[in] value Value to be added. + * + * \note The `key` and `value` are copied internally. + * + * \since Version 1.22. + */ + + void(ORT_API_CALL* AddKeyValuePair)(_In_ OrtKeyValuePairs* kvps, _In_ const char* key, _In_ const char* value); + + /** \brief Get the value associated with a key in the OrtKeyValuePairs instance. + * + * \param[in] kvps OrtKeyValuePairs instance. + * \param[in] key Key to be searched. + * + * \return The value associated with the key, or nullptr if the key does not exist. + * + * \since Version 1.22. + */ + const char*(ORT_API_CALL* GetKeyValue)(_In_ const OrtKeyValuePairs* kvps, _In_ const char* key); + + /** \brief Get all the key-value pairs from the OrtKeyValuePairs instance. + * + * \param[in] kvps OrtKeyValuePairs instance. + * \param[out] keys Array of keys from `kvps`. + * \param[out] values Array of values from `kvps`. + * \param[out] num_entries Number of entries in `keys` and `values`. + * + * \since Version 1.22. + */ + void(ORT_API_CALL* GetKeyValuePairs)(_In_ const OrtKeyValuePairs* kvps, + _Outptr_ const char* const** keys, _Outptr_ const char* const** values, + _Out_ size_t* num_entries); + + /** \brief Remove a key-value pair from the OrtKeyValuePairs instance. + * + * \param[in] kvps OrtKeyValuePairs instance. + * \param[in] key Key to be removed. No error if not found. + * + * \since Version 1.22. + */ + void(ORT_API_CALL* RemoveKeyValuePair)(_In_ OrtKeyValuePairs* kvps, _In_ const char* key); + + /** \brief Release an OrtKeyValuePairs instance. + * + * \param[in] input OrtKeyValuePairs instance to be released. + * + * \since Version 1.22. + */ + ORT_CLASS_RELEASE(KeyValuePairs); + + /** \brief Register an execution provider library with ORT. + * + * The library must export 'CreateEpFactories' and 'ReleaseEpFactory' functions. + * See OrtEpApi for more details. + * + * \param[in] env The OrtEnv instance to register the library in. + * \param[in] registration_name The name to register the execution provider library under. + * \param[in] path The path to the execution provider library. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name, + _In_ const ORTCHAR_T* path); + + /** \brief Unregister an execution provider library with ORT. + * + * ORT will call ReleaseEpFactory for all factories created by the library, and unload the library. + * + * You MUST ensure there are no Session instances using execution providers created by the library + * before calling this function. + * + * \param[in] env The OrtEnv instance to unregister the library from. + * \param[in] registration_name The name the execution provider library was registered under. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name); + + /** \brief Get the list of available OrtEpDevice instances. + * + * Each OrtEpDevice instance contains details of the execution provider and the device it will use. + * + * \param[in] env The OrtEnv instance to query. + * \param[out] ep_devices The OrtEpDevice instances that the execution provider will use. + * \param[out] num_ep_devices The number of OrtEpDevice instances returned. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(GetEpDevices, _In_ const OrtEnv* env, + _Outptr_ const OrtEpDevice* const** ep_devices, _Out_ size_t* num_ep_devices); + + /** \brief Append the execution provider that is responsible for the selected OrtEpDevice instances + * to the session options. + * + * \param[in] session_options Session options to add execution provider to. + * \param[in] env Environment that execution providers were registered with. + * \param[in] ep_devices One or more OrtEpDevice instances to create an execution provider for. + * Obtain from GetEpDevices. All OrtEpDevice instances must be from the same execution + * provider. It is only necessary to provide multiple OrtEpDevices if you want to use the + * same execution provider for multiple devices. + * e.g. the EP is capable of running on GPU and NPU. + * \param[in] num_ep_devices Number of OrtEpDevice instances. + * \param[in] ep_option_keys Optional keys to configure the execution provider. + * \param[in] ep_option_vals Optional values to configure the execution provider. + * \param[in] num_ep_options Number of execution provide options to add. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_V2, _In_ OrtSessionOptions* session_options, + _In_ OrtEnv* env, + _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, _In_ size_t num_ep_devices, + _In_reads_(num_op_options) const char* const* ep_option_keys, + _In_reads_(num_op_options) const char* const* ep_option_vals, + size_t num_ep_options); + + /** \brief Set the execution provider selection policy for the session. + * + * Allows users to specify a device selection policy for automatic execution provider (EP) selection. + * If custom selection is required please use SessionOptionsSetEpSelectionPolicyDelegate instead. + * + * \param[in] session_options The OrtSessionOptions instance. + * \param[in] policy The device selection policy to use (see OrtExecutionProviderDevicePolicy). + * + * \since Version 1.22 + */ + ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* session_options, + _In_ OrtExecutionProviderDevicePolicy policy); + + /** \brief Set the execution provider selection policy delegate for the session. + * + * Allows users to provide a custom device selection policy for automatic execution provider (EP) selection. + * + * \param[in] session_options The OrtSessionOptions instance. + * \param[in] delegate Delegate callback for custom selection. + * \param[in] delegate_state Optional state that will be passed to the delegate callback. nullptr if not required. + * + * \since Version 1.22 + */ + ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* session_options, + _In_ EpSelectionDelegate delegate, + _In_opt_ void* delegate_state); + + /** \brief Get the hardware device type. + * + * \param[in] device The OrtHardwareDevice instance to query. + * \return The hardware device type. + * + * \since Version 1.22. + */ + OrtHardwareDeviceType(ORT_API_CALL* HardwareDevice_Type)(_In_ const OrtHardwareDevice* device); + + /** \brief Get the hardware device's vendor identifier. + * + * \param[in] device The OrtHardwareDevice instance to query. + * \return The hardware device vendor identifier. + * + * \since Version 1.22. + */ + uint32_t(ORT_API_CALL* HardwareDevice_VendorId)(_In_ const OrtHardwareDevice* device); + + /** \brief Get the hardware device's vendor name. + * + * \param[in] device The OrtHardwareDevice instance to query. + * \return The hardware device's vendor name. + * + * \since Version 1.22. + */ + const char*(ORT_API_CALL* HardwareDevice_Vendor)(_In_ const OrtHardwareDevice* device); + + /** \brief Get the hardware device's unique identifier. + * + * \param[in] device The OrtHardwareDevice instance to query. + * \return The device id. + * + * \note This is not a unique identifier. It identifies the hardware type when combined with vendor id. + * \since Version 1.22. + */ + uint32_t(ORT_API_CALL* HardwareDevice_DeviceId)(_In_ const OrtHardwareDevice* device); + + /** \brief Get hardware device metadata. + * + * \param[in] device The OrtHardwareDevice instance to query. + * \return An OrtKeyValuePairs instance containing the metadata for the device. + * Note: ORT owns the instance so the user must not call ReleaseKeyValuePairs with it. + * + * \since Version 1.22. + */ + const OrtKeyValuePairs*(ORT_API_CALL* HardwareDevice_Metadata)(_In_ const OrtHardwareDevice* device); + + /** \brief Get the execution provider name. + * + * \param[in] ep_device The OrtEpDevice instance to query. + * \return The execution provider name. + * + * \since Version 1.22. + */ + const char*(ORT_API_CALL* EpDevice_EpName)(_In_ const OrtEpDevice* ep_device); + + /** \brief Get the execution provider's vendor name. + * + * \param[in] ep_device The OrtEpDevice instance to query. + * \return The execution provider's vendor name. + * + * \since Version 1.22. + */ + const char*(ORT_API_CALL* EpDevice_EpVendor)(_In_ const OrtEpDevice* ep_device); + + /** \brief Get the metadata for the OrtEpDevice. + * + * \param[in] ep_device The OrtEpDevice instance to query. + * \return An OrtKeyValuePairs instance containing the metadata for the device. + * + * \since Version 1.22. + */ + const OrtKeyValuePairs*(ORT_API_CALL* EpDevice_EpMetadata)(_In_ const OrtEpDevice* ep_device); + + /** \brief Get the execution provider options for the OrtEpDevice. + * + * \param[in] ep_device The OrtEpDevice instance to query. + * \return An OrtKeyValuePairs instance containing the execution provider options for the device. + * + * \since Version 1.22. + */ + const OrtKeyValuePairs*(ORT_API_CALL* EpDevice_EpOptions)(_In_ const OrtEpDevice* ep_device); + + /** \brief Get the OrtHardwareDevice instance for the OrtEpDevice. + * + * \param[in] ep_device The OrtEpDevice instance to query. + * \return The OrtHardwareDevice instance for the device. + * + * \since Version 1.22. + */ + const OrtHardwareDevice*(ORT_API_CALL* EpDevice_Device)(_In_ const OrtEpDevice* ep_device); + + /** \brief Get the OrtEpApi instance for implementing an execution provider. + * + * \since Version 1.22. + */ + const OrtEpApi*(ORT_API_CALL* GetEpApi)(); + + /** \brief Compute total size in bytes of the tensor data contained in an OrtValue. + * + * Returns the total number of bytes used to store the tensor data. For numeric tensors, + * this is sizeof(element_type) * total_element_count. OrtValues that are not tensors or + * that are tensors that contain strings will cause an error to be returned. + * + * \param[in] ort_value OrtValue instance containing a tensor + * \param[out] size The total size of the tensor data in bytes + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23 + */ + ORT_API2_STATUS(GetTensorSizeInBytes, _In_ const OrtValue* ort_value, _Out_ size_t* size); +}; + +/* + * Steps to use a custom op: + * 1 Create an OrtCustomOpDomain with the domain name used by the custom ops + * 2 Create an OrtCustomOp structure for each op and add them to the domain + * 3 Call OrtAddCustomOpDomain to add the custom domain of ops to the session options + */ + +// Specifies some characteristics of inputs/outputs of custom ops: +// Specify if the inputs/outputs are one of: +// 1) Non-optional (input/output must be present in the node) +// 2) Optional (input/output may be absent in the node) +// 3) Variadic: A variadic input or output specifies N (i.e., the minimum arity) or more operands. +// Only the last input or output of a custom op may be marked as variadic. +// The homogeneity of the variadic input or output determines whether all operands must be of the same +// tensor element type. +typedef enum OrtCustomOpInputOutputCharacteristic { + INPUT_OUTPUT_REQUIRED = 0, + INPUT_OUTPUT_OPTIONAL, + INPUT_OUTPUT_VARIADIC, +} OrtCustomOpInputOutputCharacteristic; + +/* + * The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by + * the implementor of the custom op. + */ +struct OrtCustomOp { + uint32_t version; // Must be initialized to ORT_API_VERSION + + // This callback creates the kernel, which is a user defined + // parameter that is passed to the Kernel* callbacks below. It is + // recommended to use CreateKernelV2 which allows for a safe error + // propagation by returning an OrtStatusPtr. + void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, + _In_ const OrtKernelInfo* info); + + // Returns the name of the op + const char*(ORT_API_CALL* GetName)(_In_ const struct OrtCustomOp* op); + + // Returns the type of the execution provider, return nullptr to use CPU execution provider + const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ const struct OrtCustomOp* op); + + // Returns the count and types of the input & output tensors + ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetInputTypeCount)(_In_ const struct OrtCustomOp* op); + ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op); + + // Perform a computation step. It is recommended to use + // KernelComputeV2 which allows for a safe error propagation by + // returning an OrtStatusPtr. + void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context); + void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); + + // Returns the characteristics of the input & output tensors + OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + + // Returns the memory type of the input tensors. This API allows the custom op + // to place the inputs on specific devices. By default, it returns + // OrtMemTypeDefault, which means the input is placed on the default device for + // the execution provider. If the inputs need to be with different memory types, + // this function can be overridden to return the specific memory types. + OrtMemType(ORT_API_CALL* GetInputMemoryType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); + + // Returns the minimum number of input arguments expected for the variadic input. + // Applicable only for custom ops that have a variadic input. + int(ORT_API_CALL* GetVariadicInputMinArity)(_In_ const struct OrtCustomOp* op); + + // Returns true (non-zero) if all arguments of a variadic input have to be of the same type (homogeneous), + // and false (zero) otherwise. + // Applicable only for custom ops that have a variadic input. + int(ORT_API_CALL* GetVariadicInputHomogeneity)(_In_ const struct OrtCustomOp* op); + + // Returns the minimum number of output values expected for the variadic output. + // Applicable only for custom ops that have a variadic output. + int(ORT_API_CALL* GetVariadicOutputMinArity)(_In_ const struct OrtCustomOp* op); + + // Returns true (non-zero) if all outputs values of a variadic output have to be of the same type (homogeneous), + // and false (zero) otherwise. + // Applicable only for custom ops that have a variadic output. + int(ORT_API_CALL* GetVariadicOutputHomogeneity)(_In_ const struct OrtCustomOp* op); + + // Create the kernel state which is passed to each compute call. + OrtStatusPtr(ORT_API_CALL* CreateKernelV2)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, + _In_ const OrtKernelInfo* info, + _Out_ void** kernel); + + // Perform the computation step. + OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context); + + OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*); + + // Get start range + int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); + int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); + + // Get the inplace_map that defines which output can reuse which input + // Callers will provide 2 raw int* and pass in their address, this function will fill these 2 arrays + // when return, output (*output_index)[i] may reuse the input (*input_index[i]). + // The return value is the size of these 2 arrays. + // Callers are responsible to delete these 2 arrays after use by calling OrtCustomOp::ReleaseMayInplace(). + size_t(ORT_API_CALL* GetMayInplace)(_Out_ int** input_index, _Out_ int** output_index); + + // Release the pointer input_index and output_index allocated from GetMayInplace() function. + // If GetMayInplace() is defined, this function MUST be defined as well. + void(ORT_API_CALL* ReleaseMayInplace)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); + + // Same as GetMayInplace() and ReleaseMayInplace() + size_t(ORT_API_CALL* GetAliasMap)(_Out_ int** input_index, _Out_ int** output_index); + void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); +}; + +/** + * ORT Model Editor API + */ + +/** + * \brief The OrtModelEditorApi struct provides functions to create or edit an ONNX model. + * + * See onnxruntime/test/shared_lib/test_model_editor_api.cc for example usage. + * + * \since Version 1.22. + */ +struct OrtModelEditorApi { + // Model building/editing requires a full build. We return nullptr from GetModelEditorApi if this is a minimal + // build, so it doesn't matter if there are no function pointers in this struct as a user will never get an + // OrtModelEditorApi instance. We do however need a dummy field to avoid empty struct warning. +#if defined(ORT_MINIMAL_BUILD) + const bool not_defined_in_this_build; +#else + /** \brief Create an OrtTypeInfo instance for a Tensor. + * + * Create an OrtTypeInfo instance for a Tensor to use as graph inputs/outputs with the Model Editor API. + * + * User can release `tensor_info` after creating the OrtTypeInfo. + * + * \param[in] tensor_info Tensor type and shape information. + * \param[out] type_info TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a SparseTensor. + * + * Create an OrtTypeInfo instance for a SparseTensor to use as graph inputs/outputs with the Model Editor API. + * + * User can release `tensor_info` after creating the OrtTypeInfo. + * + * \param[in] tensor_info SparseTensor type and shape information. + * \param[out] type_info TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a Map. + * + * Create an OrtTypeInfo instance for a Map to use as graph inputs/outputs with the Model Editor API. + * + * User can release `map_value_type` after creating the OrtTypeInfo. + * + * \param[in] map_key_type Key type for the map. + * \param[in] map_value_type Value type for the map. + * \param[out] type_info TypeInfo instance for the map. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, + _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for a Sequence. + * + * Create an OrtTypeInfo instance for a Sequence to use as graph inputs/outputs with the Model Editor API. + * + * User can release `sequence_type` after creating the OrtTypeInfo. + * + * \param[in] sequence_type Sequence type and shape information. + * \param[out] type_info TypeInfo instance for the sequence. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtTypeInfo instance for an Optional. + * + * Create an OrtTypeInfo instance for an Optional to use as graph inputs/outputs with the Model Editor API. + * + * User can release `contained_type` after creating the OrtTypeInfo. + * + * \param[in] contained_type Tensor type and shape information. + * \param[out] type_info TypeInfo instance for the tensor. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, _Outptr_ OrtTypeInfo** type_info); + + /** \brief Create an OrtValueInfo for use as an OrtGraph input or output. + * + * \param[in] name The name of the input or output. + * \param[in] type_info The type information for the input or output. The provided value is copied. + * \param[out] value_info The OrtValueInfo instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, + _Outptr_ OrtValueInfo** value_info); + + /** \brief Create an OrtNode to add to an OrtGraph. + * + * Create an OrtNode. + * + * Create attributes with CreateOpAttr. OrtOpAttr instances are copied. + * + * \param[in] operator_name The name of the operator. + * \param[in] domain_name The domain of the operator. Use an empty string for ONNX operators. + * \param[in] node_name The name of the node. + * \param[in] input_names The names of the inputs. + * \param[in] input_names_len The number of input names. + * \param[in] output_names The names of the outputs. + * \param[in] output_names_len The number of output names. + * \param[in] attributes The optional attributes of the node. + * \param[in] attribs_len The number of attributes. May be zero. + * \param[out] node The OrtNode instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, _In_ const char* domain_name, _In_ const char* node_name, + _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, + _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, + _In_reads_(attribs_len) _In_opt_ OrtOpAttr** attributes, _In_ size_t attribs_len, + _Outptr_ OrtNode** node); + + /** \brief Create an OrtGraph + * \snippet{doc} snippets.dox OrtStatus Return Value + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateGraph, _Outptr_ OrtGraph** graph); + + /** \brief Set the inputs for the OrtGraph. + * + * Set the graph inputs. This will replace any existing inputs with the new values. + * The OrtGraph takes ownership of the OrtValueInfo instances and you should NOT call ReleaseOrtValueInfo. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] inputs The input OrtValueInfo instances. + * \param[in] inputs_len The number of input OrtValueInfo instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(SetGraphInputs, _Inout_ OrtGraph* graph, + _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); + + /** \brief Set the outputs for the OrtGraph. + * + * Set the graph outputs. This will replace any existing outputs with the new values. + * The OrtGraph takes ownership of the OrtValueInfo instances provided and you should NOT call ReleaseOrtValueInfo. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] outputs The output OrtValueInfo instances. + * \param[in] outputs_len The number of output OrtValueInfo instances. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(SetGraphOutputs, _Inout_ OrtGraph* graph, + _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); + + /** \brief Add an initializer to the OrtGraph + * + * ORT will take ownership of the OrtValue and you should NOT call ReleaseOrtValue. + * + * Two options: + * + * Allocated memory: + * Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. + * Set `data_is_external` to false. + * + * Pre-existing memory: + * Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue + * with a tensor that contains a pointer to the existing data. + * Set `data_is_external` to true. + * + * The pointer must remain valid for the duration of the inference session. + * If using CreateTensorWithDataAsOrtValue you are responsible for freeing the memory after the inference session + * is released. + * If using CreateTensorWithDataAndDeleterAsOrtValue, ORT will free the memory using the provided deleter as + * soon as the OrtValue is no longer in use. + * + * NOTE: A tensor containing pre-existing memory MUST have 128 bytes of data or more. + * For smaller tensors use CreateTensorAsOrtValue. + * + * ONNX shape inferencing does not support external data. An initializer involved in shape inferencing is + * typically small (a single value or limited by the rank of a tensor) and uses less than 128 bytes of + * memory, so this limit acts as a simple catch-all rule to avoid issues. + * e.g. Reshape's `shape`, Clip's `min` and `max`, various ops `axes`. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] name The value name for the initializer. + * \param[in] tensor The OrtValue instance containing the tensor data. + * \param[in] data_is_external Set to true if the data is external and should not be copied. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(AddInitializerToGraph, _Inout_ OrtGraph* graph, _In_ const char* name, _In_ OrtValue* tensor, + bool data_is_external); + + /** \brief Add an OrtNode to an OrtGraph + * + * Add the node to the graph. The OrtGraph will take ownership of OrtNode and you should NOT call ReleaseOrtNode. + * + * \param[in] graph The OrtGraph instance to update. + * \param[in] node The OrtNode instance to add to the graph. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(AddNodeToGraph, _Inout_ OrtGraph* graph, _In_ OrtNode* node); + + /** \brief Create an OrtModel. + * + * Create an OrtModel. + * + * This can be used to build a new model, or to augment an existing model. + * + * \param[in] domain_names The domain names for the model. + * If augmenting an existing model add additional domains if needed. + * \param[in] opset_versions The opset versions for the model. + * If augmenting an existing model add additional opset versions if needed. + * \param[in] opset_entries_len The number of domain_names and opset_versions entries. + * Domain and opset entries should be 1:1 + * \param[out] model The OrtModel instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateModel, + _In_reads_(opset_entries_len) const char* const* domain_names, + _In_reads_(opset_entries_len) const int* opset_versions, + size_t opset_entries_len, + _Outptr_ OrtModel** model); + + /** \brief Add an OrtGraph to an OrtModel. + * + * Add the graph to a model. This should be called once when creating a new model. + * + * The OrtModel takes ownership of the OrtGraph and you should NOT call ReleaseOrtGraph. + * + * \param[in] model The OrtModel instance to update. + * \param[in] graph The OrtGraph instance to add to the model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(AddGraphToModel, _Inout_ OrtModel* model, _In_ OrtGraph* graph); + + /** \brief Create an OrtSession using the OrtModel. + * + * Create an inference session using the OrtModel instance. + * The OrtModel should have been populated with an OrtGraph containing nodes and initializers, and SetGraphInputs + * and SetGraphOutputs must have been called. + * This will validate the model, run optimizers, and prepare the session for inferencing. + * + * ReleaseOrtModel must be called to free the OrtModel after session creation. + * + * \param[in] env The OrtEnv instance. + * \param[in] model The OrtModel instance. + * \param[in] options The OrtSessionOptions instance. + * \param[out] out The OrtSession instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, + _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession to augment an existing model. + * + * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. + * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the + * model is finalized. + * + * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. + * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. + * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made + * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. + * + * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the + * session for inferencing by calling FinalizeModelEditorSession. + * + * \param{in} env The OrtEnv instance. + * \param{in} model_path The path to the existing ONNX model to augment. + * \param{in} options The OrtSessionOptions instance. + * \param{out} out The created OrtSession instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateModelEditorSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + + /** \brief Create an OrtSession to augment an existing model. + * + * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. + * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the + * model is finalized. + * + * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. + * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. + * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made + * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. + * + * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the + * session for inferencing by calling FinalizeModelEditorSession. + * + * \param{in} env The OrtEnv instance. + * \param{in} model_data The model data for the existing model to augment. + * \param{in} model_data_length The length of the model data. + * \param{in} options The OrtSessionOptions instance. + * \param{out} out The created OrtSession instance. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, + _In_ const void* model_data, size_t model_data_length, + _In_ const OrtSessionOptions* options, + _Outptr_ OrtSession** out); + + /** \brief Query the session for the opset version of a domain. + * + * When using the Model Editor API to augment a model, any new nodes must conform to the opset version of the + * original model. To do that the user must be able to discover that opset version. + * Returns an error if the domain is not used in the model. + * + * \param[in] session OrtSession to query + * \param[in] domain Domain to query. The ONNX domain is an empty string. + * \param[out] opset The opset version of the domain. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, _Out_ int* opset); + + /** \brief Apply changes to augment the ONNX model in a session created using CreateModelEditorSession[FromArray] + * + * Adds new nodes and updates graph inputs/outputs using `model` to augment the original ONNX model in the session. + * All changes will be validated. + * Call FinalizeModelEditorSession to prepare the session for inferencing. + * + * Existing input/outputs will only be updated if the OrtGraph inputs/outputs are set in the OrtModel. + * i.e. you don't need to call SetGraphInputs/SetGraphOutputs if they are unchanged. + * + * ReleaseOrtModel must be called to free the OrtModel after it is applied to the session. + * + * \param[in] session OrtSession to update. Session must have been created using CreateModelEditorSession[FromArray]. + * \param[in] model OrtModel containing new nodes, new initializers, and updated graph input and/or output info. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ApplyModelToModelEditorSession, _Inout_ OrtSession* session, _In_ OrtModel* model); + + /** \brief Finalize the Model Editor session that was created using CreateModelEditorSession[FromArray]. + * + * Finalize the Model Editor session that augmented an ONNX model by adding new nodes. + * This will run optimizers and prepare the session for inferencing. + * + * \param[in] session OrtSession to finalize. Session must have been created using CreateModelEditorSession[FromArray]. + * \param[in] options OrtSessionOptions to use for the session. + * \param[in] prepacked_weights_container Optional OrtPrepackedWeightsContainer to use for the session. + Set to nullptr if not used. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(FinalizeModelEditorSession, _Inout_ OrtSession* session, _In_ const OrtSessionOptions* options, + _In_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container); +#endif // !defined(ORT_MINIMAL_BUILD) +}; + +/** + * ORT Compile API + */ + +/** \brief Flags representing options to enable when compiling a model. + */ +typedef enum OrtCompileApiFlags { + // Default. Do not enable any additional compilation options. + OrtCompileApiFlags_NONE = 0, + + // Force compilation to return an error (ORT_FAIL) if no nodes were compiled. + // Otherwise, a model with basic optimizations (ORT_ENABLE_BASIC) is still generated by default. + OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED = 1 << 0, + + // Force compilation to return an error (ORT_FAIL) if a file with the same filename as the output model exists. + // Otherwise, compilation will automatically overwrite the output file if it exists. + OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS = 1 << 1, +} OrtCompileApiFlags; + +/** + * \brief The OrtCompileApi struct provides functions to compile ONNX models. + * + * Execution providers that support compilation fuse a subgraph into an EPContext node that wraps a provider-specific + * binary representation of the subgraph. + * For more details about the EPContext design, refer to: + * \htmlonly + * EPContext design document. + * \endhtmlonly + * + * Example (error handling not shown): + * OrtStatus* status = NULL; + * OrtCompileApi* compile_api = ort_api->GetCompileApi(); + * OrtModelCompilationOptions* compile_options = NULL; + * + * status = compile_api->CreateModelCompilationOptionsFromSessionOptions(env, session_options, &compile_options); + * status = compile_api->ModelCompilationOptions_SetInputModelPath(compile_options, ORT_TSTR("model.onnx")); + * status = compile_api->ModelCompilationOptions_SetOutputModelPath(compile_options, ORT_TSTR("model.compiled.onnx")); + * status = compile_api->CompileModel(env, compile_options); + * compile_api->ReleaseModelCompilationOptions(compile_options); + * + * \since Version 1.22. + */ +struct OrtCompileApi { + /// @} + /// \name OrtModelCompilationOptions + /// @{ + ORT_CLASS_RELEASE(ModelCompilationOptions); + + /** \brief Creates an OrtModelCompilationOptions object from an existing OrtSessionOptions object. + * + * An OrtModelCompilationOptions object contains the settings used to generate a compiled ONNX model. + * The OrtSessionOptions object has the execution providers with which the model will be compiled. + * + * ReleaseOrtModelCompilationsOptions must be called to free the OrtModelCompilationOptions after calling + * CompileModel. + * + * \param[in] env OrtEnv object. + * \param[in] session_options The OrtSessionOptions instance from which to create the OrtModelCompilationOptions. + * \param[out] out The created OrtModelCompilationOptions instance. + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateModelCompilationOptionsFromSessionOptions, _In_ const OrtEnv* env, + _In_ const OrtSessionOptions* session_options, _Outptr_ OrtModelCompilationOptions** out); + + /** \brief Sets the file path to the input ONNX model to compile. + * + * The input model's location (e.g., file path or memory buffer) must be set with either + * ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] input_model_path Null terminated string of the path (wchar on Windows, char otherwise). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetInputModelPath, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* input_model_path); + + /** \brief Sets the buffer that stores the bytes of the loaded ONNX model to compile. + * + * The input model's location (e.g., file path or memory buffer) must be set with either + * ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] input_model_data Buffer containing the loaded ONNX model bytes. + * \param[in] input_model_data_size The number of bytes in the `input_model_data` buffer. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetInputModelFromBuffer, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const void* input_model_data, + size_t input_model_data_size); + + /** \brief Sets the file path for the output ONNX model generated by CompileModel. + * + * The output model's location (e.g., file path or memory buffer) can be set with either + * ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer. + * + * If the output model's location is not set, ONNX Runtime will generate an output file with a path based on + * the input model's file path. Examples: + * /Path/my_model.onnx -> /Path/my_model_ctx.onnx + * /Path/my_model -> /Path/my_model_ctx.onnx + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] output_model_path Null terminated string of the path (wchar on Windows, char otherwise). + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelPath, _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* output_model_path); + + /** \brief Optionally sets the file that should store external initializers for the compiled ONNX model. + * If not set, initializers are stored within the model. + * + * Only initializers for nodes that were not compiled are stored in the external initializers file. + * Compiled nodes contain their initializer data within the `ep_cache_context` attribute of EPContext nodes. + * Refer to ModelCompilationOptions_SetEpContextEmbedMode. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] external_initializers_file_path Null terminated string of the path to the file. + * \param[in] external_initializers_size_threshold Initializers larger than this threshold are stored in the file. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelExternalInitializersFile, + _In_ OrtModelCompilationOptions* model_compile_options, + _In_ const ORTCHAR_T* external_initializers_file_path, + size_t external_initializers_size_threshold); + + /** \brief Configures model compilation to store the output compiled ONNX model in a buffer. + * + * The caller passes an OrtAllocator that ONNX Runtime uses to allocate memory for the buffer. + * + * The output model's location (e.g., file path or memory buffer) can be set with either + * ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer. + * + * If the output model's location is not set, ONNX Runtime will generate an output file with a path based on + * the input model's file path. Examples: + * /Path/my_model.onnx -> /Path/my_model_ctx.onnx + * /Path/my_model -> /Path/my_model_ctx.onnx + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] allocator The allocator used to allocate the buffer for the compiled model. + * \param[out] output_model_buffer_ptr Pointer to the buffer that stores the compiled model. + * \param[out] output_model_buffer_size_ptr Pointer set to the size of output model in bytes. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelBuffer, + _In_ OrtModelCompilationOptions* model_compile_options, + _Inout_ OrtAllocator* allocator, + _Outptr_ void** output_model_buffer_ptr, + _Out_ size_t* output_model_buffer_size_ptr); + + /** \brief Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute + * of EPContext nodes. Defaults to false. + * + * If enabled, the `ep_cache_context` attribute of EPContext nodes will store the context binary data, which may + * include weights for compiled subgraphs. + * + * If disabled, the `ep_cache_context` attribute of EPContext nodes will contain the path to the file containing the + * context binary data. The path is set by the execution provider creating the EPContext node. + * + * More details relate to EPContext design refers to: + * \htmlonly + * EPContext design document. + * \endhtmlonly + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] embed_ep_context_in_model True to embed EPContext binary data into the EPContext node + * `ep_cache_context` attributes. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* model_compile_options, + bool embed_ep_context_in_model); + + /** \brief Compiles an input ONNX model with the given compilation options. + * + * \param[in] env OrtEnv object. + * \param[in] model_options The compilation options that defines compilation options for a model. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); + + /** \brief Sets flags from OrtCompileApiFlags that represent one or more boolean options to enable. + * + * \param[in] model_compile_options The OrtModelCompilationOptions instance. + * \param[in] flags bitwise OR of flags in OrtCompileApiFlags to enable. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.23. + */ + ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, + size_t flags); +}; + +ORT_RUNTIME_CLASS(Ep); +ORT_RUNTIME_CLASS(EpFactory); + +struct OrtEpApi { + /** \brief Create an OrtEpDevice for the EP and an OrtHardwareDevice. + * \param[in] ep_factory Execution provider factory that is creating the instance. + * \param[in] hardware_device Hardware device that the EP can utilize. + * \param[in] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used + * during execution provider selection and passed to CreateEp. + * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. + * \param[in] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added + * to the Session configuration options if the execution provider is selected. + * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. + * \param ep_device OrtExecutionDevice that is created. + * + * \since Version 1.22. + */ + ORT_API2_STATUS(CreateEpDevice, _In_ OrtEpFactory* ep_factory, + _In_ const OrtHardwareDevice* hardware_device, + _In_opt_ const OrtKeyValuePairs* ep_metadata, + _In_opt_ const OrtKeyValuePairs* ep_options, + _Out_ OrtEpDevice** ep_device); + + ORT_CLASS_RELEASE(EpDevice); +}; + +/** + * \brief The OrtEp struct provides functions to implement for an execution provider. + * \since Version 1.22. + */ +struct OrtEp { + /** \brief The ONNX Runtime version the execution provider was compiled with. + * + * Implementation should set to ORT_API_VERSION. + * ORT will use this to ensure it does not call functions that were not available when the library was compiled. + * + * \since Version 1.22. + */ + uint32_t ort_version_supported; + + /** \brief Get the execution provider name. + * + * \param[in] this_ptr The OrtEp instance. + * \return The execution provider name. + * + * \note Returned string is owned by ORT and valid until UnregisterExecutionProviderLibrary is called. + * + * \since Version 1.22. + */ + const char*(ORT_API_CALL* GetName)(const OrtEp* this_ptr); + + // OrtStatus* GetCapability(OrtEp* ep, const OrtGraph* graph, + // size_t* num_supported_subgraphs, + // OrtIndexedSubgraph** supported_subgraphs, OrtAllocator* allocator); + + // OrtStatus* Compile(OrtEp* ep, const OrtGraph** graphs, OrtNode** fused_graph_nodes, + // size_t count, OrtNodeComputeInfo* node_compute_infos); + + // TODO: Implement OrtEpApi and the complete OrtEp interface as the next step. +}; + +/** \brief The function signature that ORT will call to create OrtEpFactory instances. + * + * This must be available in a function called 'CreateEpFactories' in the execution provider library. + * + * \param[in] registered_name The name the execution library is registered with by RegisterExecutionProviderLibrary + * \param[in] ort_api_base The OrtApiBase instance that is used by the factory to get the OrtApi instance for the + * version of ORT that the library was compiled against. + * \param[in,out] factories The implementation should create and add OrtEpFactory instances to this + * pre-allocated array. + * i.e. usage is `factories[0] = new MyEpFactory();` + * \param[in] max_factories The maximum number of OrtEpFactory instances that can be added to `factories`. + * Current default is to allow 4 factories. This can be increased in the future if needed. + * \param[out] num_factories The number of OrtEpFactory instances created by the factory and added to `factories`. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ +typedef OrtStatus* (*CreateEpApiFactoriesFn)(_In_ const char* registered_name, _In_ const OrtApiBase* ort_api_base, + _Inout_ OrtEpFactory** factories, _In_ size_t max_factories, + _Out_ size_t* num_factories); + +/** \brief The function signature that ORT will call to release an OrtEpFactory instance. + * + * This must be available in a function called 'ReleaseEpFactory' in the execution provider library. + * + * \param[in] factory The OrtEpFactory instance to release. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.22. + */ +typedef OrtStatus* (*ReleaseEpApiFactoryFn)(_In_ OrtEpFactory* factory); + +/** + * \brief The OrtEpFactory provides functions to create and manage execution providers. + * \since Version 1.22. + */ +struct OrtEpFactory { + /** \brief The ONNX Runtime version the execution provider was compiled with. + * + * Implementation should set to ORT_API_VERSION. + * ORT will use this to ensure it does not call functions that were not available when the library was compiled. + * + * \since Version 1.22. + */ + uint32_t ort_version_supported; + + /** \brief Get the name the of the execution provider that the factory creates. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \return The name of the execution provider the factory creates. + * + * \since Version 1.22. + */ + const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr); + + /** \brief Get the name of vendor who owns the execution provider that the factory creates. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \return vendor The vendor name of the execution provider the factory creates. + * + * \since Version 1.22. + */ + const char*(ORT_API_CALL* GetVendor)(const OrtEpFactory* this_ptr); // return EP vendor + + /** \brief Get information from the execution provider if it supports the OrtHardwareDevice. + * + * \param[in] this_ptr The OrtEpFactory instance. + * Non-const as the factory is passed through to the CreateEp call via the OrtEpDevice. + * \param[in] devices The OrtHardwareDevice instances that are available. + * \param[in] num_devices The number of OrtHardwareDevice instances. + * \param[out] ep_devices OrtEpDevice instances for each OrtHardwareDevice that the EP can use. + * The implementation should call OrtEpApi::CreateEpDevice to create, and add the OrtEpDevice + * instances to this pre-allocated array. ORT will take ownership of the values returned. + * i.e. usage is `ep_devices[0] = ;` + * \param[in] max_ep_devices The maximum number of OrtEpDevices that can be added to ep_devices. + * Current default is 8. This can be increased if needed. + * \param[out] num_ep_devices The number of EP devices added to ep_devices. + * \return true if the factory can create an execution provider that uses `device`. + * + * \note ORT will take ownership or ep_metadata and/or ep_options if they are not null. + * + * \since Version 1.22. + */ + OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_ size_t num_devices, + _Inout_ OrtEpDevice** ep_devices, + _In_ size_t max_ep_devices, + _Out_ size_t* num_ep_devices); + + /** \brief Function to create an OrtEp instance for use in a Session. + * + * ORT will call ReleaseEp to release the instance when it is no longer needed. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] devices The OrtHardwareDevice instances that the execution provider was selected to use. + * \param[in] ep_metadata_pairs Execution provider metadata that was provided to OrtEpApi::CreateEpDevice, for each + * device. + * \param[in] num_devices The number of devices the execution provider was selected for. + * \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the + * session. This will include ep_options from GetSupportedDevices as well as any + * user provided overrides. + * Execution provider options will have been added with a prefix of 'ep.[ep name].'. + * The OrtSessionOptions instance will NOT be valid after this call and should not be + * stored for later use. + * \param[in] logger The OrtLogger instance for the session that the execution provider should use for logging. + * \param[out] ep The OrtEp instance created by the factory. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version [coming soon]. This is a placeholder. + */ + OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, + _In_reads_(num_devices) const OrtHardwareDevice* const* devices, + _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, + _In_ size_t num_devices, + _In_ const OrtSessionOptions* session_options, + _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); + + /** \brief Release the OrtEp instance. + * + * \param[in] this_ptr The OrtEpFactory instance. + * \param[in] ep The OrtEp instance to release. + * + * \since Version [coming soon]. This is a placeholder. + */ + void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); +}; + +/* + * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists + * + * \param device_id CUDA device id, starts from zero. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id); + +/* + * This is the old way to add the ROCm provider to the session, please use + * SessionOptionsAppendExecutionProvider_ROCM above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with + * HIP support and the ROCm provider shared library exists + * + * \param device_id HIP device id, starts from zero. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessionOptions* options, int device_id); + +/* + * This is the old way to add the MIGraphX provider to the session, please use + * SessionOptionsAppendExecutionProvider_MIGraphX above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with + * HIP support and the MIGraphX provider shared library exists + * + * \param device_id HIP device id, starts from zero. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id); + +/* + * This is the old way to add the oneDNN provider to the session, please use + * SessionOptionsAppendExecutionProvider_oneDNN above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with + * oneDNN support and the oneDNN provider shared library exists + * + * \param use_arena zero: false. non-zero: true. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena); + +/* + * This is the old way to add the TensorRT provider to the session, please use SessionOptionsAppendExecutionProvider_TensorRT_V2 above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with TensorRT support and the TensorRT provider shared library exists + * + * \param device_id CUDA device id, starts from zero. + */ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); + +#ifdef __cplusplus +} +#endif +/// @} diff --git a/src/ort_include/core/util/thread_utils.h b/src/ort_include/core/util/thread_utils.h index b146c0d..b5e2516 100644 --- a/src/ort_include/core/util/thread_utils.h +++ b/src/ort_include/core/util/thread_utils.h @@ -7,7 +7,8 @@ #include #include -struct OrtThreadPoolParams { +struct OrtThreadPoolParams +{ // 0: Use default setting. (All the physical cores or half of the logical cores) // 1: Don't create thread pool // n: Create a thread pool with n threads. @@ -37,15 +38,16 @@ struct OrtThreadPoolParams { // meaning ith thread will be attached to first 8 logical processors std::string affinity_str; - const ORTCHAR_T* name = nullptr; + const ORTCHAR_T *name = nullptr; // Set or unset denormal as zero bool set_denormal_as_zero = false; }; -std::ostream& operator<<(std::ostream& os, const OrtThreadPoolParams& params); +std::ostream &operator<<(std::ostream &os, const OrtThreadPoolParams ¶ms); -struct OrtThreadingOptions { +struct OrtThreadingOptions +{ // Params for creating the threads that parallelizes execution of an op OrtThreadPoolParams intra_op_thread_pool_params; @@ -53,14 +55,17 @@ struct OrtThreadingOptions { OrtThreadPoolParams inter_op_thread_pool_params; }; -namespace onnxruntime { +namespace onnxruntime +{ -namespace concurrency { -enum class ThreadPoolType : uint8_t { - INTRA_OP, - INTER_OP -}; -std::unique_ptr CreateThreadPool(Env* env, OrtThreadPoolParams options, - ThreadPoolType tpool_type); -} // namespace concurrency -} // namespace onnxruntime + namespace concurrency + { + enum class ThreadPoolType : uint8_t + { + INTRA_OP, + INTER_OP + }; + std::unique_ptr CreateThreadPool(Env *env, OrtThreadPoolParams options, + ThreadPoolType tpool_type); + } // namespace concurrency +} // namespace onnxruntime diff --git a/tests/bench/bench_cast.cpp b/tests/bench/bench_cast.cpp new file mode 100644 index 0000000..1dccbe4 --- /dev/null +++ b/tests/bench/bench_cast.cpp @@ -0,0 +1,54 @@ +#include "bench_util.h" +#include "core/mlas/lib/mlasi.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +void BM_ConvertF16ToF32(benchmark::State& state) { + bool aligned = static_cast(state.range(0)); + const size_t count = 1 << 18; + auto src = RandomVectorUniform(count, 0, 60000); + auto dst = std::vector(count + 16); + auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); + float* dst_start = aligned ? reinterpret_cast(aligned_dst) + : reinterpret_cast(aligned_dst + 1); + + // Warm up + MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); + + for (auto _ : state) { + MlasCastF16ToF32KernelNeon(src.data(), dst_start, count); + } +} + +void BM_ConvertF32ToF16(benchmark::State& state) { + bool aligned = static_cast(state.range(0)); + const size_t count = 1 << 18; + auto src = RandomVectorUniform(count, -30000.0f, 30000.0f); + auto dst = std::vector(count + 16); + auto aligned_dst = (reinterpret_cast(dst.data()) + 15) & (~15); + unsigned short* dst_start = aligned ? reinterpret_cast(aligned_dst) + : reinterpret_cast(aligned_dst + 1); + + // Warm up + MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); + + for (auto _ : state) { + MlasCastF32ToF16KernelNeon(src.data(), dst_start, count); + } +} + +BENCHMARK(BM_ConvertF16ToF32) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"aligned"}); + b->ArgsProduct({{0, 1}}); + }); + +BENCHMARK(BM_ConvertF32ToF16) + ->UseRealTime() + ->Apply([](benchmark::internal::Benchmark* b) { + b->ArgNames({"aligned"}); + b->ArgsProduct({{0, 1}}); + }); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/tests/bench/bench_computesoftmax.cpp b/tests/bench/bench_computesoftmax.cpp index 57ab53c..32135b3 100644 --- a/tests/bench/bench_computesoftmax.cpp +++ b/tests/bench/bench_computesoftmax.cpp @@ -5,19 +5,7 @@ #include "core/util/thread_utils.h" #include "bench_util.h" -#ifndef BUILD_MLAS_NO_ONNXRUNTIME using onnxruntime::narrow; -#else -using gsl::narrow; -#define ORT_THROW(X) throw std::runtime_error(X) -#define ORT_ENFORCE(condition, ...) \ - do { \ - if (!(condition)) { \ - abort(); \ - } \ - } while (false) -#define ORT_THROW_EX(X) throw X(); -#endif struct RestrictAlignedPtr { float* ptr; // Aligned pointer within the underlying buffer diff --git a/tests/bench/bench_hgemm.cpp b/tests/bench/bench_hgemm.cpp new file mode 100644 index 0000000..f42c5e5 --- /dev/null +++ b/tests/bench/bench_hgemm.cpp @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas.h" +#include "bench_util.h" +#include "core/util/thread_utils.h" + +#include +#include + +static const std::vector hgemm_bench_arg_names = {"M", "N", "K"}; + +void HGEMM(benchmark::State& state, bool transA, bool transB) { + if (state.range(0) <= 0) throw std::invalid_argument("M must greater than 0!"); + if (state.range(1) <= 0) throw std::invalid_argument("N must greater than 0!"); + if (state.range(2) <= 0) throw std::invalid_argument("K must greater than 0!"); + const size_t M = static_cast(state.range(0)); + const size_t N = static_cast(state.range(1)); + const size_t K = static_cast(state.range(2)); + + auto A = RandomVectorUniform(static_cast(M * K), MLAS_FP16(-1.0f), MLAS_FP16(1.0f)); + auto B = RandomVectorUniform(static_cast(N * K), MLAS_FP16(-1.0f), MLAS_FP16(1.0f)); + std::vector C(static_cast(M * N)); + + MLAS_FP16 alpha = MLAS_FP16(1.0f); + MLAS_FP16 beta = MLAS_FP16(0.0f); + OrtThreadPoolParams tpo; + tpo.thread_pool_size = 8; + tpo.auto_set_affinity = true; + std::unique_ptr tp( + onnxruntime::concurrency::CreateThreadPool(&onnxruntime::Env::Default(), + tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); + MlasGemm( + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + static_cast(M), + static_cast(N), + static_cast(K), + A.data(), + transA ? M : K, + B.data(), + transB ? K : N, + C.data(), + N, + alpha.val, + beta.val, + tp.get()); + + for (auto _ : state) { + MlasGemm( + transA ? CblasTrans : CblasNoTrans, + transB ? CblasTrans : CblasNoTrans, + static_cast(M), + static_cast(N), + static_cast(K), + A.data(), + transA ? M : K, + B.data(), + transB ? K : N, + C.data(), + N, + alpha.val, + beta.val, + tp.get()); + } +} + +static void GemmSizeWithOne(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{1}, {63, 255, 1023}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {1}, {63, 255, 1023}}); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {1}}); +} +BENCHMARK_CAPTURE(HGEMM, GEMV_TransB, false, true)->Apply(GemmSizeWithOne)->UseRealTime(); +BENCHMARK_CAPTURE(HGEMM, GEMV_B, false, false)->Apply(GemmSizeWithOne)->UseRealTime(); + +static void GemmSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{63, 255, 1023}, {63, 255, 1023}, {63, 255, 1023}}); +} +BENCHMARK_CAPTURE(HGEMM, NORMAL_TransB, false, true)->Apply(GemmSizeProducts)->UseRealTime(); +BENCHMARK_CAPTURE(HGEMM, NORMAL_B, false, false)->Apply(GemmSizeProducts)->UseRealTime(); + +static void GemmLLMSizeProducts(benchmark::internal::Benchmark* b) { + b->ArgNames(hgemm_bench_arg_names); + b->ArgsProduct({{1, 1024, 2048}, {4096, 11008}, {4096, 11008}}); +} +BENCHMARK_CAPTURE(HGEMM, LLM_TransB, false, true)->Apply(GemmLLMSizeProducts)->UseRealTime(); +BENCHMARK_CAPTURE(HGEMM, LLM_B, false, false)->Apply(GemmLLMSizeProducts)->UseRealTime(); diff --git a/tests/bench/bench_qnbitgemm.cpp b/tests/bench/bench_qnbitgemm.cpp index 64d2298..8ad3b59 100644 --- a/tests/bench/bench_qnbitgemm.cpp +++ b/tests/bench/bench_qnbitgemm.cpp @@ -31,8 +31,8 @@ void RunQNBitGemmBenchmark(size_t BlkLen, } size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; - MlasBlockwiseQuantizedBufferSizes( - BlkBitWidth, static_cast(BlkLen), /* columnwise */ true, + MlasBlockwiseQuantizedBufferSizes( + static_cast(BlkLen), /* columnwise */ true, static_cast(K), static_cast(N), QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); @@ -63,13 +63,13 @@ void RunQNBitGemmBenchmark(size_t BlkLen, tp.get()); std::unique_ptr Workspace; - if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, !Symmetric, ComputeType); WorkspaceSize > 0) { Workspace = std::make_unique(WorkspaceSize); } std::unique_ptr PackedQuantBData; - if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, !Symmetric, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBData = std::make_unique(PackedQuantBDataSize); MlasQNBitGemmPackQuantBData(N, K, BlkBitWidth, BlkLen, ComputeType, QuantBData.data(), PackedQuantBData.get(), @@ -135,6 +135,7 @@ static void QNBitGemmArgs(benchmark::internal::Benchmark* b) { } BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); +BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); BENCHMARK(QNBITGEMM)->Apply(QNBitGemmArgs)->UseRealTime(); // This test gets benchmark arguments from environment variables. diff --git a/tests/bench/bench_rope.cpp b/tests/bench/bench_rope.cpp new file mode 100644 index 0000000..b0630b9 --- /dev/null +++ b/tests/bench/bench_rope.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "mlas.h" +#include "benchmark/benchmark.h" +#include "bench_util.h" +#include "core/framework/float16.h" + +using namespace onnxruntime; + +template +void RunRoPEBenchmark(size_t rotary_emb_dim, bool interleaved, benchmark::State& state) { + std::vector input(rotary_emb_dim); + size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; + std::vector sin_data(table_len); + std::vector cos_data(table_len); + std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim); + + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = static_cast(i + 1.0f); + } + for (size_t i = 0; i < table_len; ++i) { + // https://arxiv.org/pdf/2104.09864 section 3.4.3 + float theta_i = static_cast(pow(10000, -2.0f * i / rotary_emb_dim)); + sin_data[i] = static_cast(std::sin(theta_i)); + cos_data[i] = static_cast(std::cos(theta_i)); + } + + // warm up run + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); + + for (auto _ : state) { + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); + } +} + +template +void RoPE(benchmark::State& state) { + using onnxruntime::narrow; + + const auto rotary_emb_dim = narrow(state.range(0)); + const auto interleaved = narrow(state.range(1)); + + RunRoPEBenchmark(rotary_emb_dim, interleaved, state); +} + +template +static void RoPEArgs(benchmark::internal::Benchmark* b) { + b->ArgNames({"rotary_emb_dim", "interleaved"}); + + b->ArgsProduct({ + {128, 256, 512, 1024}, // rotary_emb_dim + {int64_t{false}, int64_t{true}}, // interleaved + }); +} + +BENCHMARK(RoPE)->Apply(RoPEArgs)->UseRealTime(); +BENCHMARK(RoPE)->Apply(RoPEArgs)->UseRealTime(); diff --git a/tests/unittest/test_blockq4.cpp b/tests/unittest/test_blockq4.cpp index f75002f..7e8f128 100644 --- a/tests/unittest/test_blockq4.cpp +++ b/tests/unittest/test_blockq4.cpp @@ -18,192 +18,163 @@ Module Name: #include "test_util.h" #include "mlas_q4.h" +#include "mlasi.h" +template +int GetElem(int v, int idx) { + return (v >> (qbits * idx)) & ((1 << qbits) - 1); +} + +template +int SetElem(int v, int idx, int value) { + v &= ~(((1 << qbits) - 1) << (qbits * idx)); + v |= (value & ((1 << qbits) - 1)) << (qbits * idx); + return v; +} + +template class MlasBlockwiseQdqTest : public MlasTestBase { private: - MatrixGuardBuffer FpBuf; - MatrixGuardBuffer FpBuf2; + std::random_device rd; + std::mt19937 gen{192373}; + std::uniform_int_distribution dist_int{0, (1 << qbits) - 1}; + std::uniform_real_distribution dist_float{-1.f, 1.f}; + MatrixGuardBuffer FpBuf; + MatrixGuardBuffer FpBuf2; + MatrixGuardBuffer FpBuf3; MatrixGuardBuffer InputElements; - MatrixGuardBuffer InputScales; + MatrixGuardBuffer InputScales; MatrixGuardBuffer InputOffsets; - MatrixGuardBuffer OutputElements; - MatrixGuardBuffer OutputScales; - MatrixGuardBuffer OutputOffsets; MatrixGuardBuffer QDQOutputElements; - MatrixGuardBuffer QDQOutputScales; + MatrixGuardBuffer QDQOutputScales; MatrixGuardBuffer QDQOutputOffsets; MatrixGuardBuffer QDQTransposedOutputElements; - MatrixGuardBuffer QDQTransposedOutputScales; + MatrixGuardBuffer QDQTransposedOutputScales; MatrixGuardBuffer QDQTransposedOutputOffsets; + constexpr static float err_ = qbits == 8 ? 1e-2f : qbits == 4 ? 6e-2f + : 2e-1f; + constexpr static float rel_ = qbits == 8 ? 5e-2f : qbits == 4 ? 2e-1f + : 5e-1f; + + bool FloatEqual(T a, T b, float err = err_, float rel = rel_) { + float va = static_cast(a); + float vb = static_cast(b); + return std::abs(va - vb) < err + std::abs(va) * rel; + } void Test(int rows, int columns, int block_size, bool columnwise, bool symmetric) { - float* dequant_buf = FpBuf.GetBuffer(rows * columns, true); - float* transposed = FpBuf2.GetBuffer(rows * columns, true); + constexpr int packSize = 8 / qbits; + T* input = FpBuf.GetFilledBuffer(rows * columns, [this](T* start, size_t size) { + for (size_t i = 0; i < size; i++) { + start[i] = T(this->dist_float(this->gen)); + } + }); + T* dequant = FpBuf2.GetBuffer(rows * columns, true); + T* transposed = FpBuf3.GetBuffer(rows * columns, true); size_t scale_size = (rows + block_size - 1) / block_size * columns; - size_t zp_size = (scale_size + 1) / 2; - + size_t zp_size = (scale_size + packSize - 1) / packSize; MLAS_THREADPOOL* threadpool_ptr = GetMlasThreadPool(); int meta_rows; int meta_cols; - MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); + MlasBlockwiseQuantMetaShape(block_size, columnwise, rows, columns, meta_rows, meta_cols); int q_rows; int q_cols; - MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); + MlasBlockwiseQuantizedShape(block_size, columnwise, rows, columns, q_rows, q_cols); size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; - MlasBlockwiseQuantizedBufferSizes(4, block_size, columnwise, rows, columns, - q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); - - uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); - uint8_t* qdq_weights = QDQOutputElements.GetBuffer((rows * columns + 1) / 2, true); - uint8_t* qdq_weights_T = QDQTransposedOutputElements.GetBuffer(q_data_size_in_bytes, true); - - int v = 7; - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - - elements[idx] = (v1 << 4) | v0; - } + MlasBlockwiseQuantizedBufferSizes(block_size, columnwise, rows, columns, + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + uint8_t* elements = InputElements.GetBuffer(q_data_size_in_bytes, true); // after quantize + uint8_t* qdq_weights; + uint8_t* qdq_weights_T; + if constexpr (qbits == 4) { + qdq_weights = QDQOutputElements.GetBuffer((rows * columns + packSize - 1) / packSize, true); + qdq_weights_T = QDQTransposedOutputElements.GetBuffer(q_data_size_in_bytes, true); } - float* scales = InputScales.GetBuffer(q_scale_size); - float* qdq_scales = QDQOutputScales.GetBuffer(scale_size); - float* qdq_scales_T = QDQTransposedOutputScales.GetBuffer(q_scale_size); + T* scales = InputScales.GetBuffer(q_scale_size, true); uint8_t* zp = symmetric ? nullptr : InputOffsets.GetBuffer(q_zp_size_in_bytes, true); - uint8_t* qdq_zp = symmetric ? nullptr : QDQOutputOffsets.GetBuffer(zp_size, true); - uint8_t* qdq_zp_T = symmetric ? nullptr : QDQTransposedOutputOffsets.GetBuffer(q_zp_size_in_bytes, true); - if (zp) { - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - uint8_t v0 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - uint8_t v1 = 0; - if (r + 1 < meta_rows) { - v1 = static_cast(v); - v = (v + 5) % 16; - if (v == 11 || v == 7 || v == 3) { - // making the cycle 13 instead of 16, avoiding same values in a row - v = (v + 5) % 16; - } - } - zp[idx] = (v1 << 4) | v0; - } - } + T* qdq_scales; + T* qdq_scales_T; + uint8_t* qdq_zp; + uint8_t* qdq_zp_T; + if constexpr (qbits == 4) { + qdq_scales = QDQOutputScales.GetBuffer(scale_size, true); + qdq_scales_T = QDQTransposedOutputScales.GetBuffer(q_scale_size, true); + qdq_zp = symmetric ? nullptr : QDQOutputOffsets.GetBuffer(zp_size, true); + qdq_zp_T = symmetric ? nullptr : QDQTransposedOutputOffsets.GetBuffer(q_zp_size_in_bytes, true); } - MlasDequantizeBlockwise(dequant_buf, elements, scales, zp, block_size, - columnwise, rows, columns, threadpool_ptr); - - MlasTranspose(dequant_buf, transposed, columns, rows); - - uint8_t* o_elements = OutputElements.GetBuffer(q_rows * q_cols, true); - float* o_scales = OutputScales.GetBuffer(meta_rows * meta_cols); - uint8_t* o_zp = symmetric ? nullptr : OutputOffsets.GetBuffer(((meta_rows + 1) / 2) * meta_cols, true); - - MlasQuantizeBlockwise(o_elements, o_scales, o_zp, transposed, block_size, + MlasQuantizeBlockwise(elements, scales, zp, input, block_size, columnwise, rows, columns, columns, threadpool_ptr); - if (columnwise) { - bool signed_quant = MlasQDQQuantizeBlockwise( - transposed, qdq_scales, qdq_zp, qdq_weights, - true, rows, columns, block_size, threadpool_ptr); + MlasDequantizeBlockwise(dequant, elements, scales, zp, block_size, + columnwise, rows, columns, threadpool_ptr); - ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; + MlasTranspose(dequant, transposed, columns, rows, threadpool_ptr); - if (symmetric) { - MlasQDQTransposeBlockwiseQuantized( - qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + if constexpr (qbits == 4) { + if (columnwise) { + bool signed_quant = MlasQDQQuantizeBlockwise( + input, qdq_scales, qdq_zp, qdq_weights, true, rows, columns, block_size, threadpool_ptr); - } else { - MlasQDQTransposeBlockwiseQuantized( - qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, - true, rows, columns, block_size, threadpool_ptr); + ASSERT_EQ(symmetric, signed_quant) << "symmetric quantization should be signed"; + + if (symmetric) { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); + + } else { + MlasQDQTransposeBlockwiseQuantized( + qdq_weights, qdq_scales, qdq_zp, qdq_weights_T, qdq_scales_T, qdq_zp_T, + true, rows, columns, block_size, threadpool_ptr); + } } } - for (int c = 0; c < columns; c++) { - for (int r = 0; r < rows; r += 2) { - int idx = c * q_rows + r / 2; - ASSERT_EQ(o_elements[idx] & 0xf, elements[idx] & 0xf) + for (int r = 0; r < rows; r++) { + for (int c = 0; c < columns; c++) { + int idx = r * columns + c; + ASSERT_TRUE(FloatEqual(input[idx], transposed[idx])) + << " input: " << input[idx] << ", transposed: " << transposed[idx] << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] & 0xf, elements[idx] & 0xf) - << ", index=[" << r << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } + } + } - if (r + 1 < rows) { - ASSERT_EQ(o_elements[idx] >> 4, elements[idx] >> 4) - << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_weights_T[idx] >> 4, elements[idx] >> 4) - << ", index=[" << r + 1 << "x" << c << "], shape=[" << rows << "x" << columns + if (columnwise && qbits == 4) { + for (int c = 0; c < columns; c++) { + for (int r = 0; r < rows; r += packSize) { + int idx = c * q_rows + r / packSize; + for (int l = 0; l < packSize && l + r < rows; ++l) { + ASSERT_EQ(GetElem(qdq_weights_T[idx], l), GetElem(elements[idx], l)) + << ", qdq index=[" << r + l << "x" << c << "], shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; } } } - } - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r++) { - int idx = c * meta_rows + r; - ASSERT_EQ(o_scales[idx], scales[idx]) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - - if (columnwise) { - ASSERT_EQ(qdq_scales_T[idx], scales[idx]) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r++) { + int idx = c * meta_rows + r; + ASSERT_TRUE(FloatEqual(qdq_scales_T[idx], scales[idx])) + << ", qdq index=" << r << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; } } - } - if (symmetric) return; - for (int c = 0; c < meta_cols; c++) { - for (int r = 0; r < meta_rows; r += 2) { - int idx = c * ((meta_rows + 1) / 2) + r / 2; - ASSERT_EQ(o_zp[idx] & 0xf, zp[idx] & 0xf) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] & 0xf, zp[idx] & 0xf) - << ", index=" << r << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - } - if (r + 1 < meta_rows) { - ASSERT_EQ(o_zp[idx] >> 4, zp[idx] >> 4) - << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns - << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; - if (columnwise) { - ASSERT_EQ(qdq_zp_T[idx] >> 4, zp[idx] >> 4) - << ", index=" << r + 1 << "x" << c << ", shape=[" << rows << "x" << columns + if (symmetric) return; + for (int c = 0; c < meta_cols; c++) { + for (int r = 0; r < meta_rows; r += packSize) { + int idx = c * ((meta_rows + packSize - 1) / packSize) + r / packSize; + for (int l = 0; l < packSize && r + l < meta_rows; ++l) { + ASSERT_EQ(GetElem(qdq_zp_T[idx], l), GetElem(zp[idx], l)) + << ", qdq index=" << r + l << "x" << c << ", shape=[" << rows << "x" << columns << "] block: " << block_size << ", symmetric: " << symmetric << ", columnwise: " << columnwise; } } @@ -213,7 +184,7 @@ class MlasBlockwiseQdqTest : public MlasTestBase { public: static const char* GetTestSuiteName() { - static const std::string suite_name("BlockQ4"); + static const std::string suite_name("BlockQ" + std::to_string(qbits)); return suite_name.c_str(); } @@ -263,7 +234,9 @@ class MlasBlockwiseQdqTest : public MlasTestBase { static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { - count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } return count; }); diff --git a/tests/unittest/test_eltwise.cpp b/tests/unittest/test_eltwise.cpp new file mode 100644 index 0000000..720ff37 --- /dev/null +++ b/tests/unittest/test_eltwise.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" +#include "mlasi.h" +#include "eltwise.h" + +class MlasEltwiseAddTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInputLeft; + MatrixGuardBuffer BufferInputRight; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + MatrixGuardBuffer BufferInputLeftFp16; + MatrixGuardBuffer BufferInputRightFp16; + MatrixGuardBuffer BufferOutputFp16; + + void Test(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) { + float* InputLeft = BufferInputLeft.GetBuffer(N); + float* InputRight = BufferInputRight.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + InputLeft[n] = distribution(generator); + InputRight[n] = ScalarValue.value_or(distribution(generator)); + } + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = InputLeft[n] + InputRight[n]; + } + + MlasEltwiseAdd(InputLeft, InputRight, Output, N); + + constexpr float AbsoluteTolerance = 1e-6f; + constexpr float RelativeTolerance = 1e-6f; + + for (size_t n = 0; n < N; n++) { + float diff = std::fabs(Output[n] - OutputReference[n]); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[n]) * RelativeTolerance) + << " @" << n << " of " << N << ", got: " << Output[n] << ", expecting: " << OutputReference[n]; + } + } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + void TestFp16(size_t N, float MinimumValue, float MaximumValue, const std::optional& ScalarValue = std::nullopt) { + MLAS_FP16* InputLeft = BufferInputLeftFp16.GetBuffer(N); + MLAS_FP16* InputRight = BufferInputRightFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + InputLeft[n] = MLAS_FP16(distribution(generator)); + InputRight[n] = MLAS_FP16(ScalarValue.value_or(distribution(generator))); + } + + MlasEltwiseAdd(InputLeft, InputRight, Output, N); + + constexpr float AbsoluteTolerance = 5e-4f; + constexpr float RelativeTolerance = 1e-3f; + + for (size_t n = 0; n < N; n++) { + float inLeft = InputLeft[n].ToFloat(); + float inRight = InputRight[n].ToFloat(); + float ref = inLeft + inRight; + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << inLeft << ", " << inRight << ", got: " << out << ", expecting: " << ref + << ", r-diff: " << diff / std::fabs(ref); + } + } + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Eltwise_Add"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { + Test(n, -10.f, 10.f); + Test(n, -10.f, 10.f, -5000.f); +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(n, -17.f, 11.f); + TestFp16(n, -17.f, 11.f, -5000.f); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); diff --git a/tests/unittest/test_halfgemm.h b/tests/unittest/test_halfgemm.h index 006b6a7..4db5c2b 100644 --- a/tests/unittest/test_halfgemm.h +++ b/tests/unittest/test_halfgemm.h @@ -15,7 +15,7 @@ Module Name: --*/ #pragma once -#include + #include "test_fp16.h" /** diff --git a/tests/unittest/test_hgemm_neon.cpp b/tests/unittest/test_hgemm_neon.cpp new file mode 100644 index 0000000..19d41fa --- /dev/null +++ b/tests/unittest/test_hgemm_neon.cpp @@ -0,0 +1,683 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_hgemm_neon.cpp + +Abstract: + + Tests for MLAS fp16 GEMM on ARM CPU. + +--*/ + +#include +#include + +#include "test/mlas/unittest/test_util.h" +#include "mlasi.h" +#include "halfgemm.h" + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + +class MlasNeonHGemmPackBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer input_, ref_, packed_; + + template + MLAS_FORCEINLINE void PackB_TransposedB(const MLAS_FP16* src, MLAS_FP16* dst) { + size_t i = 0; + for (; i + 32 <= N; i += 32) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 32; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + } + } + if (i + 16 <= N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 16; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + } + i += 16; + } + if (i + 8 <= N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 8; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + } + i += 8; + } + if (i < N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < N - i; ++k) { + *dst = src[(i + k) * K + j]; + ++dst; + } + dst += 8 - (N - i); + } + } + } + + template + MLAS_FORCEINLINE void PackB_B(const MLAS_FP16* src, MLAS_FP16* dst) { + size_t i = 0; + for (; i + 32 <= N; i += 32) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 32; ++k) { + *dst = src[(i + k) + j * N]; + ++dst; + } + } + } + for (; i + 16 <= N; i += 16) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 16; ++k) { + *dst = src[(i + k) + j * N]; + ++dst; + } + } + } + if (i + 8 <= N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < 8; ++k) { + *dst = src[(i + k) + j * N]; + ++dst; + } + } + i += 8; + } + if (i < N) { + for (size_t j = 0; j < K; ++j) { + for (size_t k = 0; k < N - i; ++k) { + *dst = src[(i + k) + j * N]; + ++dst; + } + dst += 8 - (N - i); + } + } + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* packed, const MLAS_FP16* ref) { + size_t j = 0; + for (; j + 31 < N; j += 32) { + for (size_t i = 0; i < 32 * K; ++i) { + ASSERT_EQ(packed[j * K + i].val, ref[j * K + i].val) + << " seed " << seed_ << " K " << i / 32 << " N " << j + i % 32; + } + } + for (; j + 15 < N; j += 16) { + for (size_t i = 0; i < 16 * K; ++i) { + ASSERT_EQ(packed[j * K + i].val, ref[j * K + i].val) + << " seed " << seed_ << " K " << i / 16 << " N " << j + i % 16; + } + } + if (j + 7 < N) { + for (size_t i = 0; i < 8 * K; ++i) { + ASSERT_EQ(packed[j * K + i].val, ref[j * K + i].val) + << " seed " << seed_ << " K " << i / 8 << " N " << j + i % 8; + } + j += 8; + } + if (j < N) { + for (size_t i = 0; i < K; ++i) { + for (size_t k = 0; k < N - j; ++k) { + ASSERT_EQ(packed[j * K + i * 8 + k].val, ref[j * K + i * 8 + k].val) + << " seed " << seed_ << " K " << i << " N " << j + k; + } + } + } + } + + template + void TestPackB_TransposedB() { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(N * K, InitializeBuffer); + auto* packed = packed_.GetBuffer(K * ((N + 7) & ~7), true); + auto* ref = ref_.GetBuffer(K * ((N + 7) & ~7), true); + hgemm_neon::HPackB_TransposedB_Kernel(input, packed, N, K, K); + PackB_TransposedB(input, ref); + Check(packed, ref); + } + + template + void TestPackB_B() { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const auto* input = input_.GetFilledBuffer(N * K, InitializeBuffer); + auto* packed = packed_.GetBuffer(K * ((N + 7) & ~7), true); + auto* ref = ref_.GetBuffer(K * ((N + 7) & ~7), true); + hgemm_neon::HPackB_B_Kernel(input, packed, N, K, N); + PackB_B(input, ref); + Check(packed, ref); + } + + public: + MlasNeonHGemmPackBTest() + : seed_(rd_()), gen_(seed_), distrib_(-100.f, 100.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmPackB"; + } + + void ExecuteShort(void) override { + TestPackB_TransposedB<1, 1>(); + TestPackB_TransposedB<1, 15>(); + TestPackB_TransposedB<1, 31>(); + TestPackB_TransposedB<8, 1>(); + TestPackB_TransposedB<8, 16>(); + TestPackB_TransposedB<9, 31>(); + TestPackB_TransposedB<9, 33>(); + TestPackB_TransposedB<31, 33>(); + TestPackB_TransposedB<33, 67>(); + TestPackB_TransposedB<63, 96>(); + TestPackB_TransposedB<271, 263>(); + TestPackB_B<1, 1>(); + TestPackB_B<1, 15>(); + TestPackB_B<1, 31>(); + TestPackB_B<8, 1>(); + TestPackB_B<8, 16>(); + TestPackB_B<9, 31>(); + TestPackB_B<9, 33>(); + TestPackB_B<31, 31>(); + TestPackB_B<63, 33>(); + TestPackB_B<33, 31>(); + TestPackB_B<33, 33>(); + TestPackB_B<65, 67>(); + TestPackB_B<65, 96>(); + TestPackB_B<271, 263>(); + } +}; + +class MlasNeonHGemmTransposedBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm( + const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta, + size_t lda, size_t ldb, size_t ldc) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n * ldb + k].ToFloat()); + } + C[m * ldc + n] = MLAS_FP16(accu * alphaf + C[m * ldc + n].ToFloat() * betaf); + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref, size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t k = i * ldc + j; + ASSERT_TRUE(FloatEqual(C[k], ref[k], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i << " j " << j + << " M " << M << " N " << N << " ldc " << ldc + << " value " << C[k] << " ref " << ref[k]; + } + } + } + + template + MLAS_FORCEINLINE void Copy(const MLAS_FP16* C, MLAS_FP16* ref, size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t k = i * ldc + j; + ref[k] = C[k]; + } + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const size_t lda = ((K + 7) & ~7); + const size_t ldb = ((K + 7) & ~7) + 8; + const size_t ldc = ((N + 7) & ~7); + const auto* A = A_.GetFilledBuffer(M * lda, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(ldb * N, InitializeBuffer); + auto* C = C_.GetFilledBuffer(M * ldc, InitializeBuffer); + auto* ref = ref_.GetBuffer(M * ldc, true); + Copy(C, ref, ldc); + hgemm_neon::HGemm_TransposedB_Kernel(A, B, C, M, N, K, lda, ldb, ldc, alpha.val, beta.val); + HGemm(A, B, ref, alpha, beta, lda, ldb, ldc); + Check(C, ref, ldc); + } + + public: + MlasNeonHGemmTransposedBTest() + : seed_(1928375), gen_(seed_), distrib_(-1.f, 1.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmTransposedB"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 78, 263>(MLAS_FP16(0.5f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.5f), MLAS_FP16(1.0f)); + } +}; + +class MlasNeonHGemmBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm( + const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta, + size_t lda, size_t ldb, size_t ldc) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n + k * ldb].ToFloat()); + } + C[m * ldc + n] = MLAS_FP16(accu * alphaf + C[m * ldc + n].ToFloat() * betaf); + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref, size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t idx = i * ldc + j; + ASSERT_TRUE(FloatEqual(C[idx], ref[idx], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i << " j " << j + << " M " << M << " N " << N << " ldc " << ldc + << " value " << C[idx] << " ref " << ref[idx]; + } + } + } + + template + MLAS_FORCEINLINE void Copy(const MLAS_FP16* C, MLAS_FP16* ref, size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t idx = i * ldc + j; + ref[idx] = C[idx]; + } + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const size_t lda = ((K + 7) & ~7); + const size_t ldb = ((N + 7) & ~7) + 8; + const size_t ldc = ((N + 7) & ~7); + const auto* A = A_.GetFilledBuffer(M * lda, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(K * ldb, InitializeBuffer); + auto* C = C_.GetFilledBuffer(M * ldc, InitializeBuffer); + auto* ref = ref_.GetBuffer(M * ldc, true); + Copy(C, ref, ldc); + hgemm_neon::HGemm_B_Kernel(A, B, C, M, N, K, lda, ldb, ldc, alpha.val, beta.val); + HGemm(A, B, ref, alpha, beta, lda, ldb, ldc); + Check(C, ref, ldc); + } + + public: + MlasNeonHGemmBTest() + : seed_(172387), gen_(seed_), distrib_(-1.f, 1.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmB"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 78, 263>(MLAS_FP16(0.5f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(1.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.f), MLAS_FP16(0.f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(1.f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.f), MLAS_FP16(1.f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(1.f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.f), MLAS_FP16(0.f)); + TestHGemm<1, 78, 263>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.f), MLAS_FP16(1.0f)); + TestHGemm<2, 65, 65>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + TestHGemm<2, 63, 63>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + TestHGemm<2, 65, 63>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + TestHGemm<2, 63, 65>(MLAS_FP16(1.f), MLAS_FP16(0.0f)); + } +}; + +class MlasNeonHGemmPackedBTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm( + const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta, + size_t lda, size_t ldc) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + size_t n = 0; + for (; n + 32 <= N; n += 32) { + for (size_t i = 0; i < 32; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n * K + k * 32 + i].ToFloat()); + } + C[m * ldc + n + i] = MLAS_FP16(accu * alphaf + C[m * ldc + n + i].ToFloat() * betaf); + } + } + } + for (; n + 16 <= N; n += 16) { + for (size_t i = 0; i < 16; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n * K + k * 16 + i].ToFloat()); + } + C[m * ldc + n + i] = MLAS_FP16(accu * alphaf + C[m * ldc + n + i].ToFloat() * betaf); + } + } + } + if (n + 8 <= N) { + for (size_t i = 0; i < 8; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n * K + k * 8 + i].ToFloat()); + } + C[m * ldc + n + i] = MLAS_FP16(accu * alphaf + C[m * ldc + n + i].ToFloat() * betaf); + } + } + n += 8; + } + if (n < N) { + for (size_t i = 0; i < N - n; ++i) { + for (size_t m = 0; m < M; ++m) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[m * lda + k].ToFloat()) * (B[n * K + k * 8 + i].ToFloat()); + } + C[m * ldc + n + i] = MLAS_FP16(accu * alphaf + C[m * ldc + n + i].ToFloat() * betaf); + } + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref, const size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t k = i * ldc + j; + ASSERT_TRUE(FloatEqual(C[k], ref[k], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i << " j " << j + << " M " << M << " K " << K << " N " << N + << " value " << C[k] << " ref " << ref[k]; + } + } + } + + template + MLAS_FORCEINLINE void Copy(const MLAS_FP16* C, MLAS_FP16* ref, const size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + size_t k = i * ldc + j; + ref[k] = C[k]; + } + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const size_t lda = ((K + 7) & ~7) + 8; + const size_t ldc = ((N + 7) & ~7); + const auto* A = A_.GetFilledBuffer(M * lda, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(K * ((N + 7) & ~7), InitializeBuffer); + auto* C = C_.GetFilledBuffer(M * ldc, InitializeBuffer); + auto* ref = ref_.GetBuffer(M * ldc, true); + Copy(C, ref, ldc); + hgemm_neon::HGemm_PackedB_Kernel(A, B, C, M, N, K, lda, ldc, alpha.val, beta.val); + HGemm(A, B, ref, alpha, beta, lda, ldc); + Check(C, ref, ldc); + } + + public: + MlasNeonHGemmPackedBTest() + : seed_(1928372), gen_(), distrib_(-1.f, 1.f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemmPackedB"; + } + + void ExecuteShort(void) override { + TestHGemm<2, 1, 1>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 1, 1>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 15, 17>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 17, 15>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 17, 15>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 33, 31>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 31, 32>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 32, 33>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 78, 263>(MLAS_FP16(0.5f), MLAS_FP16(0.0f)); + TestHGemm<2, 267, 79>(MLAS_FP16(1.5f), MLAS_FP16(1.0f)); + } +}; + +class MlasNeonHGemmTest : public MlasTestBase { + private: + std::random_device rd_; + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_; + MatrixGuardBuffer A_, B_, ref_, C_; + + template + MLAS_FORCEINLINE void HGemm(const MLAS_FP16* A, const MLAS_FP16* B, MLAS_FP16* C, MLAS_FP16 alpha, MLAS_FP16 beta, + size_t lda, size_t ldb, size_t ldc) { + float alphaf = alpha.ToFloat(); + float betaf = beta.ToFloat(); + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + float accu = 0.0f; + for (size_t k = 0; k < K; ++k) { + accu += (A[transA ? k * lda + i : i * lda + k].ToFloat()) * (B[transB ? j * ldb + k : k * ldb + j].ToFloat()); + } + C[i * ldc + j] = MLAS_FP16(accu * alphaf + C[i * ldc + j].ToFloat() * betaf); + } + } + } + + MLAS_FORCEINLINE + bool FloatEqual(MLAS_FP16 v0, MLAS_FP16 v1, float rtol, float atol) { + float f0 = v0.ToFloat(), f1 = v1.ToFloat(); + return std::abs(f0 - f1) <= std::abs(f1 * rtol) + atol; + } + + template + MLAS_FORCEINLINE void Check(const MLAS_FP16* C, const MLAS_FP16* ref, const size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + ASSERT_TRUE(FloatEqual(C[i * ldc + j], ref[i * ldc + j], 0.02f, 0.055f)) + << " seed " << seed_ << " i " << i << " j " << j + << " M " << M << " K " << K << " N " << N + << " value " << C[i * ldc + j] << " ref " << ref[i * ldc + j]; + } + } + } + + template + MLAS_FORCEINLINE void Copy(const MLAS_FP16* C, MLAS_FP16* ref, const size_t ldc) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + ref[i * ldc + j] = C[i * ldc + j]; + } + } + } + + template + void TestHGemm(MLAS_FP16 alpha, MLAS_FP16 beta) { + auto InitializeBuffer = [this](MLAS_FP16* buffer, size_t count) { + for (size_t i = 0; i < count; i++) { + buffer[i] = MLAS_FP16(distrib_(gen_)); + } + }; + + const size_t lda = transA ? (M + 15) & (~15) : (K + 15) & (~15); + const size_t ldb = transB ? (K + 7) & (~7) : (N + 15) & (~15); + const size_t ldc = (N + 7) & (~7); + const auto* A = A_.GetFilledBuffer(transA ? lda * K : M * lda, InitializeBuffer); + const auto* B = B_.GetFilledBuffer(transB ? ldb * N : K * ldb, InitializeBuffer); + auto* C = C_.GetFilledBuffer(M * ldc, InitializeBuffer); + auto* ref = ref_.GetBuffer(M * ldc, true); + Copy(C, ref, ldc); + MlasGemm(transA ? CblasTrans : CblasNoTrans, transB ? CblasTrans : CblasNoTrans, + M, N, K, A, lda, B, ldb, C, ldc, alpha.val, beta.val, nullptr); + HGemm(A, B, ref, alpha, beta, lda, ldb, ldc); + Check(C, ref, ldc); + } + + public: + MlasNeonHGemmTest() + : seed_(192837), gen_(seed_), distrib_(-0.25f, 0.25f) { + } + + static const char* GetTestSuiteName() { + return "NeonHGemm"; + } + + // TODO(fajin): test beta + void ExecuteShort(void) override { + TestHGemm<2, 1, 1, false, true>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 128, 512, false, true>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 128, 513, false, true>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 128, 511, false, true>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 129, 512, false, true>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 127, 512, false, true>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 513, 1023, false, true>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 511, 1025, false, true>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<127, 513, 1023, false, true>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<129, 511, 1025, false, true>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 1, 1, false, false>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<1, 128, 512, false, false>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 128, 513, false, false>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 128, 511, false, false>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<2, 129, 512, false, false>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<1, 127, 512, false, false>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<1, 513, 1023, false, false>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<2, 511, 1025, false, false>(MLAS_FP16(1.5f), MLAS_FP16(0.5f)); + TestHGemm<127, 513, 1023, false, false>(MLAS_FP16(1.0f), MLAS_FP16(0.0f)); + TestHGemm<129, 511, 1025, false, false>(MLAS_FP16(0.5f), MLAS_FP16(1.0f)); + TestHGemm<129, 513, 1025, false, false>(MLAS_FP16(0.5f), MLAS_FP16(0.5f)); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/tests/unittest/test_rope.cpp b/tests/unittest/test_rope.cpp new file mode 100644 index 0000000..9f08970 --- /dev/null +++ b/tests/unittest/test_rope.cpp @@ -0,0 +1,141 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_rope.h + +Abstract: + + Tests for MLAS RoPE. + +--*/ + +#include "test_util.h" +#include "mlas.h" +#include "core/framework/float16.h" +#include "rotary_embedding.h" + +using namespace onnxruntime; + +template +class MlasRoPETest : public MlasTestBase { + public: + void Test(size_t rotary_emb_dim, bool interleaved) { + std::vector input(rotary_emb_dim); + size_t table_len = interleaved ? rotary_emb_dim / 2 : rotary_emb_dim; + std::vector sin_data(table_len); + std::vector cos_data(table_len); + std::vector output_ref(rotary_emb_dim), output_impl(rotary_emb_dim); + + for (size_t i = 0; i < rotary_emb_dim; ++i) { + input[i] = static_cast(i + 1.0f); + } + for (size_t i = 0; i < table_len; ++i) { + // https://arxiv.org/pdf/2104.09864 section 3.4.3 + float theta_i = static_cast(pow(10000, -2.0f * i / rotary_emb_dim)); + sin_data[i] = static_cast(std::sin(theta_i)); + cos_data[i] = static_cast(std::cos(theta_i)); + } + + // Call the function + MlasRotaryEmbedOneRow_FallBack(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_ref[0]); + MlasRotaryEmbedOneRow(&input[0], &sin_data[0], &cos_data[0], rotary_emb_dim, interleaved, &output_impl[0]); + + for (size_t i = 0; i < rotary_emb_dim; i++) { + ASSERT_TRUE(CloseEnough(output_impl[i], output_ref[i])) + << "Expected: " << output_ref[i] << " Actual: " << output_impl[i] << "@[" << i << "], " + << "rotary_emb_dim=" << rotary_emb_dim << ", interleaved=" << interleaved; + } + } +}; + +// +// Short Execute() test helper to register each test separately by all parameters. +// +template +class RoPEShortExecuteTest : public MlasTestFixture> { + public: + explicit RoPEShortExecuteTest(size_t rotary_emb_dim, bool interleaved) + : rotary_emb_dim_(rotary_emb_dim), + interleaved_(interleaved) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test(rotary_emb_dim_, interleaved_); + } + + static size_t RegisterSingleTest(size_t rotary_emb_dim, bool interleaved) { + size_t tests_registered = 0; + + std::string test_suite_name{"RoPE_"}; + if (std::is_same::value) { + test_suite_name += "fp32"; + } else if (std::is_same::value) { + test_suite_name += "fp16"; + } else { + ADD_FAILURE() << "Unknown type passed to test: " << test_suite_name; + return 0; // Return 0 since no test is registered + } + + std::stringstream ss; + ss << "/rotary_emb_dim" << rotary_emb_dim << "/interleaved" << interleaved; + auto test_name = ss.str(); + + testing::RegisterTest( + test_suite_name.c_str(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new RoPEShortExecuteTest(rotary_emb_dim, interleaved); + }); + + tests_registered += 1; + + return tests_registered; + } + + static size_t RegisterShortExecuteTests() { + size_t tests_registered = 0; + tests_registered += RegisterSingleTest(6, false); + tests_registered += RegisterSingleTest(6, true); + tests_registered += RegisterSingleTest(16, false); + tests_registered += RegisterSingleTest(16, true); + tests_registered += RegisterSingleTest(24, false); + tests_registered += RegisterSingleTest(24, true); + tests_registered += RegisterSingleTest(32, false); + tests_registered += RegisterSingleTest(32, true); + tests_registered += RegisterSingleTest(42, false); + tests_registered += RegisterSingleTest(42, true); + tests_registered += RegisterSingleTest(64, false); + tests_registered += RegisterSingleTest(64, true); + tests_registered += RegisterSingleTest(70, false); + tests_registered += RegisterSingleTest(70, true); + return tests_registered; + } + + private: + size_t rotary_emb_dim_; + bool interleaved_; +}; + +// only test float RoPE with avx2 where RopeDispatch is assigned at this moment. +#ifdef MLAS_TARGET_AMD64 +static size_t RoPERegisterAllShortExecuteTests() { + return RoPEShortExecuteTest::RegisterShortExecuteTests() + RoPEShortExecuteTest::RegisterShortExecuteTests(); +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister( + [](bool is_short_execute) -> size_t { + if (is_short_execute) { + return RoPERegisterAllShortExecuteTests(); + } + return 0; + }); +#endif diff --git a/tests/unittest/test_softcap.cpp b/tests/unittest/test_softcap.cpp new file mode 100644 index 0000000..3ad1d06 --- /dev/null +++ b/tests/unittest/test_softcap.cpp @@ -0,0 +1,112 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" +#include "mlasi.h" +#include "softmax.h" + +class MlasComputeTanhTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInputFp16; + MatrixGuardBuffer BufferOutputFp16; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + void TestFp16(size_t N, float MinimumValue, float MaximumValue) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + Input[n] = MLAS_FP16(distribution(generator)); + } + + MlasComputeTanh(Input, Output, N); + + constexpr float AbsoluteTolerance = 5e-3f; + constexpr float RelativeTolerance = 5e-3f; + + for (size_t n = 0; n < N; n++) { + float in = Input[n].ToFloat(); + float ref = std::tanh(in); + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << in << ", got: " << out << ", expecting: " << ref + << ", diff: " << diff << ", r-diff: " << diff / std::fabs(ref); + } + } +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Tanh"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(n, -3.51562f, 3.51562f); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + } + } +}; + +class MlasComputeSoftcapTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInputFp16; + MatrixGuardBuffer BufferOutputFp16; + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + void TestFp16(size_t N, float MinimumValue, float MaximumValue, float cap) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + Input[n] = MLAS_FP16(distribution(generator)); + } + + MlasComputeSoftcap(Input, Output, N, MLAS_FP16(cap)); + + constexpr float AbsoluteTolerance = 5e-3f; + constexpr float RelativeTolerance = 5e-3f; + + for (size_t n = 0; n < N; n++) { + float in = Input[n].ToFloat(); + float ref = std::tanh(in / cap) * cap; + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << in << ", got: " << out << ", expecting: " << ref << ", r-diff " << diff / std::fabs(ref); + } + } +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Softcap"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(n, -10.f, 10.f, 3.2f); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); diff --git a/tests/unittest/test_softmax.cpp b/tests/unittest/test_softmax.cpp index fb4ebbe..df0c7f6 100644 --- a/tests/unittest/test_softmax.cpp +++ b/tests/unittest/test_softmax.cpp @@ -2,6 +2,126 @@ // Licensed under the MIT License. #include "test_util.h" +#include "mlasi.h" +#include "softmax.h" + +class MlasComputeExpTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + MatrixGuardBuffer BufferInputFp16; + MatrixGuardBuffer BufferOutputFp16; + + void Test(size_t N, float MinimumValue, float MaximumValue) { + float* Input = BufferInput.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + Input[n] = distribution(generator); + } + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = std::exp(Input[n]); + } + + MlasComputeExp(Input, Output, N); + + constexpr float AbsoluteTolerance = 1e-6f; + constexpr float RelativeTolerance = 1e-6f; + + for (size_t n = 0; n < N; n++) { + float diff = std::fabs(Output[n] - OutputReference[n]); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(OutputReference[n]) * RelativeTolerance) + << " @" << n << " of " << N << ", got: " << Output[n] << ", expecting: " << OutputReference[n]; + } + } + +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + void TestFp16(size_t N, float MinimumValue, float MaximumValue) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t n = 0; n < N; n++) { + Input[n] = MLAS_FP16(distribution(generator)); + } + + MlasComputeExp(Input, Output, N); + + constexpr float AbsoluteTolerance = 5e-4f; + constexpr float RelativeTolerance = 1e-3f; + + for (size_t n = 0; n < N; n++) { + float in = Input[n].ToFloat(); + float ref = std::exp(in); + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << in << ", got: " << out << ", expecting: " << ref << ", r-diff: " << diff / std::fabs(ref); + } + } + + void TestSumFp16(size_t N, float MinimumValue, float MaximumValue) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + float max_val = std::numeric_limits::lowest(); + for (size_t n = 0; n < N; n++) { + Input[n] = MLAS_FP16(distribution(generator)); + max_val = std::fmax(max_val, Input[n].ToFloat()); + } + + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + auto sum = dispatch->SumExp_Fp16(Input, Output, N, MLAS_FP16(-max_val)); + + constexpr float AbsoluteTolerance = 5e-4f; + constexpr float RelativeTolerance = 1e-3f; + + float sum_ref = 0.0f; + for (size_t n = 0; n < N; n++) { + float in = Input[n].ToFloat(); + float ref = std::exp(in - max_val); + sum_ref += ref; + float out = Output[n].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << " @ " << in << ", got: " << out << ", expecting: " << ref << ", r-diff: " << diff / std::fabs(ref); + } + + float diff = std::fabs(sum.ToFloat() - sum_ref); + ASSERT_TRUE(diff <= 1e-3f || diff <= std::fabs(sum_ref) * 5e-3f) + << " sum: " << sum.ToFloat() << ", expecting: " << sum_ref << ", r-diff: " << diff / std::fabs(sum_ref); + } + +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + + public: + static const char* GetTestSuiteName() { + static const std::string suite_name("Exp"); + return suite_name.c_str(); + } + + void ExecuteShort(void) override { + for (size_t n = 1; n < 128; n++) { + Test(n, -10.f, 10.f); +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(n, -17.f, 11.f); + TestSumFp16(n, -10.f, 10.f); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + } + } +}; template class MlasSoftmaxTest : public MlasTestBase { @@ -9,6 +129,8 @@ class MlasSoftmaxTest : public MlasTestBase { MatrixGuardBuffer BufferInput; MatrixGuardBuffer BufferOutput; MatrixGuardBuffer BufferOutputReference; + MatrixGuardBuffer BufferInputFp16; + MatrixGuardBuffer BufferOutputFp16; MLAS_THREADPOOL* threadpool_; void Test(size_t N, size_t D, float MinimumValue, float MaximumValue) { @@ -44,6 +166,64 @@ class MlasSoftmaxTest : public MlasTestBase { } } +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + void TestReduceMaxFp16(size_t N, float MinimumValue, float MaximumValue) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + float ref = std::numeric_limits::lowest(); + + for (size_t nd = 0; nd < N; nd++) { + Input[nd] = MLAS_FP16(distribution(generator)); + ref = std::fmax(ref, Input[nd].ToFloat()); + } + + const auto* dispatch = GetMlasPlatform().SoftmaxDispatch; + auto out = dispatch->ReduceMax_Fp16(Input, N).ToFloat(); + + constexpr float AbsoluteTolerance = 1e-3f; + constexpr float RelativeTolerance = 1e-3f; + + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << "ReduceMaxFp16: " << N << ", got: " << out << ", expecting: " << ref + << ", diff: " << diff << ", r-diff: " << diff / std::fabs(ref); + } + + void TestFp16(size_t N, size_t D, float MinimumValue, float MaximumValue, bool LogSoftmax, bool SmoothSoftmax) { + MLAS_FP16* Input = BufferInputFp16.GetBuffer(N * D); + MLAS_FP16* Output = BufferOutputFp16.GetBuffer(N * D); + float* InputReference = BufferInput.GetBuffer(N * D); + float* OutputReference = BufferOutputReference.GetBuffer(N * D); + + std::default_random_engine generator(static_cast(N * D)); + std::uniform_real_distribution distribution(MinimumValue, MaximumValue); + + for (size_t nd = 0; nd < N * D; nd++) { + Input[nd] = MLAS_FP16(distribution(generator)); + InputReference[nd] = Input[nd].ToFloat(); + } + + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + ReferenceSoftmax(InputReference, OutputReference, N, D, LogSoftmax, SmoothSoftmax); + + constexpr float AbsoluteTolerance = 5e-3f; + constexpr float RelativeTolerance = 5e-3f; + + for (size_t nd = 0; nd < N * D; nd++) { + float in = Input[nd].ToFloat(); + float ref = OutputReference[nd]; + float out = Output[nd].ToFloat(); + float diff = std::fabs(out - ref); + ASSERT_TRUE(diff <= AbsoluteTolerance || diff <= std::fabs(ref) * RelativeTolerance) + << "LogSoftmax:" << LogSoftmax << ", SmoothSoftmax: " << SmoothSoftmax << ", input " << in + << ", got: " << out << ", expecting: " << ref << ", diff: " << diff << ", r-diff: " << diff / std::fabs(ref); + } + } +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + void ReferenceSoftmax(const float* Input, float* Output, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { for (size_t n = 0; n < N; n++) { float MaximumValue = std::numeric_limits::lowest(); @@ -99,11 +279,32 @@ class MlasSoftmaxTest : public MlasTestBase { void ExecuteShort(void) override { for (size_t d = 1; d < 128; d++) { Test(1, d, -10.f, 10.f); +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestReduceMaxFp16(d, -10.f, 10.f); + TestFp16(1, d, -10.f, 10.f, false, true); + TestFp16(1, d, -10.f, 10.f, true, true); + TestFp16(1, d, -10.f, 10.f, false, false); + TestFp16(1, d, -10.f, 10.f, true, false); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) } Test(3, 128, 20.f, 30.f); Test(63, 95, -150.f, 190.f); Test(16, 211, 20.f, 30.f); +#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) + TestFp16(3, 128, 3.f, 7.f, false, true); + TestFp16(3, 128, 3.f, 7.f, true, true); + TestFp16(3, 128, 3.f, 7.f, false, false); + TestFp16(3, 128, 3.f, 7.f, true, false); + TestFp16(63, 95, -15.f, 19.f, false, true); + TestFp16(63, 95, -15.f, 19.f, true, true); + TestFp16(63, 95, -15.f, 19.f, false, false); + TestFp16(63, 95, -15.f, 19.f, true, false); + TestFp16(16, 211, -7.f, -3.f, false, true); + TestFp16(16, 211, -7.f, -3.f, true, true); + TestFp16(16, 211, -7.f, -3.f, false, false); + TestFp16(16, 211, -7.f, -3.f, true, false); +#endif // defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) } }; @@ -111,6 +312,7 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe size_t count = 0; if (is_short_execute) { count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); if (GetMlasThreadPool() != nullptr) { count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } diff --git a/tests/unittest/test_sq8bitgemm.cpp b/tests/unittest/test_sq8bitgemm.cpp new file mode 100644 index 0000000..1237cdb --- /dev/null +++ b/tests/unittest/test_sq8bitgemm.cpp @@ -0,0 +1,521 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + test_sq8bitgemm_neon.cpp + +Abstract: + + Tests for MatMul8Bits kernels on x86 CPU with input A type T1 fp32. + +--*/ + +#include +#include + +#include "test_util.h" +#include "mlasi.h" +#include "core/mlas/inc/mlas_q4.h" +#include "qnbitgemm.h" +#include "mlas_qnbit.h" + +class MlasSQ8BitPrepackTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution distrib_u8_; + std::uniform_real_distribution distrib_f32_; + MatrixGuardBuffer inputB_, inputZp_, refB_, packedBuffer_; + MatrixGuardBuffer inputScale_, refScale_; + MatrixGuardBuffer inputBlkSum_, refBlkSum_; + + template + void PrepackB(const uint8_t* src, uint8_t* dst) { + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + size_t n = 0; + for (; n + 4 <= N; n += 4) { + size_t k = 0; + for (; k + SubBlkLen <= ldb; k += SubBlkLen) { + for (size_t i = 0; i < 4; ++i) { + std::copy(src + (n + i) * ldb + k, src + (n + i) * ldb + k + SubBlkLen, dst + n * ldb + 4 * k + i * SubBlkLen); + } + } + + for (size_t kk = 0; kk + k + BlkLen <= ldb; kk += BlkLen) { + for (size_t i = 0; i < 4; ++i) { + std::copy(src + (n + i) * ldb + k + kk, src + (n + i) * ldb + k + kk + BlkLen, dst + n * ldb + 4 * k + 4 * kk + i * BlkLen); + } + } + } + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations" +#endif + for (; n < N; ++n) { + std::copy(src + n * ldb, src + n * ldb + ldb, dst + n * ldb); + } +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + } + + template + void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t BlkPerSubBlk = SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1; + + size_t n = 0; + for (; n + 4 <= N; n += 4) { + size_t k = 0; + for (; k + BlkPerSubBlk <= BlkCount; k += BlkPerSubBlk) { + for (size_t i = 0; i < 4; ++i) { + for (size_t j = 0; j < BlkPerSubBlk; ++j) { + auto srcOffset = (n + i) * BlkCount + k + j; + auto scaleDstOffset = n * BlkCount + 4 * k + i * BlkPerSubBlk + j; + auto sumDstOffset = (((n + i) / 16) * BlkCount + k + j) * 16 + (n + i) % 16; + + auto vSum = -scale[srcOffset] * (zp ? static_cast(zp[srcOffset]) : 128.f); + + packedScale[scaleDstOffset] = scale[srcOffset]; + blkSum[sumDstOffset] = vSum; + } + } + } + for (size_t kk = 0; k + kk < BlkCount; ++kk) { + for (size_t i = 0; i < 4; ++i) { + auto srcOffset = (n + i) * BlkCount + k + kk; + auto scaleDstOffset = n * BlkCount + 4 * k + 4 * kk + i; + auto sumDstOffset = (((n + i) / 16) * BlkCount + k + kk) * 16 + (n + i) % 16; + + auto vSum = -scale[srcOffset] * (zp ? static_cast(zp[srcOffset]) : 128.f); + + packedScale[scaleDstOffset] = scale[srcOffset]; + blkSum[sumDstOffset] = vSum; + } + } + } + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Waggressive-loop-optimizations" +#endif + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + auto srcOffset = n * BlkCount + k; + auto scaleDstOffset = n * BlkCount + k; + auto sumDstOffset = (((n) / 16) * BlkCount + k) * 16 + (n) % 16; + + auto vSum = -scale[srcOffset] * (zp ? static_cast(zp[srcOffset]) : 128.f); + + packedScale[scaleDstOffset] = scale[srcOffset]; + blkSum[sumDstOffset] = vSum; + } + } +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + } + + template + void CheckB(const uint8_t* packedB, const uint8_t* refB) { + size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + size_t n = 0, N4 = N & (~3), ldbSub = ldb & (~(SubBlkLen - 1)); + for (; n < N4; ++n) { + size_t k = 0; + for (; k < ldbSub && k < K; ++k) { + size_t idx = (n & (~3)) * ldb + (k & (~(SubBlkLen - 1))) * 4 + (n & 3) * SubBlkLen + (k & (SubBlkLen - 1)); + ASSERT_EQ(packedB[idx], refB[idx]) + << " n " << n << " k " << k; + } + for (; k < K; ++k) { + size_t idx = (n & (~3)) * ldb + (k & (~(BlkLen - 1))) * 4 + (n & 3) * BlkLen + (k & (BlkLen - 1)); + ASSERT_EQ(packedB[idx], refB[idx]) + << " n " << n << " k " << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < K; ++k) { + ASSERT_EQ(packedB[n * ldb + k], refB[n * ldb + k]) + << " n " << n << " k " << k; + } + } + } + + template + void CheckScale(const float* packedScale, const float* refScale) { + size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t BlkPerSubBlk = SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1; + size_t n = 0, N4 = N & (~3), BlkCountSub = BlkCount & (~(BlkPerSubBlk - 1)); + + for (; n < N4; ++n) { + size_t k = 0; + for (; k < BlkCountSub; ++k) { + size_t idx = (n & (~3)) * BlkCount + (k & (~(BlkPerSubBlk - 1))) * 4 + (n & 3) * BlkPerSubBlk + (k & (BlkPerSubBlk - 1)); + ASSERT_EQ(packedScale[idx], refScale[idx]) + << " n " << n << " k " << k; + } + for (; k < BlkCount; ++k) { + size_t idx = (n & (~3)) * BlkCount + k * 4 + (n & 3); + ASSERT_EQ(packedScale[idx], refScale[idx]) + << " n " << n << " k " << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + ASSERT_EQ(packedScale[n * BlkCount + k], refScale[n * BlkCount + k]) + << " n " << n << " k " << k; + } + } + } + + template + void CheckBlkSum(const float* packedBlkSum, const float* refBlkSum) { + size_t BlkCount = (K + BlkLen - 1) / BlkLen; + + for (size_t n = 0; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = (((n) / 16) * BlkCount + k) * 16 + (n) % 16; + ASSERT_EQ(packedBlkSum[idx], refBlkSum[idx]) + << " n " << n << " k " << k; + } + } + } + + template + void TestPrepack() { + if (!MlasIsQNBitGemmAvailable(8, BlkLen, SQNBIT_CompInt8)) return; + + constexpr size_t Bits = 8; + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t Ldb = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t PackBCount = N * Ldb; + constexpr size_t ScaleCount = BlkCount * N; + const size_t BufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, Bits, BlkLen, hasZp, SQNBIT_CompInt8); + + const auto* inputB = inputB_.GetFilledBuffer(PackBCount, [this](uint8_t* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = static_cast(this->distrib_u8_(this->gen_)); + } + }); + + const auto* inputScale = inputScale_.GetFilledBuffer(ScaleCount, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + const auto* inputZp = hasZp ? inputZp_.GetFilledBuffer(ScaleCount, [this](uint8_t* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = static_cast(this->distrib_u8_(this->gen_)); + } + }) + : nullptr; + + auto* packedBuffer = packedBuffer_.GetBuffer(BufferSize, true); + auto* refB = refB_.GetBuffer(PackBCount, true); + auto* refScale = refScale_.GetBuffer(ScaleCount, true); + auto* refBlkSum = refBlkSum_.GetBuffer(((N + 15) & (~15)) * BlkCount, true); + + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, + inputScale, hasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + inputScale, hasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( + N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + nullptr, hasZp, inputZp, nullptr); + + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen); + + PrepackB(inputB, refB); + PrepackBlkSumAndScale(inputScale, inputZp, refScale, refBlkSum); + + CheckB(refB, reinterpret_cast(packedQuantB.PackedQuantBData)); + CheckScale(refScale, packedQuantB.PackedQuantBScale); + CheckBlkSum(refBlkSum, packedQuantB.QuantBBlkSum); + } + + public: + MlasSQ8BitPrepackTest() + : seed_(19287), gen_(seed_), distrib_u8_(0, 255), distrib_f32_(-10.f, 10.f) { + } + + static const char* GetTestSuiteName() { + return "SQ8BitPrepack"; + } + + template + void Execute(void) { + TestPrepack(); + TestPrepack(); + } + + void ExecuteShort(void) override { + auto& platform = GetMlasPlatform(); + + if (platform.Avx512Supported_) { + Execute<1, 1, 16, 128>(); + Execute<1, 1, 32, 128>(); + Execute<1, 1, 64, 128>(); + Execute<1, 1, 128, 128>(); + Execute<1, 1, 256, 128>(); + + Execute<16, 4, 16, 128>(); + Execute<32, 4, 16, 128>(); + Execute<64, 4, 16, 128>(); + Execute<128, 4, 16, 128>(); + + Execute<15, 5, 16, 128>(); + Execute<15, 5, 32, 128>(); + Execute<15, 5, 64, 128>(); + Execute<15, 5, 128, 128>(); + Execute<15, 5, 256, 128>(); + + Execute<17, 8, 16, 128>(); + Execute<17, 8, 32, 128>(); + Execute<17, 8, 64, 128>(); + Execute<17, 8, 128, 128>(); + Execute<17, 8, 256, 128>(); + + Execute<256, 16, 16, 128>(); + Execute<257, 17, 32, 128>(); + Execute<255, 15, 64, 128>(); + Execute<256, 17, 128, 128>(); + Execute<257, 16, 256, 128>(); + } else { + Execute<1, 1, 16, 64>(); + Execute<1, 1, 32, 64>(); + Execute<1, 1, 64, 64>(); + Execute<1, 1, 128, 64>(); + Execute<1, 1, 256, 64>(); + + Execute<16, 4, 16, 64>(); + Execute<32, 4, 16, 64>(); + Execute<64, 4, 16, 64>(); + Execute<128, 4, 16, 64>(); + + Execute<15, 5, 16, 64>(); + Execute<15, 5, 32, 64>(); + Execute<15, 5, 64, 64>(); + Execute<15, 5, 128, 64>(); + Execute<15, 5, 256, 64>(); + + Execute<17, 8, 16, 64>(); + Execute<17, 8, 32, 64>(); + Execute<17, 8, 64, 64>(); + Execute<17, 8, 128, 64>(); + Execute<17, 8, 256, 64>(); + + Execute<159, 16, 16, 64>(); + Execute<160, 17, 32, 64>(); + Execute<161, 15, 64, 64>(); + Execute<160, 17, 128, 64>(); + Execute<159, 16, 256, 64>(); + } + } +}; + +class MlasSQ8BitGemmKernelTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_real_distribution distrib_f32_; + MatrixGuardBuffer packedBuffer_, workspace_, packedB_, Zp_; + MatrixGuardBuffer A_, B_, C_, ref_, bias_, scale_; + + bool FloatEqual(float v0, float v1, float rtol, float atol) { + return std::abs(v0 - v1) <= std::abs(v1 * rtol) + atol; + } + + template + void MatMul(const float* A, size_t lda, const float* B, const float* bias, float* C, size_t ldc) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float accu = bias ? bias[n] : 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[m * lda + k]; + float b = B[n * K + k]; + accu += a * b; + } + C[m * ldc + n] = accu; + } + } + } + + template + void Check(const float* target, const float* ref, size_t ldc, float rtol, float atol) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + size_t i = m * ldc + n; + ASSERT_TRUE(FloatEqual(target[i], ref[i], rtol, atol)) + << " M " << M << " K " << K << " N " << N << " BlkLen " << BlkLen + << " v0 " << target[i] << " v1 " << ref[i] + << " m " << m << " n " << n; + } + } + } + + template + void TestSQ8BitGemmKernel() { + if (!MlasIsQNBitGemmAvailable(8, BlkLen, SQNBIT_CompInt8)) return; + + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t ldb = BlkCount * BlkLen; + constexpr size_t lda = ldb; + constexpr size_t ldc = (N + 15) & (~15); + const auto* A = A_.GetFilledBuffer(M * lda, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + auto* B = B_.GetFilledBuffer(K * N, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + int q_rows, q_cols; + MlasBlockwiseQuantizedShape((int)BlkLen, true, (int)K, (int)N, q_rows, q_cols); + + size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; + MlasBlockwiseQuantizedBufferSizes<8>((int)(BlkLen), true, (int)K, (int)N, + q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); + + auto* inputB = packedB_.GetBuffer(q_data_size_in_bytes, true); + auto* inputScale = scale_.GetBuffer(q_scale_size, true); + auto* inputZp = HasZp ? Zp_.GetBuffer(q_zp_size_in_bytes, true) : nullptr; + + MlasQuantizeBlockwise( + inputB, + inputScale, + inputZp, + B, + BlkLen, + true, + K, + N, + N, + nullptr); + + MlasDequantizeBlockwise( + B, + inputB, + inputScale, + inputZp, + BlkLen, + true, + K, + N, + nullptr); + + size_t bufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, 8, BlkLen, HasZp, SQNBIT_CompInt8); + auto* packedBuffer = packedBuffer_.GetBuffer(bufferSize, true); + + MlasQNBitGemmPackQuantBData( + N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, + inputScale, HasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( + N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + inputScale, HasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( + N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, + nullptr, HasZp, inputZp, nullptr); + + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen); + + auto* C = C_.GetBuffer(M * ldc, true); + auto* ref = ref_.GetBuffer(M * ldc, true); + + auto* bias = HasBias ? bias_.GetFilledBuffer(N, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }) + : nullptr; + + const size_t workspace_size = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, 8, BlkLen, HasZp, SQNBIT_CompInt8); + auto* workspace = workspace_.GetBuffer(workspace_size, true); + + MLAS_QNBIT_GEMM_DATA_PARAMS data; + data.A = A; + data.lda = lda; + data.QuantBDataWorkspace = packedBuffer; + data.PackedQuantBData = packedQuantB.PackedQuantBData; + data.QuantBScale = inputScale; + data.QuantBZeroPoint = inputZp; + data.Bias = bias; + data.C = C; + data.ldc = ldc; + + MlasQNBitGemmBatch(M, N, K, 1, 8, BlkLen, SQNBIT_CompInt8, &data, workspace, nullptr); + + MatMul(A, lda, B, bias, ref, ldc); + Check(C, ref, ldc, 0.01f, 0.02f); + } + + public: + MlasSQ8BitGemmKernelTest() + : seed_(1234), gen_(seed_), distrib_f32_(-0.25f, 0.25f) { + } + + static const char* GetTestSuiteName() { + return "SQ8BitGemmKernel"; + } + + template + void Execute(void) { + TestSQ8BitGemmKernel(); + TestSQ8BitGemmKernel(); + TestSQ8BitGemmKernel(); + TestSQ8BitGemmKernel(); + } + + void ExecuteShort(void) override { + Execute<1, 1, 1, 16>(); + Execute<7, 128, 4, 16>(); + Execute<8, 497, 5, 16>(); + Execute<1, 3072, 128, 16>(); + Execute<2, 3072, 128, 16>(); + + Execute<1, 1, 1, 32>(); + Execute<8, 33, 5, 32>(); + Execute<8, 513, 9, 32>(); + Execute<1, 3072, 128, 32>(); + Execute<2, 3072, 128, 32>(); + + Execute<1, 1, 1, 64>(); + Execute<8, 497, 9, 64>(); + Execute<1, 3072, 128, 64>(); + Execute<2, 3072, 128, 64>(); + + Execute<1, 1, 1, 128>(); + Execute<6, 255, 7, 128>(); + Execute<5, 257, 9, 128>(); + Execute<1, 3072, 128, 128>(); + Execute<2, 3072, 128, 128>(); + + Execute<1, 1, 1, 256>(); + Execute<7, 255, 7, 256>(); + Execute<6, 257, 7, 256>(); + Execute<1, 3072, 128, 256>(); + Execute<2, 3072, 128, 256>(); + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); + } + return count; +}); diff --git a/tests/unittest/test_sqnbitgemm.cpp b/tests/unittest/test_sqnbitgemm.cpp index e22018a..91ce359 100644 --- a/tests/unittest/test_sqnbitgemm.cpp +++ b/tests/unittest/test_sqnbitgemm.cpp @@ -246,9 +246,9 @@ class MlasSQNBitGemmTest : public MlasTestBase { uint8_t* QuantBZeroPoint = nullptr; { size_t QuantBDataSizeInBytes, QuantBScaleSize, QuantBZeroPointSizeInBytes; - MlasBlockwiseQuantizedBufferSizes(BlkBitWidth, BlkLen, /* columnwise */ true, - static_cast(K), static_cast(N), - QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); + MlasBlockwiseQuantizedBufferSizes(BlkLen, /* columnwise */ true, + static_cast(K), static_cast(N), + QuantBDataSizeInBytes, QuantBScaleSize, &QuantBZeroPointSizeInBytes); QuantBData = BufferQuantBData.GetBuffer(QuantBDataSizeInBytes); QuantBScale = BufferQuantBScale.GetBuffer(QuantBScaleSize); @@ -265,13 +265,13 @@ class MlasSQNBitGemmTest : public MlasTestBase { } void* Workspace = nullptr; - if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, ComputeType); + if (const auto WorkspaceSize = MlasQNBitGemmBatchWorkspaceSize(M, N, K, 1, BlkBitWidth, BlkLen, !Symmetric, ComputeType); WorkspaceSize > 0) { Workspace = BufferWorkspace.GetBuffer(WorkspaceSize); } void* PackedQuantBDataWorkspace = nullptr; - if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, ComputeType); + if (const auto PackedQuantBDataSize = MlasQNBitGemmPackQuantBDataSize(N, K, BlkBitWidth, BlkLen, !Symmetric, ComputeType); PackedQuantBDataSize > 0) { PackedQuantBDataWorkspace = BufferPackedQuantBData.GetBuffer(PackedQuantBDataSize); bool has_zp_input = QuantBZeroPoint != nullptr; diff --git a/tests/unittest/test_sqnbitgemm_neon_fp16.cpp b/tests/unittest/test_sqnbitgemm_neon_fp16.cpp index 243752b..772c4a9 100644 --- a/tests/unittest/test_sqnbitgemm_neon_fp16.cpp +++ b/tests/unittest/test_sqnbitgemm_neon_fp16.cpp @@ -17,7 +17,7 @@ Module Name: #include #include "test_util.h" -#include "core/mlas/lib/mlasi.h" +#include "mlasi.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/tests/unittest/test_transpose.cpp b/tests/unittest/test_transpose.cpp index 8fa9841..2f446ec 100644 --- a/tests/unittest/test_transpose.cpp +++ b/tests/unittest/test_transpose.cpp @@ -3,12 +3,13 @@ #include "test_util.h" -template +template class MlasTransposeTest : public MlasTestBase { private: MatrixGuardBuffer BufferInput; MatrixGuardBuffer BufferOutput; MatrixGuardBuffer BufferOutputReference; + MLAS_THREADPOOL* threadpool_; void Test(size_t M, size_t N) { @@ -16,7 +17,7 @@ class MlasTransposeTest : public MlasTestBase { ElementType* Output = BufferOutput.GetBuffer(M * N); ElementType* OutputReference = BufferOutputReference.GetBuffer(M * N); - MlasTranspose(Input, Output, M, N); + MlasTranspose(Input, Output, M, N, threadpool_); ReferenceTranspose(Input, OutputReference, M, N); ASSERT_EQ(memcmp(Output, OutputReference, M * N * sizeof(ElementType)), 0) << " [" << M << "," << N << "]"; @@ -31,11 +32,23 @@ class MlasTransposeTest : public MlasTestBase { } public: + MlasTransposeTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} + static const char* GetTestSuiteName() { - static const std::string suite_name = std::string("Transpose_Size") + std::to_string(int(sizeof(ElementType))); + static const std::string suite_name = std::string("Transpose_") + + GetTypeString() + + std::string(Threaded ? "_Threaded" : "_SingleThread"); return suite_name.c_str(); } + static const std::string GetTypeString() { + if (std::is_same::value) return std::string("FP32"); + if (std::is_same::value) return std::string("U32"); + if (std::is_same::value) return std::string("U16"); + if (std::is_same::value) return std::string("U8"); + return std::string("unknown"); + } + void ExecuteShort(void) override { for (size_t m = 1; m <= 32; m++) { for (size_t n = 1; n <= 32; n++) { @@ -48,9 +61,14 @@ class MlasTransposeTest : public MlasTestBase { static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { size_t count = 0; if (is_short_execute) { - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); - count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); } return count; }); diff --git a/tests/unittest/test_util.h b/tests/unittest/test_util.h index 94c50c9..a000e35 100644 --- a/tests/unittest/test_util.h +++ b/tests/unittest/test_util.h @@ -14,8 +14,6 @@ #include #include #include -#include - #if defined(_WIN32) #include #else @@ -37,40 +35,6 @@ #define _countof(_Array) (sizeof(_Array) / sizeof(_Array[0])) #endif -#ifndef MLAS_THROW_EX -#ifdef MLAS_NO_EXCEPTION - -MLAS_FORCEINLINE -void -MlasPrintFinalMessage(const std::string& msg) -{ -#if defined(__ANDROID__) - __android_log_print(ANDROID_LOG_ERROR, "mlas", "%s", msg.c_str()); -#else - // TODO, consider changing the output of the error message from std::cerr to logging when the - // exceptions are disabled, since using std::cerr might increase binary size, and std::cerr - // output might not be easily accesible on some systems such as mobile - // TODO, see if we need to change the output of the error message from std::cerr to NSLog for - // iOS - std::cerr << msg << std::endl; -#endif -} - -#define MLAS_THROW_EX(ex, what) \ - do { \ - std::string msg = #ex; \ - msg.append(what); \ - MlasPrintFinalMessage(msg); \ - abort(); \ - } while (false) - -#else - -#define MLAS_THROW_EX(ex, ...) throw ex(__VA_ARGS__) - -#endif // MLAS_NO_EXCEPTION -#endif - MLAS_THREADPOOL* GetMlasThreadPool(void); template @@ -123,7 +87,7 @@ class MatrixGuardBuffer { #if defined(_WIN32) if (VirtualAlloc(_BaseBuffer, BytesToAllocate, MEM_COMMIT, PAGE_READWRITE) == nullptr) { - MLAS_THROW_EX(std::bad_alloc); + ORT_THROW_EX(std::bad_alloc); } #else if (mprotect(_BaseBuffer, BytesToAllocate, PROT_READ | PROT_WRITE) != 0) { From 81d121a7dc727645ad06dd7c08e23022df38ed6c Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 09:06:37 -0700 Subject: [PATCH 12/33] update --- src/ort_include/core/platform/EigenNonBlockingThreadPool.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h index 263fa35..45b1751 100644 --- a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h +++ b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h @@ -10,7 +10,7 @@ /* Modifications Copyright (c) Microsoft. */ #include - +#include #pragma once #include "onnxruntime_config.h" // build/external/eigen/unsupported/Eigen/CXX11/src/Tensor/TensorEvaluator.h:162:71: From db693b08bc149147ccc913a62d7f895504ec7f52 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 09:50:16 -0700 Subject: [PATCH 13/33] update --- CMakeLists.txt | 55 +++++++++- src/core/platform/posix/env.cc | 73 +------------ src/core/platform/windows/env.cc | 157 ++-------------------------- src/core/platform/windows/env.h | 25 +---- src/lib/CMakeLists.txt | 6 +- src/ort_include/core/platform/env.h | 29 +---- tests/unittest/CMakeLists.txt | 2 +- 7 files changed, 73 insertions(+), 274 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a3c0bfc..66bddec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -13,7 +13,7 @@ cmake_policy(SET CMP0091 NEW) cmake_policy(SET CMP0117 NEW) # Project -project(MLAS C CXX ASM) +project(MLAS C CXX) include(CheckCXXCompilerFlag) @@ -43,9 +43,60 @@ set(ONNXRUNTIME_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/src) set(ONNXRUNTIME_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) option(MLAS_ENABLE_WEBASSEMBLY_THREADS "Enable this option to create WebAssembly byte codes with multi-threads support" OFF) option(MLAS_ENABLE_WEBASSEMBLY_BROWSER_TESTS "Build all executables as html files" OFF) -option(MLAS_NO_ONNXRUNTIME "Disable ORT related code" OFF) +option(MLAS_NO_ONNXRUNTIME "Disable ONNX Runtime related code as much as possible" OFF) option(MLAS_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING "Enable this option to turn on exception catching" OFF) +if (MSVC) + # Make sure Visual Studio sets __cplusplus macro correctly: https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus") + + if (CMAKE_VS_PLATFORM_NAME) + # Multi-platform generator + set(onnxruntime_target_platform ${CMAKE_VS_PLATFORM_NAME}) + else() + set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) + endif() + if (onnxruntime_target_platform STREQUAL "ARM64") + set(onnxruntime_target_platform "ARM64") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM64EC") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "ARM" OR CMAKE_GENERATOR MATCHES "ARM") + set(onnxruntime_target_platform "ARM") + enable_language(ASM_MARMASM) + elseif (onnxruntime_target_platform STREQUAL "x64" OR onnxruntime_target_platform STREQUAL "x86_64" OR onnxruntime_target_platform STREQUAL "AMD64" OR CMAKE_GENERATOR MATCHES "Win64") + set(onnxruntime_target_platform "x64") + enable_language(ASM_MASM) + elseif (onnxruntime_target_platform STREQUAL "Win32" OR onnxruntime_target_platform STREQUAL "x86" OR onnxruntime_target_platform STREQUAL "i386" OR onnxruntime_target_platform STREQUAL "i686") + set(onnxruntime_target_platform "x86") + enable_language(ASM_MASM) + message("Enabling SAFESEH for x86 build") + set(CMAKE_ASM_MASM_FLAGS "${CMAKE_ASM_MASM_FLAGS} /safeseh") + else() + message(FATAL_ERROR "Unknown CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") + endif() +else() + set(onnxruntime_target_platform ${CMAKE_SYSTEM_PROCESSOR}) + enable_language(ASM) +endif() + +function(onnxruntime_configure_target target_name) + target_compile_options(${target_name} PRIVATE + $<$,$>:/MP> + $<$,$>:/MP> + ) +endfunction() + +function(onnxruntime_add_executable target_name) + add_executable(${target_name} ${ARGN}) + onnxruntime_configure_target(${target_name}) +endfunction() + +function(onnxruntime_add_static_library target_name) + add_library(${target_name} STATIC ${ARGN}) + onnxruntime_configure_target(${target_name}) +endfunction() + if(MLAS_ENABLE_WEBASSEMBLY_BROWSER_TESTS) #The variable cannot be set from cmake command line because otherwise emscripten's toolchain file will override it. set(CMAKE_EXECUTABLE_SUFFIX ".html") diff --git a/src/core/platform/posix/env.cc b/src/core/platform/posix/env.cc index 2bb17f3..47b12f2 100644 --- a/src/core/platform/posix/env.cc +++ b/src/core/platform/posix/env.cc @@ -62,28 +62,13 @@ namespace { constexpr int OneMillion = 1000000; -class UnmapFileParam { - public: - void* addr; - size_t len; -}; - -static void UnmapFile(void* param) noexcept { - std::unique_ptr p(reinterpret_cast(param)); - int ret = munmap(p->addr, p->len); - if (ret != 0) { - auto [err_no, err_msg] = GetErrnoInfo(); - LOGS_DEFAULT(ERROR) << "munmap failed. error code: " << err_no << " error msg: " << err_msg; - } -} - struct FileDescriptorTraits { using Handle = int; static Handle GetInvalidHandleValue() { return -1; } static void CleanUp(Handle h) { if (close(h) == -1) { auto [err_no, err_msg] = GetErrnoInfo(); - LOGS_DEFAULT(ERROR) << "Failed to close file descriptor " << h << " - error code: " << err_no + std::cout << "Failed to close file descriptor " << h << " - error code: " << err_no << " error msg: " << err_msg; } } @@ -207,7 +192,7 @@ class PosixThread : public EnvThread { } else { // Logical processor id starts from 0 internally, but in ort API, it starts from 1, // that's why id need to increase by 1 when logging. - LOGS_DEFAULT(ERROR) << "cpu " << id + 1 << " does not exist, skipping it for affinity setting"; + std::cout << "cpu " << id + 1 << " does not exist, skipping it for affinity setting"; } } auto ret = pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); @@ -219,7 +204,7 @@ class PosixThread : public EnvThread { errno = ret; auto [err_no, err_msg] = GetErrnoInfo(); #if !defined(USE_MIGRAPHX) - LOGS_DEFAULT(ERROR) << "pthread_setaffinity_np failed for thread: " << syscall(SYS_gettid) + std::cout << "pthread_setaffinity_np failed for thread: " << syscall(SYS_gettid) << ", index: " << p->index << ", mask: " << *p->affinity << ", error code: " << err_no << " error msg: " << err_msg @@ -429,58 +414,6 @@ class PosixEnv : public Env { return Status::OK(); } - common::Status LoadDynamicLibrary(const PathString& library_filename, bool global_symbols, void** handle) const override { - dlerror(); // clear any old error_str - *handle = dlopen(library_filename.c_str(), RTLD_NOW | (global_symbols ? RTLD_GLOBAL : RTLD_LOCAL)); - char* error_str = dlerror(); - if (!*handle) { - return common::Status(common::ONNXRUNTIME, common::FAIL, - "Failed to load library " + library_filename + " with error: " + error_str); - } - return common::Status::OK(); - } - - common::Status UnloadDynamicLibrary(void* handle) const override { - if (!handle) { - return common::Status(common::ONNXRUNTIME, common::FAIL, "Got null library handle"); - } - dlerror(); // clear any old error_str - int retval = dlclose(handle); - char* error_str = dlerror(); - if (retval != 0) { - return common::Status(common::ONNXRUNTIME, common::FAIL, - "Failed to unload library with error: " + std::string(error_str)); - } - return common::Status::OK(); - } - - common::Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const override { - dlerror(); // clear any old error str - - // search global space if handle is nullptr. - // value of RTLD_DEFAULT differs across posix platforms (-2 on macos, 0 on linux). - handle = handle ? handle : RTLD_DEFAULT; - *symbol = dlsym(handle, symbol_name.c_str()); - - char* error_str = dlerror(); - if (error_str) { - return common::Status(common::ONNXRUNTIME, common::FAIL, - "Failed to get symbol " + symbol_name + " with error: " + error_str); - } - // it's possible to get a NULL symbol in our case when Schemas are not custom. - return common::Status::OK(); - } - - std::string FormatLibraryFileName(const std::string& name, const std::string& version) const override { - std::string filename; - if (version.empty()) { - filename = "lib" + name + ".so"; - } else { - filename = "lib" + name + ".so" + "." + version; - } - return filename; - } - // \brief returns a value for the queried variable name (var_name) std::string GetEnvironmentVar(const std::string& var_name) const override { char* val = getenv(var_name.c_str()); diff --git a/src/core/platform/windows/env.cc b/src/core/platform/windows/env.cc index 300acc9..dc07992 100644 --- a/src/core/platform/windows/env.cc +++ b/src/core/platform/windows/env.cc @@ -32,34 +32,16 @@ limitations under the License. #include "core/common/span_utils.h" #include "core/platform/env.h" #include "core/platform/scoped_resource.h" -#if defined(_M_X64) && !defined(_M_ARM64EC) -#include "core/platform/windows/hardware_core_enumerator.h" -#endif + #include #include #include "core/platform/path_lib.h" // for LoopDir() -#include "core/platform/windows/dll_load_error.h" EXTERN_C IMAGE_DOS_HEADER __ImageBase; namespace onnxruntime { -class UnmapFileParam { - public: - void* addr; - size_t len; -}; - -static void UnmapFile(void* param) noexcept { - std::unique_ptr p(reinterpret_cast(param)); - bool ret = UnmapViewOfFile(p->addr); - if (!ret) { - const auto error_code = GetLastError(); - LOGS_DEFAULT(ERROR) << "unmap view of file failed. error code: " << error_code - << " error msg: " << std::system_category().message(error_code); - } -} std::wstring Basename(const std::wstring& path) { auto basename_index = path.find_last_of(L"/\\") + 1; // results in 0 if no separator is found @@ -158,7 +140,7 @@ class WindowsThread : public EnvThread { } else { // Logical processor id starts from 0 internally, but in ort API, it starts from 1, // that's why id need to increase by 1 when logging. - LOGS_DEFAULT(ERROR) << "Cannot set affinity for thread " << GetCurrentThreadId() + std::cout << "Cannot set affinity for thread " << GetCurrentThreadId() << ", processor " << global_processor_id + 1 << " does not exist"; group_id = -1; mask = 0; @@ -167,7 +149,7 @@ class WindowsThread : public EnvThread { if (group_id == -1) { group_id = processor_info.group_id; } else if (group_id != processor_info.group_id) { - LOGS_DEFAULT(ERROR) << "Cannot set cross-group affinity for thread " + std::cout << "Cannot set cross-group affinity for thread " << GetCurrentThreadId() << ", first on group " << group_id << ", then on " << processor_info.group_id; group_id = -1; @@ -180,12 +162,12 @@ class WindowsThread : public EnvThread { thread_affinity.Group = static_cast(group_id); thread_affinity.Mask = mask; if (SetThreadGroupAffinity(GetCurrentThread(), &thread_affinity, nullptr)) { - LOGS_DEFAULT(VERBOSE) << "SetThreadAffinityMask done for thread: " << GetCurrentThreadId() + std::cout << "SetThreadAffinityMask done for thread: " << GetCurrentThreadId() << ", group_id: " << thread_affinity.Group << ", mask: " << thread_affinity.Mask; } else { const auto error_code = GetLastError(); - LOGS_DEFAULT(ERROR) << "SetThreadAffinityMask failed for thread: " << GetCurrentThreadId() + std::cout << "SetThreadAffinityMask failed for thread: " << GetCurrentThreadId() << ", index: " << p->index << ", mask: " << *p->affinity << ", error code: " << error_code @@ -248,44 +230,7 @@ int WindowsEnv::DefaultNumCores() { } int WindowsEnv::GetNumPhysicalCpuCores() const { -// EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. -#if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) - // The following code is a temporary fix for a perf problem on Intel's Meteor Lake CPUs. The Intel compute platform has - // a hybrid architecture that some CPU cores runs significant slower than the others. If we distribute our compute work - // evenly to all CPU cores, the slowest CPU core will drag the performance down. So, instead, we reduce the total number - // of threads to exclude the slowest cores out. - // The following code is based on assumptions that: - // 1. All Intel hybrid CPUs should have 3 levels of cache. - // 2. If a CPU core is only associated with two levels of cache, it should be a low performance CPU core and should - // not be used. - // Since we don't know what the next Intel hybrid CPU would be like, later on we may need to rework the following code. - // However, no matter what the code should not cause any crash. The worst is it might return 1 that - // thread pools will not be created, which is just a perf issue and does not impact usability. - // TODO: detect if CPUID instruction is available per instructions at https://wiki.osdev.org/CPUID#Checking_CPUID_availability - int regs[4]; - __cpuid(regs, 0); - bool bIsIntel = - (kVendorID_Intel[0] == regs[1]) && - (kVendorID_Intel[1] == regs[2]) && - (kVendorID_Intel[2] == regs[3]); - if (bIsIntel && regs[0] >= 7) { - // Query Structured Extended Feature Flags Enumeration Leaf - __cpuid(regs, 0x7); - // The bit 15 of EDX indicates if the processor is identified as a hybrid part. - bool ishybrid = regs[3] & (1 << 15); - if (ishybrid) { - // NOTE: even if ishybrid is true, it doesn't mean the processor must have P-cores and E-cores. - // On Intel CPUs we assume the HardwareCoreEnumerator::DefaultIntraOpNumThreads function would never fail. - // NOTE: due to resource restrictions, we cannot test this branch in our CI build pipelines. - return std::max(static_cast(1), HardwareCoreEnumerator::DefaultIntraOpNumThreads()); - } else { - return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); - } - } else -#endif - { - return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); - } + return cores_.empty() ? DefaultNumCores() : static_cast(cores_.size()); } std::vector WindowsEnv::GetDefaultThreadAffinities() const { @@ -394,18 +339,6 @@ Status WindowsEnv::ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* const file_path, c return Status::OK(); } - - -bool WindowsEnv::FileExists(const std::wstring& path) const { - DWORD attributes = GetFileAttributesW(path.c_str()); - return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL); -} - -bool WindowsEnv::FileExists(const std::string& path) const { - DWORD attributes = GetFileAttributesA(path.c_str()); - return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL); -} - common::Status WindowsEnv::GetCanonicalPath( const PathString& path, PathString& canonical_path) const { @@ -489,43 +422,6 @@ PathString WindowsEnv::GetRuntimePath() const { return path.substr(0, slash_index + 1); } -Status WindowsEnv::LoadDynamicLibrary(const PathString& wlibrary_filename, bool /*global_symbols*/, void** handle) const { -#if WINAPI_FAMILY == WINAPI_FAMILY_PC_APP - *handle = ::LoadPackagedLibrary(wlibrary_filename.c_str(), 0); -#else - // TODO: in most cases, the path name is a relative path and the behavior of the following line of code is undefined. - *handle = ::LoadLibraryExW(wlibrary_filename.c_str(), nullptr, LOAD_WITH_ALTERED_SEARCH_PATH); -#endif - if (!*handle) { - const auto error_code = GetLastError(); - static constexpr DWORD bufferLength = 64 * 1024; - std::wstring s(bufferLength, '\0'); - FormatMessageW( - FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, - error_code, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPWSTR)s.data(), - bufferLength, NULL); - s.erase(std::remove(s.begin(), s.end(), L'\r'), s.end()); - s.erase(std::remove(s.begin(), s.end(), L'\n'), s.end()); - std::wostringstream oss; - oss << DetermineLoadLibraryError(wlibrary_filename.c_str(), LOAD_WITH_ALTERED_SEARCH_PATH) - << L" (Error " << error_code << ": \"" << s.c_str() << "\")"; - std::wstring errmsg = oss.str(); - common::Status status(common::ONNXRUNTIME, common::FAIL, ToUTF8String(errmsg)); - return status; - } - return Status::OK(); -} - -Status WindowsEnv::UnloadDynamicLibrary(void* handle) const { - if (::FreeLibrary(reinterpret_cast(handle)) == 0) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "FreeLibrary failed with error ", error_code, " - ", std::system_category().message(error_code)); - } - return Status::OK(); -} namespace dlfcn_win32 { // adapted from https://github.com/dlfcn-win32 version 1.3.1. @@ -580,39 +476,6 @@ void* SearchModulesForSymbol(const char* name) { } } // namespace dlfcn_win32 -Status WindowsEnv::GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const { - Status status = Status::OK(); - - // global search to replicate dlsym RTLD_DEFAULT if handle is nullptr - if (handle == nullptr) { - *symbol = dlfcn_win32::SearchModulesForSymbol(symbol_name.c_str()); - } else { - *symbol = ::GetProcAddress(reinterpret_cast(handle), symbol_name.c_str()); - } - - if (!*symbol) { - const auto error_code = GetLastError(); - static constexpr DWORD bufferLength = 64 * 1024; - std::wstring s(bufferLength, '\0'); - FormatMessageW(FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, NULL, error_code, - MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - (LPWSTR)s.data(), 0, NULL); - std::wostringstream oss; - oss << L"Failed to find symbol " << ToWideString(symbol_name) << L" in library, error code: " - << error_code << L" \"" << s.c_str() << L"\""; - std::wstring errmsg = oss.str(); - // TODO: trim the ending '\r' and/or '\n' - status = Status(common::ONNXRUNTIME, common::FAIL, ToUTF8String(errmsg)); - } - - return status; -} - -std::string WindowsEnv::FormatLibraryFileName(const std::string& name, const std::string& version) const { - ORT_UNUSED_PARAMETER(name); - ORT_UNUSED_PARAMETER(version); - ORT_NOT_IMPLEMENTED(__FUNCTION__, " is not implemented"); -} // \brief returns a value for the queried variable name (var_name) std::string WindowsEnv::GetEnvironmentVar(const std::string& var_name) const { @@ -673,7 +536,7 @@ void WindowsEnv::InitializeCpuInfo() { if (last_error != ERROR_INSUFFICIENT_BUFFER) { const auto error_code = GetLastError(); if (logging::LoggingManager::HasDefaultLogger()) { - LOGS_DEFAULT(ERROR) << "Failed to calculate byte size for saving cpu info on windows" + std::cout << "Failed to calculate byte size for saving cpu info on windows" << ", error code: " << error_code << ", error msg: " << std::system_category().message(error_code); } @@ -686,7 +549,7 @@ void WindowsEnv::InitializeCpuInfo() { if (!GetLogicalProcessorInformationEx(RelationProcessorCore, processorInfos, &returnLength)) { const auto error_code = GetLastError(); if (logging::LoggingManager::HasDefaultLogger()) { - LOGS_DEFAULT(ERROR) << "Failed to fetch cpu info on windows" + std::cout << "Failed to fetch cpu info on windows" << ", error code: " << error_code << ", error msg: " << std::system_category().message(error_code); } @@ -739,7 +602,7 @@ void WindowsEnv::InitializeCpuInfo() { if (last_error != ERROR_INSUFFICIENT_BUFFER) { const auto error_code = GetLastError(); if (logging::LoggingManager::HasDefaultLogger()) { - LOGS_DEFAULT(ERROR) << "Failed to calculate byte size for saving cpu info on windows" + std::cout << "Failed to calculate byte size for saving cpu info on windows" << ", error code: " << error_code << ", error msg: " << std::system_category().message(error_code); } @@ -755,7 +618,7 @@ void WindowsEnv::InitializeCpuInfo() { if (!GetLogicalProcessorInformationEx(RelationCache, processorInfos, &newLength)) { const auto error_code = GetLastError(); if (logging::LoggingManager::HasDefaultLogger()) { - LOGS_DEFAULT(ERROR) << "Failed to fetch cpu info on windows" + std::cout << "Failed to fetch cpu info on windows" << ", error code: " << error_code << ", error msg: " << std::system_category().message(error_code); } diff --git a/src/core/platform/windows/env.h b/src/core/platform/windows/env.h index 05b92bb..c22e194 100644 --- a/src/core/platform/windows/env.h +++ b/src/core/platform/windows/env.h @@ -15,7 +15,6 @@ limitations under the License. // Portions Copyright (c) Microsoft Corporation #include "core/platform/env.h" -#include "core/platform/windows/telemetry.h" #include "core/common/inlined_containers.h" #include @@ -62,29 +61,10 @@ class WindowsEnv : public Env { common::Status GetFileLength(int fd, /*out*/ size_t& file_size) const override; Status ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* const file_path, const FileOffsetType offset, const size_t length, const gsl::span buffer) const override; - Status MapFileIntoMemory(_In_z_ const ORTCHAR_T* file_path, - FileOffsetType offset, - size_t length, - MappedMemoryPtr& mapped_memory) const override; - bool FolderExists(const std::wstring& path) const override; - bool FolderExists(const std::string& path) const override; - bool FileExists(const std::wstring& path) const override; - bool FileExists(const std::string& path) const override; - common::Status CreateFolder(const std::wstring& path) const override; - common::Status CreateFolder(const std::string& path) const override; - common::Status DeleteFolder(const PathString& path) const override; - common::Status FileOpenRd(const std::wstring& path, /*out*/ int& fd) const override; - common::Status FileOpenWr(const std::wstring& path, /*out*/ int& fd) const override; - common::Status FileOpenRd(const std::string& path, /*out*/ int& fd) const override; - common::Status FileOpenWr(const std::string& path, /*out*/ int& fd) const override; - common::Status FileClose(int fd) const override; + common::Status GetCanonicalPath(const PathString& path, PathString& canonical_path) const override; PathString GetRuntimePath() const override; - Status LoadDynamicLibrary(const PathString& library_filename, bool /*global_symbols*/, void** handle) const override; - Status UnloadDynamicLibrary(void* handle) const override; - Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const override; - std::string FormatLibraryFileName(const std::string& name, const std::string& version) const override; - const Telemetry& GetTelemetryProvider() const override; + std::string GetEnvironmentVar(const std::string& var_name) const override; ProcessorInfo GetProcessorAffinityMask(int global_processor_id) const; @@ -138,7 +118,6 @@ class WindowsEnv : public Env { private: void InitializeCpuInfo(); typedef VOID(WINAPI* FnGetSystemTimePreciseAsFileTime)(LPFILETIME); - WindowsTelemetry telemetry_provider_; }; } // namespace onnxruntime diff --git a/src/lib/CMakeLists.txt b/src/lib/CMakeLists.txt index 719423b..493c72f 100644 --- a/src/lib/CMakeLists.txt +++ b/src/lib/CMakeLists.txt @@ -42,7 +42,7 @@ endif() # hardware specific files would cause trouble in # multi-target build # -add_library(onnxruntime_mlas STATIC +onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/mlasi.h ${MLAS_SRC_DIR}/platform.cpp ${MLAS_SRC_DIR}/threading.cpp @@ -500,7 +500,7 @@ else() endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) - add_library(onnxruntime_mlas_arm STATIC64 ${mlas_platform_srcs}) + onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) set(mlas_platform_srcs ) @@ -756,7 +756,7 @@ endif() endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) - add_library(onnxruntime_mlas_x STATIC86_64 ${mlas_platform_srcs}) + onnxruntime_add_static_library(onnxruntime_mlas_x64 ${mlas_platform_srcs}) set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) set(mlas_platform_srcs ) diff --git a/src/ort_include/core/platform/env.h b/src/ort_include/core/platform/env.h index 970567a..e68039e 100644 --- a/src/ort_include/core/platform/env.h +++ b/src/ort_include/core/platform/env.h @@ -166,21 +166,7 @@ class Env { // This functions is always successful. It can't fail. virtual PIDType GetSelfPid() const = 0; - // \brief Load a dynamic library. - // - // Pass "library_filename" to a platform-specific mechanism for dynamically - // loading a library. The rules for determining the exact location of the - // library are platform-specific and are not documented here. - // - // global_symbols only has an effect on unix, where a value of true means to load with RTLD_GLOBAL vs RTLD_LOCAL - // - // On success, returns a handle to the library in "*handle" and returns - // OK from the function. - // Otherwise returns nullptr in "*handle" and an error status from the - // function. - virtual common::Status LoadDynamicLibrary(const PathString& library_filename, bool global_symbols, void** handle) const = 0; - - virtual common::Status UnloadDynamicLibrary(void* handle) const = 0; + // \brief Gets the file path of the onnx runtime code // @@ -188,20 +174,7 @@ class Env { // The DNNL provider shared library. Without this path, the module won't be found on windows in all cases. virtual PathString GetRuntimePath() const { return PathString(); } - // \brief Get a pointer to a symbol from a dynamic library. - // - // "handle" should be a pointer returned from a previous call to LoadDynamicLibrary. - // On success, store a pointer to the located symbol in "*symbol" and return - // OK from the function. Otherwise, returns nullptr in "*symbol" and an error - // status from the function. - virtual common::Status GetSymbolFromLibrary(void* handle, const std::string& symbol_name, void** symbol) const = 0; - // \brief build the name of dynamic library. - // - // "name" should be name of the library. - // "version" should be the version of the library or NULL - // returns the name that LoadDynamicLibrary() can use - virtual std::string FormatLibraryFileName(const std::string& name, const std::string& version) const = 0; // \brief returns a value for the queried variable name (var_name) // diff --git a/tests/unittest/CMakeLists.txt b/tests/unittest/CMakeLists.txt index 757082b..764355a 100644 --- a/tests/unittest/CMakeLists.txt +++ b/tests/unittest/CMakeLists.txt @@ -1,6 +1,6 @@ find_package(GTest) -add_executable(mlas_unittest test_activation.cpp +onnxruntime_add_executable(mlas_unittest test_activation.cpp test_blkq8.cpp test_blockq4.cpp test_conv2d.cpp From a440bb45b707cb2495377d3bd774ef3f7372e596 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 10:02:55 -0700 Subject: [PATCH 14/33] update --- src/common/cpuid_info.cc | 2 +- src/core/platform/posix/env.cc | 6 +++--- src/core/platform/windows/env.cc | 6 +++--- src/ort_include/core/platform/env_var_utils.h | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/common/cpuid_info.cc b/src/common/cpuid_info.cc index aec3ee5..0b675fc 100644 --- a/src/common/cpuid_info.cc +++ b/src/common/cpuid_info.cc @@ -366,7 +366,7 @@ CPUIDInfo::CPUIDInfo() { #if defined(CPUINFO_SUPPORTED) pytorch_cpuinfo_init_ = cpuinfo_initialize(); if (!pytorch_cpuinfo_init_) { - LOGS_DEFAULT(WARNING) << "Failed to initialize PyTorch cpuinfo library. May cause CPU EP performance degradation " + std::cout << "Failed to initialize PyTorch cpuinfo library. May cause CPU EP performance degradation " "due to undetected CPU features."; } #endif // defined(CPUINFO_SUPPORTED) diff --git a/src/core/platform/posix/env.cc b/src/core/platform/posix/env.cc index 47b12f2..2318c09 100644 --- a/src/core/platform/posix/env.cc +++ b/src/core/platform/posix/env.cc @@ -96,7 +96,7 @@ int nftw_remove( const auto result = remove(fpath); if (result != 0) { auto [err_no, err_msg] = GetErrnoInfo(); - LOGS_DEFAULT(WARNING) << "remove() failed. Error code: " << err_no << " error msg: " << err_msg + std::cout << "remove() failed. Error code: " << err_no << " error msg: " << err_msg << ", path: " << fpath; } return result; @@ -197,7 +197,7 @@ class PosixThread : public EnvThread { } auto ret = pthread_setaffinity_np(pthread_self(), sizeof(cpu_set_t), &cpuset); if (0 == ret) { - LOGS_DEFAULT(VERBOSE) << "pthread_setaffinity_np succeed for thread: " << syscall(SYS_gettid) + std::cout << "pthread_setaffinity_np succeed for thread: " << syscall(SYS_gettid) << ", index: " << p->index << ", mask: " << *p->affinity; } else { @@ -425,7 +425,7 @@ class PosixEnv : public Env { PosixEnv() { cpuinfo_available_ = cpuinfo_initialize(); if (!cpuinfo_available_) { - LOGS_DEFAULT(INFO) << "cpuinfo_initialize failed"; + std::cout << "cpuinfo_initialize failed"; } } bool cpuinfo_available_{false}; diff --git a/src/core/platform/windows/env.cc b/src/core/platform/windows/env.cc index dc07992..fb4055f 100644 --- a/src/core/platform/windows/env.cc +++ b/src/core/platform/windows/env.cc @@ -643,9 +643,9 @@ void WindowsEnv::InitializeCpuInfo() { } if (logging::LoggingManager::HasDefaultLogger()) { - LOGS_DEFAULT(VERBOSE) << "Found total " << cores_.size() << " core(s) from windows system:"; - LOGS_DEFAULT(VERBOSE) << log_stream.str(); - LOGS_DEFAULT(VERBOSE) << "\nDetected L2 cache size: " << l2_cache_size_ << " bytes"; + std::cout << "Found total " << cores_.size() << " core(s) from windows system:"; + std::cout << log_stream.str(); + std::cout << "\nDetected L2 cache size: " << l2_cache_size_ << " bytes"; } } } // namespace onnxruntime diff --git a/src/ort_include/core/platform/env_var_utils.h b/src/ort_include/core/platform/env_var_utils.h index 63a2fed..efb80ac 100644 --- a/src/ort_include/core/platform/env_var_utils.h +++ b/src/ort_include/core/platform/env_var_utils.h @@ -83,7 +83,7 @@ std::optional ParseTestOnlyEnvironmentVariable(const std::string& name, std::string default_hint = "End users should opt for provider options or session options."; const std::string& logged_hint = hint.empty() ? default_hint : hint; - LOGS_DEFAULT(WARNING) << "Environment variable " << name << " is used. It is reserved for internal testing purpose. " + std::cout << "Environment variable " << name << " is used. It is reserved for internal testing purpose. " << logged_hint; return env; From eabc550e68328de81024c349abf96d0caac9efef Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 10:19:47 -0700 Subject: [PATCH 15/33] update --- src/core/platform/posix/env.cc | 13 ---- src/lib/mlasi.h | 65 +++++++++++++++---- src/ort_include/core/platform/env_var_utils.h | 2 +- tests/unittest/test_util.h | 5 ++ 4 files changed, 59 insertions(+), 26 deletions(-) diff --git a/src/core/platform/posix/env.cc b/src/core/platform/posix/env.cc index 2318c09..43c6c4d 100644 --- a/src/core/platform/posix/env.cc +++ b/src/core/platform/posix/env.cc @@ -89,19 +89,6 @@ long int TempFailureRetry(TFunc retriable_operation, TFuncArgs&&... args) { return result; } -// nftw() callback to remove a file -int nftw_remove( - const char* fpath, const struct stat* /*sb*/, - int /*typeflag*/, struct FTW* /*ftwbuf*/) { - const auto result = remove(fpath); - if (result != 0) { - auto [err_no, err_msg] = GetErrnoInfo(); - std::cout << "remove() failed. Error code: " << err_no << " error msg: " << err_msg - << ", path: " << fpath; - } - return result; -} - template struct Freer { void operator()(T* p) { ::free(p); } diff --git a/src/lib/mlasi.h b/src/lib/mlasi.h index 184816a..0dd0165 100644 --- a/src/lib/mlasi.h +++ b/src/lib/mlasi.h @@ -252,25 +252,66 @@ enum MlasUArch { // Define MLAS_FP16 // #include "mlas_float16.h" +#include "../ort_include/core/session/onnxruntime_float16.h" namespace onnxruntime { -struct MLFloat16 { - uint16_t val{0}; +// MLFloat16 +struct MLFloat16 : onnxruntime_float16::Float16Impl { + private: + explicit constexpr MLFloat16(uint16_t x) noexcept { val = x; } - MLFloat16() = default; - explicit constexpr MLFloat16(uint16_t x) : val(x) {} - explicit MLFloat16(float ff) : val(MLAS_Float2Half(ff)) {} + public: + using Base = onnxruntime_float16::Float16Impl; - float ToFloat() const { return MLAS_Half2Float(val); } + MLFloat16() = default; - operator float() const { return ToFloat(); } + constexpr static MLFloat16 FromBits(uint16_t x) noexcept { return MLFloat16(x); } - MLFloat16& operator=(float ff) - { - val = MLAS_Float2Half(ff); - return *this; - } + // Using inherited implementation instead of math floatToHalf allows us to use this + // in other shared providers without having to implement the bridge + explicit MLFloat16(float v) noexcept { val = Base::ToUint16Impl(v); } + + static const MLFloat16 NaN; + static const MLFloat16 NegativeNaN; + static const MLFloat16 Infinity; + static const MLFloat16 NegativeInfinity; + static const MLFloat16 MaxValue; + static const MLFloat16 Zero; + static const MLFloat16 One; + static const MLFloat16 MinusOne; + + // Using inherited implementation instead of math halfToFloat allows us to use this + // in other shared providers without having to implement the bridge + float ToFloat() const noexcept { return Base::ToFloatImpl(); } + + using Base::IsNegative; + + using Base::IsNaN; + + using Base::IsFinite; + + using Base::IsPositiveInfinity; + + using Base::IsNegativeInfinity; + + using Base::IsInfinity; + + using Base::IsNaNOrZero; + + using Base::IsNormal; + + using Base::IsSubnormal; + + using Base::Abs; + + using Base::Negate; + + operator float() const noexcept { return ToFloat(); } + + using Base::operator==; + using Base::operator!=; + using Base::operator<; }; inline bool diff --git a/src/ort_include/core/platform/env_var_utils.h b/src/ort_include/core/platform/env_var_utils.h index efb80ac..b7cb6ea 100644 --- a/src/ort_include/core/platform/env_var_utils.h +++ b/src/ort_include/core/platform/env_var_utils.h @@ -4,7 +4,7 @@ #pragma once #include - +#include #include "core/common/common.h" #ifndef SHARED_PROVIDER #include "core/common/logging/logging.h" diff --git a/tests/unittest/test_util.h b/tests/unittest/test_util.h index a000e35..94c4143 100644 --- a/tests/unittest/test_util.h +++ b/tests/unittest/test_util.h @@ -7,6 +7,7 @@ #include "gtest/gtest.h" #include +#include #include #include #include @@ -87,7 +88,11 @@ class MatrixGuardBuffer { #if defined(_WIN32) if (VirtualAlloc(_BaseBuffer, BytesToAllocate, MEM_COMMIT, PAGE_READWRITE) == nullptr) { +#ifdef BUILD_MLAS_NO_ONNXRUNTIME + abort(); +#else ORT_THROW_EX(std::bad_alloc); +#endif } #else if (mprotect(_BaseBuffer, BytesToAllocate, PROT_READ | PROT_WRITE) != 0) { From b11a4c025c13b61c7bdf16e97a85dbbbf71c28e7 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 10:20:52 -0700 Subject: [PATCH 16/33] update --- src/lib/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/CMakeLists.txt b/src/lib/CMakeLists.txt index 493c72f..f7d6d97 100644 --- a/src/lib/CMakeLists.txt +++ b/src/lib/CMakeLists.txt @@ -854,7 +854,7 @@ endblock() target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) set_target_properties(onnxruntime_mlas_q4dq PROPERTIES FOLDER "ONNXRuntimeTest") - target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common) + target_link_libraries(onnxruntime_mlas_q4dq PRIVATE ${ONNXRUNTIME_MLAS_LIBS}) if(NOT MLAS_NO_ONNXRUNTIME) target_link_libraries(onnxruntime_mlas_q4dq PRIVATE onnxruntime_common) endif() From 8b4cd4c228cb4de1087e351d07b96f2fd9f569c6 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 10:22:04 -0700 Subject: [PATCH 17/33] update --- src/lib/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/CMakeLists.txt b/src/lib/CMakeLists.txt index f7d6d97..acabed4 100644 --- a/src/lib/CMakeLists.txt +++ b/src/lib/CMakeLists.txt @@ -756,7 +756,7 @@ endif() endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) - onnxruntime_add_static_library(onnxruntime_mlas_x64 ${mlas_platform_srcs}) + onnxruntime_add_static_library(onnxruntime_mlas_x86_64 ${mlas_platform_srcs}) set_target_properties(onnxruntime_mlas_x86_64 PROPERTIES OSX_ARCHITECTURES "x86_64") list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_x86_64) set(mlas_platform_srcs ) From 93a774d925d81a346e27306b7f042e6e502dcaf3 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 10:30:44 -0700 Subject: [PATCH 18/33] update --- cmake/deps.txt | 2 +- tests/unittest/test_halfgemm.h | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/cmake/deps.txt b/cmake/deps.txt index fa296d6..7a9a0bf 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -1,5 +1,5 @@ eigen;https://gitlab.com/libeigen/eigen/-/archive/ff174f79264d3f8dc0115dea7a288f98208b694f/eigen-ff174f79264d3f8dc0115dea7a288f98208b694f.zip;666e2f940faeef0196e72617a5d01241a22b67f3 microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 -googletest;https://github.com/google/googletest/archive/refs/tags/v1.15.0.zip;9d2d0af8d77ac726ea55d44a8fa727ec98311349 +googletest;https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip;f638fa0e724760e2ba07ff8cfba32cd644e1ce28 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 diff --git a/tests/unittest/test_halfgemm.h b/tests/unittest/test_halfgemm.h index 4db5c2b..5df0944 100644 --- a/tests/unittest/test_halfgemm.h +++ b/tests/unittest/test_halfgemm.h @@ -17,6 +17,8 @@ Module Name: #pragma once #include "test_fp16.h" +#include +#include /** * @brief Test class for half precision GEMM From 48d8c4cadd26cedf69cf16ca13357e89eee80260 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 10:34:08 -0700 Subject: [PATCH 19/33] change host --- .github/workflows/linux_ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/linux_ci.yml b/.github/workflows/linux_ci.yml index 4bee28b..9289336 100644 --- a/.github/workflows/linux_ci.yml +++ b/.github/workflows/linux_ci.yml @@ -12,7 +12,7 @@ concurrency: jobs: Linux_arm64_gcc_release: - runs-on: ["self-hosted", "1ES.Pool=mlas-linux-ARM64-CPU"] + runs-on: ubuntu-24.04-arm steps: - uses: actions/checkout@v4 - run: | From 143152a07f4c36425110fdaffcb7cf88cb0140ea Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 11:11:57 -0700 Subject: [PATCH 20/33] update --- CMakeLists.txt | 6 +- src/common/threadpool.cc | 221 - src/core/platform/windows/env.cc | 230 - src/core/platform/windows/env.h | 9 +- .../platform/EigenNonBlockingThreadPool.h | 119 +- src/ort_include/core/platform/env.h | 31 - src/ort_include/core/platform/threadpool.h | 10 +- .../platform/windows/TraceLoggingConfig.h | 81 - .../core/platform/windows/readme.txt | 2 - .../core/session/onnxruntime_c_api.h | 6140 ----------------- 10 files changed, 8 insertions(+), 6841 deletions(-) delete mode 100644 src/ort_include/core/platform/windows/TraceLoggingConfig.h delete mode 100644 src/ort_include/core/platform/windows/readme.txt diff --git a/CMakeLists.txt b/CMakeLists.txt index 66bddec..0c00a53 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,7 +28,11 @@ if(NOT CMAKE_C_STANDARD) endif() if(NOT CMAKE_CXX_STANDARD) message("Setting C++ standard to 20") - set(CMAKE_CXX_STANDARD 20) + if(WIN32) + set(CMAKE_CXX_STANDARD 23) + else() + set(CMAKE_CXX_STANDARD 20) + endif() endif() set(CMAKE_POSITION_INDEPENDENT_CODE ON) diff --git a/src/common/threadpool.cc b/src/common/threadpool.cc index 0fd2336..0cfdf08 100644 --- a/src/common/threadpool.cc +++ b/src/common/threadpool.cc @@ -46,202 +46,6 @@ namespace onnxruntime { namespace concurrency { -#if !defined(ORT_MINIMAL_BUILD) -ThreadPoolProfiler::ThreadPoolProfiler(int num_threads, const CHAR_TYPE* thread_pool_name) : num_threads_(num_threads) { - child_thread_stats_.assign(num_threads, {}); - if (thread_pool_name) { -#ifdef _WIN32 - thread_pool_name_ = ToUTF8String(thread_pool_name); -#else - thread_pool_name_ = thread_pool_name; -#endif - } else { - thread_pool_name_ = "unnamed_thread_pool"; - } -} - -ThreadPoolProfiler::~ThreadPoolProfiler() { - enabled_ = false; -} - -void ThreadPoolProfiler::Start() { - enabled_ = true; -} - -ThreadPoolProfiler::MainThreadStat& ThreadPoolProfiler::GetMainThreadStat() { - static thread_local std::unique_ptr stat; - if (!stat) { - stat = std::make_unique(); - } - return *stat; -} - -std::string ThreadPoolProfiler::Stop() { - ORT_ENFORCE(enabled_, "Profiler not started yet"); - std::ostringstream ss; - ss << "{\"main_thread\": {" - << "\"thread_pool_name\": \"" - << thread_pool_name_ << "\", " - << GetMainThreadStat().Reset() - << "}, \"sub_threads\": {" - << DumpChildThreadStat() - << "}}"; - return ss.str(); -} - -void ThreadPoolProfiler::LogStartAndCoreAndBlock(std::ptrdiff_t block_size) { - if (enabled_) { - MainThreadStat& stat = GetMainThreadStat(); - stat.LogCore(); - stat.LogBlockSize(block_size); - stat.LogStart(); - } -} - -void ThreadPoolProfiler::LogCoreAndBlock(std::ptrdiff_t block_size) { - if (enabled_) { - MainThreadStat& stat = GetMainThreadStat(); - stat.LogCore(); - stat.LogBlockSize(block_size); - } -} - -void ThreadPoolProfiler::LogStart() { - if (enabled_) { - GetMainThreadStat().LogStart(); - } -} - -void ThreadPoolProfiler::LogEnd(ThreadPoolEvent evt) { - if (enabled_) { - GetMainThreadStat().LogEnd(evt); - } -} - -void ThreadPoolProfiler::LogEndAndStart(ThreadPoolEvent evt) { - if (enabled_) { - GetMainThreadStat().LogEndAndStart(evt); - } -} - -void ThreadPoolProfiler::MainThreadStat::LogCore() { -#ifdef _WIN32 - core_ = GetCurrentProcessorNumber(); -#elif defined(__APPLE__) -#if defined(__x86_64__) || defined(__i386__) - uint32_t CPUInfo[4]; - __cpuid_count(1, 0, CPUInfo[0], CPUInfo[1], CPUInfo[2], CPUInfo[3]); - if ((CPUInfo[3] & (1 << 9)) != 0) { - core_ = (unsigned)CPUInfo[1] >> 24; - } -#endif -#elif defined(__wasm__) - core_ = emscripten_num_logical_cores(); -#elif defined(_AIX) - core_ = mycpu(); -#else - core_ = sched_getcpu(); -#endif -} - -void ThreadPoolProfiler::MainThreadStat::LogBlockSize(std::ptrdiff_t block_size) { - blocks_.emplace_back(block_size); -} - -void ThreadPoolProfiler::MainThreadStat::LogStart() { - points_.emplace_back(Clock::now()); -} - -void ThreadPoolProfiler::MainThreadStat::LogEnd(ThreadPoolEvent evt) { - ORT_ENFORCE(!points_.empty(), "LogStart must pair with LogEnd"); - events_[evt] += TimeDiffMicroSeconds(points_.back(), Clock::now()); - points_.pop_back(); -} - -void ThreadPoolProfiler::MainThreadStat::LogEndAndStart(ThreadPoolEvent evt) { - ORT_ENFORCE(!points_.empty(), "LogStart must pair with LogEnd"); - events_[evt] += TimeDiffMicroSeconds(points_.back(), Clock::now()); - points_.back() = Clock::now(); -} - -std::string ThreadPoolProfiler::MainThreadStat::Reset() { - ORT_ENFORCE(points_.empty(), "LogStart must pair with LogEnd"); - std::stringstream ss; - ss << "\"thread_id\": \"" << std::this_thread::get_id() << "\", \"block_size\": ["; - if (!blocks_.empty()) { - std::copy(blocks_.begin(), blocks_.end() - 1, std::ostream_iterator(ss, ", ")); - ss << blocks_.back(); - blocks_.clear(); - } - ss << "], \"core\": " << core_ << ", "; - for (int i = 0; i < MAX_EVENT; ++i) { - ss << "\"" << ThreadPoolProfiler::GetEventName(static_cast(i)) - << "\": " << events_[i] << ((i == MAX_EVENT - 1) ? std::string{} : ", "); - } - memset(events_, 0, sizeof(uint64_t) * MAX_EVENT); - return ss.str(); -} - -const char* ThreadPoolProfiler::GetEventName(ThreadPoolEvent event) { - switch (event) { - case DISTRIBUTION: - return "Distribution"; - case DISTRIBUTION_ENQUEUE: - return "DistributionEnqueue"; - case RUN: - return "Run"; - case WAIT: - return "Wait"; - case WAIT_REVOKE: - return "WaitRevoke"; - default: - return "UnknownEvent"; - } -} - -void ThreadPoolProfiler::LogThreadId(int thread_idx) { - child_thread_stats_[thread_idx].thread_id_ = std::this_thread::get_id(); -} - -void ThreadPoolProfiler::LogRun(int thread_idx) { - if (enabled_) { - child_thread_stats_[thread_idx].num_run_++; - auto now = Clock::now(); - if (child_thread_stats_[thread_idx].core_ < 0 || - TimeDiffMicroSeconds(child_thread_stats_[thread_idx].last_logged_point_, now) > 10000) { -#ifdef _WIN32 - child_thread_stats_[thread_idx].core_ = GetCurrentProcessorNumber(); -#elif defined(__APPLE__) -#if defined(__x86_64__) || defined(__i386__) - uint32_t CPUInfo[4]; - __cpuid_count(1, 0, CPUInfo[0], CPUInfo[1], CPUInfo[2], CPUInfo[3]); - if ((CPUInfo[3] & (1 << 9)) != 0) { - child_thread_stats_[thread_idx].core_ = (unsigned)CPUInfo[1] >> 24; - } -#endif -#elif defined(__wasm__) - child_thread_stats_[thread_idx].core_ = emscripten_num_logical_cores(); -#elif defined(_AIX) - child_thread_stats_[thread_idx].core_ = mycpu(); -#else - child_thread_stats_[thread_idx].core_ = sched_getcpu(); -#endif - child_thread_stats_[thread_idx].last_logged_point_ = now; - } - } -} - -std::string ThreadPoolProfiler::DumpChildThreadStat() { - std::stringstream ss; - for (int i = 0; i < num_threads_; ++i) { - ss << "\"" << child_thread_stats_[i].thread_id_ << "\": {" - << "\"num_run\": " << child_thread_stats_[i].num_run_ << ", " - << "\"core\": " << child_thread_stats_[i].core_ << "}" - << (i == num_threads_ - 1 ? "" : ","); - } - return ss.str(); -} -#endif // A sharded loop counter distributes loop iterations between a set of worker threads. The iteration space of // the loop is divided (perhaps unevenly) between the shards. Each thread has a home shard (perhaps not uniquely @@ -458,19 +262,7 @@ void ThreadPool::Schedule(std::function fn) { } } -void ThreadPool::StartProfiling() { - if (underlying_threadpool_) { - underlying_threadpool_->StartProfiling(); - } -} -std::string ThreadPool::StopProfiling() { - if (underlying_threadpool_) { - return underlying_threadpool_->StopProfiling(); - } else { - return {}; - } -} namespace { thread_local std::optional current_parallel_section; @@ -628,19 +420,6 @@ int ThreadPool::DegreeOfParallelism(const concurrency::ThreadPool* tp) { } } -void ThreadPool::StartProfiling(concurrency::ThreadPool* tp) { - if (tp) { - tp->StartProfiling(); - } -} - -std::string ThreadPool::StopProfiling(concurrency::ThreadPool* tp) { - if (tp) { - return tp->StopProfiling(); - } else { - return {}; - } -} void ThreadPool::EnableSpinning() { if (extended_eigen_threadpool_) { diff --git a/src/core/platform/windows/env.cc b/src/core/platform/windows/env.cc index fb4055f..9dab8a9 100644 --- a/src/core/platform/windows/env.cc +++ b/src/core/platform/windows/env.cc @@ -217,9 +217,6 @@ Env& Env::Default() { return WindowsEnv::Instance(); } -void WindowsEnv::SleepForMicroseconds(int64_t micros) const { - Sleep(static_cast(micros) / 1000); -} // EIGEN_NO_CPUID is not defined in any C/C++ source code. It is a compile option. #if defined(_M_X64) && !defined(_M_ARM64EC) && !defined(EIGEN_NO_CPUID) @@ -250,233 +247,6 @@ PIDType WindowsEnv::GetSelfPid() const { return GetCurrentProcessId(); } -Status WindowsEnv::GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const { - wil::unique_hfile file_handle{ - CreateFile2(file_path, FILE_READ_ATTRIBUTES, FILE_SHARE_READ, OPEN_EXISTING, NULL)}; - if (file_handle.get() == INVALID_HANDLE_VALUE) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - LARGE_INTEGER filesize; - if (!GetFileSizeEx(file_handle.get(), &filesize)) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetFileSizeEx ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - if (static_cast(filesize.QuadPart) > std::numeric_limits::max()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "GetFileLength: File is too large"); - } - length = static_cast(filesize.QuadPart); - return Status::OK(); -} - -common::Status WindowsEnv::GetFileLength(int fd, /*out*/ size_t& file_size) const { - using namespace common; - if (fd < 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid fd was supplied: ", fd); - } - - struct _stat buf; - int rc = _fstat(fd, &buf); - if (rc < 0) { - return Status(SYSTEM, errno); - } - - if (buf.st_size < 0) { - return ORT_MAKE_STATUS(SYSTEM, FAIL, "Received negative size from stat call"); - } - - if (static_cast(buf.st_size) > std::numeric_limits::max()) { - return ORT_MAKE_STATUS(SYSTEM, FAIL, "File is too large."); - } - - file_size = static_cast(buf.st_size); - return Status::OK(); -} - -Status WindowsEnv::ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* const file_path, const FileOffsetType offset, const size_t length, - const gsl::span buffer) const { - ORT_RETURN_IF_NOT(file_path, "file_path == nullptr"); - ORT_RETURN_IF_NOT(offset >= 0, "offset < 0"); - ORT_RETURN_IF_NOT(length <= buffer.size(), "length > buffer.size()"); - wil::unique_hfile file_handle{ - CreateFile2(file_path, GENERIC_READ, FILE_SHARE_READ, OPEN_EXISTING, NULL)}; - if (file_handle.get() == INVALID_HANDLE_VALUE) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - - if (length == 0) - return Status::OK(); - - if (offset > 0) { - LARGE_INTEGER current_position; - current_position.QuadPart = offset; - if (!SetFilePointerEx(file_handle.get(), current_position, ¤t_position, FILE_BEGIN)) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SetFilePointerEx ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - } - - size_t total_bytes_read = 0; - while (total_bytes_read < length) { - constexpr DWORD k_max_bytes_to_read = 1 << 30; // read at most 1GB each time - const size_t bytes_remaining = length - total_bytes_read; - const DWORD bytes_to_read = static_cast(std::min(bytes_remaining, k_max_bytes_to_read)); - DWORD bytes_read; - - if (!ReadFile(file_handle.get(), buffer.data() + total_bytes_read, bytes_to_read, &bytes_read, nullptr)) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ReadFile ", ToUTF8String(Basename(file_path)), " fail, errcode = ", error_code, " - ", std::system_category().message(error_code)); - } - - if (bytes_read != bytes_to_read) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ReadFile ", ToUTF8String(Basename(file_path)), " fail: unexpected end"); - } - - total_bytes_read += bytes_read; - } - - return Status::OK(); -} - -common::Status WindowsEnv::GetCanonicalPath( - const PathString& path, - PathString& canonical_path) const { - // adapted from MSVC STL std::filesystem::canonical() implementation - // https://github.com/microsoft/STL/blob/ed3cbf36416a385828e7a5987ca52cb42882d84b/stl/inc/filesystem#L2986 - CREATEFILE2_EXTENDED_PARAMETERS param; - memset(¶m, 0, sizeof(param)); - param.dwSize = sizeof(CREATEFILE2_EXTENDED_PARAMETERS); - param.dwFileFlags = FILE_FLAG_BACKUP_SEMANTICS; - wil::unique_hfile file_handle{CreateFile2( - path.c_str(), - FILE_READ_ATTRIBUTES, - FILE_SHARE_READ | FILE_SHARE_WRITE | FILE_SHARE_DELETE, - OPEN_EXISTING, - ¶m)}; - - if (file_handle.get() == INVALID_HANDLE_VALUE) { - const auto error_code = GetLastError(); - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "open file ", ToUTF8String(Basename(path)), " fail, errcode = ", - error_code, " - ", std::system_category().message(error_code)); - } - - constexpr DWORD initial_buffer_size = MAX_PATH; - std::vector result_buffer{}; - result_buffer.resize(initial_buffer_size); - - while (true) { - const DWORD result_length = GetFinalPathNameByHandleW( - file_handle.get(), - result_buffer.data(), - static_cast(result_buffer.size()), - 0); - - ORT_RETURN_IF_NOT( - result_length > 0, "GetFinalPathNameByHandle() failed: ", GetLastError()); - - if (result_length < result_buffer.size()) { // buffer is large enough - canonical_path.assign(result_buffer.data(), result_length); - break; - } - - // need larger buffer - result_buffer.resize(result_length); - } - - // update prefixes - if (canonical_path.find(ORT_TSTR(R"(\\?\)")) == 0) { - if (canonical_path.size() > 6 && - (ORT_TSTR('A') <= canonical_path[4] && canonical_path[4] <= ORT_TSTR('Z') || - ORT_TSTR('a') <= canonical_path[4] && canonical_path[4] <= ORT_TSTR('z')) && - canonical_path[5] == ORT_TSTR(':')) { - // "\\?\:" -> ":" - canonical_path.erase(0, 4); - } else if (canonical_path.find(ORT_TSTR(R"(UNC\)"), 4) == 4) { - // "\\?\UNC\" -> "\\" - canonical_path.erase(2, 6); - } - } - - return Status::OK(); -} - -// Return the path of the executable/shared library for the current running code. This is to make it -// possible to load other shared libraries installed next to our core runtime code. -PathString WindowsEnv::GetRuntimePath() const { - wchar_t buffer[MAX_PATH]; - if (!GetModuleFileNameW(reinterpret_cast(&__ImageBase), buffer, _countof(buffer))) { - return PathString(); - } - - // Remove the filename at the end, but keep the trailing slash - PathString path(buffer); - auto slash_index = path.find_last_of(ORT_TSTR('\\')); - if (slash_index == std::string::npos) { - // Windows supports forward slashes - slash_index = path.find_last_of(ORT_TSTR('/')); - if (slash_index == std::string::npos) { - return PathString(); - } - } - return path.substr(0, slash_index + 1); -} - - -namespace dlfcn_win32 { -// adapted from https://github.com/dlfcn-win32 version 1.3.1. -// Simplified to only support finding symbols in libraries that were linked against. -// If ORT dynamically loads a custom ops library using RegisterCustomOpsLibrary[_V2] the handle from the library load -// is explicitly provided in the call to GetSymbolFromLibrary. -// -/* Load Psapi.dll at runtime, this avoids linking caveat */ -bool MyEnumProcessModules(HANDLE hProcess, HMODULE* lphModule, DWORD cb, LPDWORD lpcbNeeded) { - using EnumProcessModulesFn = BOOL(WINAPI*)(HANDLE, HMODULE*, DWORD, LPDWORD); - static EnumProcessModulesFn EnumProcessModulesPtr = []() { - EnumProcessModulesFn fn = nullptr; - // Windows 7 and newer versions have K32EnumProcessModules in Kernel32.dll which is always pre-loaded - HMODULE psapi = GetModuleHandleA("Kernel32.dll"); - if (psapi) { - fn = (EnumProcessModulesFn)(LPVOID)GetProcAddress(psapi, "K32EnumProcessModules"); - } - - return fn; - }(); - - if (EnumProcessModulesPtr == nullptr) { - return false; - } - - return EnumProcessModulesPtr(hProcess, lphModule, cb, lpcbNeeded); -} - -void* SearchModulesForSymbol(const char* name) { - HANDLE current_proc = GetCurrentProcess(); - DWORD size = 0; - void* symbol = nullptr; - - // GetModuleHandle(NULL) only returns the current program file. So if we want to get ALL loaded module including - // those in linked DLLs, we have to use EnumProcessModules(). - if (MyEnumProcessModules(current_proc, nullptr, 0, &size) != false) { - size_t num_handles = size / sizeof(HMODULE); - std::unique_ptr modules = std::make_unique(num_handles); - HMODULE* modules_ptr = modules.get(); - DWORD cb_needed = 0; - if (MyEnumProcessModules(current_proc, modules_ptr, size, &cb_needed) != 0 && size == cb_needed) { - for (size_t i = 0; i < num_handles; i++) { - symbol = GetProcAddress(modules[i], name); - if (symbol != nullptr) { - break; - } - } - } - } - - return symbol; -} -} // namespace dlfcn_win32 - - // \brief returns a value for the queried variable name (var_name) std::string WindowsEnv::GetEnvironmentVar(const std::string& var_name) const { // Why getenv() should be avoided on Windows: diff --git a/src/core/platform/windows/env.h b/src/core/platform/windows/env.h index c22e194..e66ca6f 100644 --- a/src/core/platform/windows/env.h +++ b/src/core/platform/windows/env.h @@ -50,20 +50,13 @@ class WindowsEnv : public Env { #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif - void SleepForMicroseconds(int64_t micros) const override; static int DefaultNumCores(); int GetNumPhysicalCpuCores() const override; std::vector GetDefaultThreadAffinities() const override; int GetL2CacheSize() const override; static WindowsEnv& Instance(); PIDType GetSelfPid() const override; - Status GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const override; - common::Status GetFileLength(int fd, /*out*/ size_t& file_size) const override; - Status ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* const file_path, const FileOffsetType offset, const size_t length, - const gsl::span buffer) const override; - - common::Status GetCanonicalPath(const PathString& path, PathString& canonical_path) const override; - PathString GetRuntimePath() const override; + std::string GetEnvironmentVar(const std::string& var_name) const override; ProcessorInfo GetProcessorAffinityMask(int global_processor_id) const; diff --git a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h index 45b1751..c313944 100644 --- a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h +++ b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h @@ -199,100 +199,6 @@ struct PaddingToAvoidFalseSharing { char padding[ORT_FALSE_SHARING_BYTES]; }; -/* Usage: -1. In executor, call Start() before profiling and Stop() to get profiled numbers; -2. Inside thread pool, call LogStart() before interested section and LogEnd... after to log elapsed time; -3. To extend, just add more events in enum Event before "All", and update GetEventName(...) accordingly; -4. Note LogStart must pair with either LogEnd or LogEndAndStart, otherwise ORT_ENFORCE will fail; -5. ThreadPoolProfiler is thread-safe. -*/ -#ifdef ORT_MINIMAL_BUILD -class ThreadPoolProfiler { - public: - enum ThreadPoolEvent { - DISTRIBUTION = 0, - DISTRIBUTION_ENQUEUE, - RUN, - WAIT, - WAIT_REVOKE, - MAX_EVENT - }; - ThreadPoolProfiler(int, const CHAR_TYPE*) {} - ~ThreadPoolProfiler() = default; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ThreadPoolProfiler); - void Start() {} - std::string Stop() { return "not available for minimal build"; } - void LogStart() {} - void LogEnd(ThreadPoolEvent) {} - void LogEndAndStart(ThreadPoolEvent) {} - void LogStartAndCoreAndBlock(std::ptrdiff_t) {} - void LogCoreAndBlock(std::ptrdiff_t) {} - void LogThreadId(int) {} - void LogRun(int) {} - std::string DumpChildThreadStat() { return {}; } -}; -#else -class ThreadPoolProfiler { - public: - enum ThreadPoolEvent { - DISTRIBUTION = 0, - DISTRIBUTION_ENQUEUE, - RUN, - WAIT, - WAIT_REVOKE, - MAX_EVENT - }; - ThreadPoolProfiler(int num_threads, const CHAR_TYPE* threal_pool_name); - ~ThreadPoolProfiler(); - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ThreadPoolProfiler); - using Clock = std::chrono::high_resolution_clock; - void Start(); // called by executor to start profiling - std::string Stop(); // called by executor to stop profiling and return collected numbers - void LogStart(); // called in main thread to record the starting time point - void LogEnd(ThreadPoolEvent); // called in main thread to calculate and save the time elapsed from last start point - void LogEndAndStart(ThreadPoolEvent); - void LogStartAndCoreAndBlock(std::ptrdiff_t block_size); - void LogCoreAndBlock(std::ptrdiff_t block_size); // called in main thread to log core and block size for task breakdown - void LogThreadId(int thread_idx); // called in child thread to log its id - void LogRun(int thread_idx); // called in child thread to log num of run - std::string DumpChildThreadStat(); // return all child statistics collected so far - - private: - static const char* GetEventName(ThreadPoolEvent); - struct MainThreadStat { - uint64_t events_[MAX_EVENT] = {}; - int32_t core_ = -1; - std::vector blocks_; // block size determined by cost model - std::vector points_; - void LogCore(); - void LogBlockSize(std::ptrdiff_t block_size); - void LogStart(); - void LogEnd(ThreadPoolEvent); - void LogEndAndStart(ThreadPoolEvent); - std::string Reset(); - }; - bool enabled_ = false; - MainThreadStat& GetMainThreadStat(); // return thread local stat - int num_threads_; -#ifdef _MSC_VER -#pragma warning(push) - // C4324: structure was padded due to alignment specifier -#pragma warning(disable : 4324) -#endif // _MSC_VER - struct ORT_ALIGN_TO_AVOID_FALSE_SHARING ChildThreadStat { - std::thread::id thread_id_; - uint64_t num_run_ = 0; - onnxruntime::TimePoint last_logged_point_ = Clock::now(); - int32_t core_ = -1; // core that the child thread is running on - }; -#ifdef _MSC_VER -#pragma warning(pop) -#endif // _MSC_VER - std::vector child_thread_stats_; - std::string thread_pool_name_; -}; -#endif - // Extended Eigen thread pool interface, avoiding the need to modify // the ThreadPoolInterface.h header from the external Eigen // repository. @@ -335,8 +241,6 @@ class ExtendedThreadPoolInterface : public Eigen::ThreadPoolInterface { // two loops execute in series in a parallel section. ] virtual void RunInParallel(std::function fn, unsigned n, std::ptrdiff_t block_size) = 0; - virtual void StartProfiling() = 0; - virtual std::string StopProfiling() = 0; }; class ThreadPoolParallelSection { @@ -705,7 +609,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter return 0; } - ThreadPoolProfiler profiler_; void SignalAllAndWait() { done_ = true; @@ -720,13 +623,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } public: - void StartProfiling() override { - profiler_.Start(); - } - std::string StopProfiling() override { - return profiler_.Stop(); - } struct Tag { constexpr Tag() : v_(0) { @@ -767,7 +664,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ThreadPoolTempl(const CHAR_TYPE* name, int num_threads, bool allow_spinning, Environment& env, const ThreadOptions& thread_options) - : profiler_(num_threads, name), + : env_(env), num_threads_(num_threads), allow_spinning_(allow_spinning), @@ -915,7 +812,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // tasks that were created (if any) for the parallel section. We // revoke tasks still in queues, and then wait for any that are // still running. - profiler_.LogStart(); unsigned tasks_started = static_cast(ps.tasks.size()); while (!ps.tasks.empty()) { const auto& item = ps.tasks.back(); @@ -925,7 +821,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } ps.tasks.pop_back(); } - profiler_.LogEnd(ThreadPoolProfiler::WAIT_REVOKE); // Wait for the dispatch task's own work... if (ps.dispatch_q_idx > -1) { @@ -1204,7 +1099,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ps.work_done.store(true, std::memory_order_release); }; - profiler_.LogStart(); ps.dispatch_q_idx = preferred_workers[current_dop] % num_threads_; WorkerData& dispatch_td = worker_data_[ps.dispatch_q_idx]; Queue& dispatch_que = dispatch_td.queue; @@ -1222,7 +1116,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } else { ps.dispatch_q_idx = -1; // failed to enqueue dispatch_task } - profiler_.LogEnd(ThreadPoolProfiler::DISTRIBUTION_ENQUEUE); } else { // Synchronous dispatch ScheduleOnPreferredWorkers(pt, ps, preferred_workers, current_dop, new_dop, std::move(worker_fn)); @@ -1240,7 +1133,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter unsigned n, std::ptrdiff_t block_size) override { ORT_ENFORCE(n <= num_threads_ + 1, "More work items than threads"); - profiler_.LogStartAndCoreAndBlock(block_size); PerThread* pt = GetPerThread(); assert(pt->leading_par_section && "RunInParallel, but not in parallel section"); assert((n > 1) && "Trivial parallel section; should be avoided by caller"); @@ -1270,18 +1162,15 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter }; RunInParallelInternal(*pt, ps, n, false, std::move(worker_fn)); assert(ps.dispatch_q_idx == -1); - profiler_.LogEndAndStart(ThreadPoolProfiler::DISTRIBUTION); // Run work in the main thread loop.fn(0); - profiler_.LogEndAndStart(ThreadPoolProfiler::RUN); // Wait for workers to exit the loop ps.current_loop = 0; while (ps.workers_in_loop) { onnxruntime::concurrency::SpinPause(); } - profiler_.LogEnd(ThreadPoolProfiler::WAIT); } // Run a single parallel loop _without_ a parallel section. This is a @@ -1298,16 +1187,12 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // 1. run fn(...); void RunInParallel(std::function fn, unsigned n, std::ptrdiff_t block_size) override { ORT_ENFORCE(n <= num_threads_ + 1, "More work items than threads"); - profiler_.LogStartAndCoreAndBlock(block_size); PerThread* pt = GetPerThread(); ThreadPoolParallelSection ps; StartParallelSectionInternal(*pt, ps); RunInParallelInternal(*pt, ps, n, true, fn); // select dispatcher and do job distribution; - profiler_.LogEndAndStart(ThreadPoolProfiler::DISTRIBUTION); fn(0); // run fn(0) - profiler_.LogEndAndStart(ThreadPoolProfiler::RUN); EndParallelSectionInternal(*pt, ps); // wait for all - profiler_.LogEnd(ThreadPoolProfiler::WAIT); } int NumThreads() const final { @@ -1539,7 +1424,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter const int steal_count = spin_count / 100; SetDenormalAsZero(set_denormal_as_zero_); - profiler_.LogThreadId(thread_id); while (!should_exit) { Task t = q.PopFront(); @@ -1632,7 +1516,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter if (t) { td.SetActive(); t(); - profiler_.LogRun(thread_id); td.SetSpinning(); } } diff --git a/src/ort_include/core/platform/env.h b/src/ort_include/core/platform/env.h index e68039e..ff3d78d 100644 --- a/src/ort_include/core/platform/env.h +++ b/src/ort_include/core/platform/env.h @@ -138,43 +138,12 @@ class Env { virtual int GetL2CacheSize() const = 0; - /// Sleeps/delays the thread for the prescribed number of micro-seconds. - /// On Windows, it's the min time to sleep, not the actual one. - virtual void SleepForMicroseconds(int64_t micros) const = 0; - /** - * Gets the length of the specified file. - */ - virtual common::Status GetFileLength(_In_z_ const ORTCHAR_T* file_path, size_t& length) const = 0; - virtual common::Status GetFileLength(int fd, /*out*/ size_t& file_size) const = 0; - - /** - * Copies the content of the file into the provided buffer. - * @param file_path The path to the file. - * @param offset The file offset from which to start reading. - * @param length The length in bytes to read. - * @param buffer The buffer in which to write. - */ - virtual common::Status ReadFileIntoBuffer(_In_z_ const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, - gsl::span buffer) const = 0; - /** Gets the canonical form of a file path (symlinks resolved). */ - virtual common::Status GetCanonicalPath( - const PathString& path, - PathString& canonical_path) const = 0; // This functions is always successful. It can't fail. virtual PIDType GetSelfPid() const = 0; - - - // \brief Gets the file path of the onnx runtime code - // - // Used to help load other shared libraries that live in the same folder as the core code, for example - // The DNNL provider shared library. Without this path, the module won't be found on windows in all cases. - virtual PathString GetRuntimePath() const { return PathString(); } - - // \brief returns a value for the queried variable name (var_name) // diff --git a/src/ort_include/core/platform/threadpool.h b/src/ort_include/core/platform/threadpool.h index 04df6dc..ad5c7a1 100644 --- a/src/ort_include/core/platform/threadpool.h +++ b/src/ort_include/core/platform/threadpool.h @@ -360,11 +360,7 @@ class ThreadPool { // working in combination with the thread initiating the loop. static int DegreeOfParallelism(const ThreadPool* tp); - ORT_DISALLOW_COPY_AND_ASSIGNMENT(ThreadPool); - - // StartProfiling and StopProfiling are not to be consumed as public-facing API - static void StartProfiling(concurrency::ThreadPool* tp); - static std::string StopProfiling(concurrency::ThreadPool* tp); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ThreadPool); private: friend class LoopCounter; @@ -411,10 +407,6 @@ class ThreadPool { void Schedule(std::function fn); - void StartProfiling(); - - std::string StopProfiling(); - ThreadOptions thread_options_; // If a thread pool is created with degree_of_parallelism != 1 then an underlying diff --git a/src/ort_include/core/platform/windows/TraceLoggingConfig.h b/src/ort_include/core/platform/windows/TraceLoggingConfig.h deleted file mode 100644 index 2987167..0000000 --- a/src/ort_include/core/platform/windows/TraceLoggingConfig.h +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -/* ++ -Module Name: - TraceLoggingConfig.h -Abstract: - Macro definitions used by this project's TraceLogging ETW providers: - - Configuration macros that select the ETW Provider Groups to be used by - this project. - - Constants for tags that are commonly used in Microsoft's - TraceLogging-based ETW. - Different versions of this file use different definitions for the - TraceLoggingOption configuration macros. The definitions in this file are - empty. As a result, providers using this configuration file will not join - any ETW Provider Groups and will not be given any special treatment by - group-sensitive ETW listeners. -Environment: - User mode or kernel mode. ---*/ - -#pragma once - -// Configuration macro for use in TRACELOGGING_DEFINE_PROVIDER. The definition -// in this file configures the provider as a normal (non-telemetry) provider. -#ifndef TraceLoggingOptionMicrosoftTelemetry -#define TraceLoggingOptionMicrosoftTelemetry() \ - TraceLoggingOptionGroup(0000000000, 00000, 00000, 0000, 0000, 0000, 0000, 0000, 000, 0000, 0000) -// Empty definition for TraceLoggingOptionMicrosoftTelemetry -#endif - -// Configuration macro for use in TRACELOGGING_DEFINE_PROVIDER. The definition -// in this file configures the provider as a normal (non-telemetry) provider. -#define TraceLoggingOptionWindowsCoreTelemetry() \ - // Empty definition for TraceLoggingOptionWindowsCoreTelemetry - -// Event privacy tags. Use the PDT macro values for the tag parameter, e.g.: -// TraceLoggingWrite(..., -// TelemetryPrivacyDataTag(PDT_BrowsingHistory | PDT_ProductAndServiceUsage), -// ...); -#define TelemetryPrivacyDataTag(tag) TraceLoggingUInt64((tag), "PartA_PrivTags") -#define PDT_BrowsingHistory 0x0000000000000002u -#define PDT_DeviceConnectivityAndConfiguration 0x0000000000000800u -#define PDT_InkingTypingAndSpeechUtterance 0x0000000000020000u -#define PDT_ProductAndServicePerformance 0x0000000001000000u -#define PDT_ProductAndServiceUsage 0x0000000002000000u -#define PDT_SoftwareSetupAndInventory 0x0000000080000000u - -// Event categories specified via keywords, e.g.: -// TraceLoggingWrite(..., -// TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES), -// ...); -#define MICROSOFT_KEYWORD_CRITICAL_DATA 0x0000800000000000 // Bit 47 -#define MICROSOFT_KEYWORD_MEASURES 0x0000400000000000 // Bit 46 -#define MICROSOFT_KEYWORD_TELEMETRY 0x0000200000000000 // Bit 45 -#define MICROSOFT_KEYWORD_RESERVED_44 0x0000100000000000 // Bit 44 (reserved for future assignment) - -// Event categories specified via event tags, e.g.: -// TraceLoggingWrite(..., -// TraceLoggingEventTag(MICROSOFT_EVENTTAG_REALTIME_LATENCY), -// ...); -#define MICROSOFT_EVENTTAG_DROP_USER_IDS 0x00008000 -#define MICROSOFT_EVENTTAG_AGGREGATE 0x00010000 -#define MICROSOFT_EVENTTAG_DROP_PII_EXCEPT_IP 0x00020000 -#define MICROSOFT_EVENTTAG_COSTDEFERRED_LATENCY 0x00040000 -#define MICROSOFT_EVENTTAG_CORE_DATA 0x00080000 -#define MICROSOFT_EVENTTAG_INJECT_XTOKEN 0x00100000 -#define MICROSOFT_EVENTTAG_REALTIME_LATENCY 0x00200000 -#define MICROSOFT_EVENTTAG_NORMAL_LATENCY 0x00400000 -#define MICROSOFT_EVENTTAG_CRITICAL_PERSISTENCE 0x00800000 -#define MICROSOFT_EVENTTAG_NORMAL_PERSISTENCE 0x01000000 -#define MICROSOFT_EVENTTAG_DROP_PII 0x02000000 -#define MICROSOFT_EVENTTAG_HASH_PII 0x04000000 -#define MICROSOFT_EVENTTAG_MARK_PII 0x08000000 - -// Field categories specified via field tags, e.g.: -// TraceLoggingWrite(..., -// TraceLoggingString(szUser, "UserName", "User's name", MICROSOFT_FIELDTAG_HASH_PII), -// ...); -#define MICROSOFT_FIELDTAG_DROP_PII 0x04000000 -#define MICROSOFT_FIELDTAG_HASH_PII 0x08000000 diff --git a/src/ort_include/core/platform/windows/readme.txt b/src/ort_include/core/platform/windows/readme.txt deleted file mode 100644 index f1a436f..0000000 --- a/src/ort_include/core/platform/windows/readme.txt +++ /dev/null @@ -1,2 +0,0 @@ -copied from minkernel/published/internal/telemetry/open_source/TraceLoggingConfig.h -this is the official open source edition for these configuration settings \ No newline at end of file diff --git a/src/ort_include/core/session/onnxruntime_c_api.h b/src/ort_include/core/session/onnxruntime_c_api.h index a2f518a..a3ae440 100644 --- a/src/ort_include/core/session/onnxruntime_c_api.h +++ b/src/ort_include/core/session/onnxruntime_c_api.h @@ -34,15 +34,7 @@ #include #include -/** \brief The API version defined in this header - * - * This value is used by some API functions to behave as this version of the header expects. - */ -#define ORT_API_VERSION 23 -#ifdef __cplusplus -extern "C" { -#endif //! @} // SAL2 Definitions @@ -111,6136 +103,4 @@ extern "C" { #endif #endif -// On Windows, ORT_FILE is a wchar_t version of the __FILE__ macro. -// Otherwise, ORT_FILE is equivalent to __FILE__. -#ifndef ORT_FILE -#define ORT_FILE_INTERNAL(x) ORT_TSTR(x) -#define ORT_FILE ORT_FILE_INTERNAL(__FILE__) -#endif - -// Any pointer marked with _In_ or _Out_, cannot be NULL. - -// Windows users should use unicode paths when possible to bypass the MAX_PATH limitation -// Every pointer marked with _In_ or _Out_, cannot be NULL. Caller should ensure that. -// for ReleaseXXX(...) functions, they can accept NULL pointer. - -#ifdef __cplusplus -// For any compiler with C++11 support, MSVC 2015 and greater, or Clang version supporting noexcept. -// Such complex condition is needed because compilers set __cplusplus value differently. -#ifndef __has_feature -#define __has_feature(x) 0 -#endif -#if ((__cplusplus >= 201103L) || (_MSC_VER >= 1900) || (defined(__has_feature) && __has_feature(cxx_noexcept))) -#define NO_EXCEPTION noexcept -#else -#define NO_EXCEPTION throw() -#endif -#else -#define NO_EXCEPTION -#endif - -// __VA_ARGS__ on Windows and Linux are different -#define ORT_API(RETURN_TYPE, NAME, ...) RETURN_TYPE ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION - -#define ORT_API_STATUS(NAME, ...) \ - _Success_(return == 0) _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) \ - NO_EXCEPTION ORT_MUST_USE_RESULT - -// XXX: Unfortunately, SAL annotations are known to not work with function pointers -#define ORT_API2_STATUS(NAME, ...) \ - _Check_return_ _Ret_maybenull_ OrtStatusPtr(ORT_API_CALL* NAME)(__VA_ARGS__) NO_EXCEPTION ORT_MUST_USE_RESULT - -// Used in *.cc files. Almost as same as ORT_API_STATUS, except without ORT_MUST_USE_RESULT and ORT_EXPORT -#define ORT_API_STATUS_IMPL(NAME, ...) \ - _Success_(return == 0) _Check_return_ _Ret_maybenull_ OrtStatusPtr ORT_API_CALL NAME(__VA_ARGS__) NO_EXCEPTION - -#define ORT_CLASS_RELEASE(X) void(ORT_API_CALL * Release##X)(_Frees_ptr_opt_ Ort##X * input) - -#ifdef __DOXYGEN__ -#undef ORT_API_STATUS -#define ORT_API_STATUS(NAME, ...) OrtStatus* NAME(__VA_ARGS__) -#undef ORT_API2_STATUS -#define ORT_API2_STATUS(NAME, ...) OrtStatus* NAME(__VA_ARGS__) -#undef ORT_CLASS_RELEASE -#define ORT_CLASS_RELEASE(X) void Release##X(Ort##X* input) -#undef NO_EXCEPTION -#define NO_EXCEPTION -#endif -/** \addtogroup Global - * ONNX Runtime C API - * @{ - */ - -/** Copied from TensorProto::DataType - * Currently, Ort doesn't support complex64, complex128 - */ -typedef enum ONNXTensorElementDataType { - ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED, - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, // maps to c type float - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8, // maps to c type uint8_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8, // maps to c type int8_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16, // maps to c type uint16_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16, // maps to c type int16_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32, // maps to c type int32_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64, // maps to c type int64_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING, // maps to c++ type std::string - ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL, - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16, - ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE, // maps to c type double - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32, // maps to c type uint32_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64, // maps to c type uint64_t - ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64, // complex with float32 real and imaginary components - ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128, // complex with float64 real and imaginary components - ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16, // Non-IEEE floating-point format based on IEEE754 single-precision - // float 8 types were introduced in onnx 1.14, see https://onnx.ai/onnx/technical/float8.html - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN, // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2, // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ, // Non-IEEE floating-point format based on IEEE754 single-precision - // Int4 types were introduced in ONNX 1.16. See https://onnx.ai/onnx/technical/int4.html - ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT4, // maps to a pair of packed uint4 values (size == 1 byte) - ONNX_TENSOR_ELEMENT_DATA_TYPE_INT4 // maps to a pair of packed int4 values (size == 1 byte) -} ONNXTensorElementDataType; - -// Synced with onnx TypeProto oneof -typedef enum ONNXType { - ONNX_TYPE_UNKNOWN, - ONNX_TYPE_TENSOR, - ONNX_TYPE_SEQUENCE, - ONNX_TYPE_MAP, - ONNX_TYPE_OPAQUE, - ONNX_TYPE_SPARSETENSOR, - ONNX_TYPE_OPTIONAL -} ONNXType; - -// These types are synced with internal -// SparseFormatFlags -typedef enum OrtSparseFormat { - ORT_SPARSE_UNDEFINED = 0, - ORT_SPARSE_COO = 0x1, - ORT_SPARSE_CSRC = 0x2, - ORT_SPARSE_BLOCK_SPARSE = 0x4 -} OrtSparseFormat; - -// Enum allows to query sparse tensor indices -enum OrtSparseIndicesFormat { - ORT_SPARSE_COO_INDICES, - ORT_SPARSE_CSR_INNER_INDICES, - ORT_SPARSE_CSR_OUTER_INDICES, - ORT_SPARSE_BLOCK_SPARSE_INDICES -}; - -/** \brief Logging severity levels - * - * In typical API usage, specifying a logging severity level specifies the minimum severity of log messages to show. - */ -typedef enum OrtLoggingLevel { - ORT_LOGGING_LEVEL_VERBOSE, ///< Verbose informational messages (least severe). - ORT_LOGGING_LEVEL_INFO, ///< Informational messages. - ORT_LOGGING_LEVEL_WARNING, ///< Warning messages. - ORT_LOGGING_LEVEL_ERROR, ///< Error messages. - ORT_LOGGING_LEVEL_FATAL, ///< Fatal error messages (most severe). -} OrtLoggingLevel; - -typedef enum OrtErrorCode { - ORT_OK, - ORT_FAIL, - ORT_INVALID_ARGUMENT, - ORT_NO_SUCHFILE, - ORT_NO_MODEL, - ORT_ENGINE_ERROR, - ORT_RUNTIME_EXCEPTION, - ORT_INVALID_PROTOBUF, - ORT_MODEL_LOADED, - ORT_NOT_IMPLEMENTED, - ORT_INVALID_GRAPH, - ORT_EP_FAIL, - ORT_MODEL_LOAD_CANCELED, - ORT_MODEL_REQUIRES_COMPILATION, -} OrtErrorCode; - -typedef enum OrtOpAttrType { - ORT_OP_ATTR_UNDEFINED = 0, - ORT_OP_ATTR_INT, - ORT_OP_ATTR_INTS, - ORT_OP_ATTR_FLOAT, - ORT_OP_ATTR_FLOATS, - ORT_OP_ATTR_STRING, - ORT_OP_ATTR_STRINGS, -} OrtOpAttrType; - -//! @} -#define ORT_RUNTIME_CLASS(X) \ - struct Ort##X; \ - typedef struct Ort##X Ort##X - -/** \addtogroup Global - * ONNX Runtime C API - * @{ - */ -// The actual types defined have an Ort prefix -ORT_RUNTIME_CLASS(Env); -ORT_RUNTIME_CLASS(Status); // nullptr for Status* indicates success -ORT_RUNTIME_CLASS(MemoryInfo); -ORT_RUNTIME_CLASS(IoBinding); -ORT_RUNTIME_CLASS(Session); // Don't call ReleaseSession from Dllmain (because session owns a thread pool) -ORT_RUNTIME_CLASS(Value); -ORT_RUNTIME_CLASS(RunOptions); -ORT_RUNTIME_CLASS(TypeInfo); -ORT_RUNTIME_CLASS(TensorTypeAndShapeInfo); -ORT_RUNTIME_CLASS(MapTypeInfo); -ORT_RUNTIME_CLASS(SequenceTypeInfo); -ORT_RUNTIME_CLASS(OptionalTypeInfo); -ORT_RUNTIME_CLASS(SessionOptions); -ORT_RUNTIME_CLASS(CustomOpDomain); -ORT_RUNTIME_CLASS(ModelMetadata); -ORT_RUNTIME_CLASS(ThreadPoolParams); -ORT_RUNTIME_CLASS(ThreadingOptions); -ORT_RUNTIME_CLASS(ArenaCfg); -ORT_RUNTIME_CLASS(PrepackedWeightsContainer); -ORT_RUNTIME_CLASS(TensorRTProviderOptionsV2); -ORT_RUNTIME_CLASS(NvTensorRtRtxProviderOptions); -ORT_RUNTIME_CLASS(CUDAProviderOptionsV2); -ORT_RUNTIME_CLASS(CANNProviderOptions); -ORT_RUNTIME_CLASS(DnnlProviderOptions); -ORT_RUNTIME_CLASS(Op); -ORT_RUNTIME_CLASS(OpAttr); -ORT_RUNTIME_CLASS(Logger); -ORT_RUNTIME_CLASS(ShapeInferContext); -ORT_RUNTIME_CLASS(LoraAdapter); -ORT_RUNTIME_CLASS(ValueInfo); -ORT_RUNTIME_CLASS(Node); -ORT_RUNTIME_CLASS(Graph); -ORT_RUNTIME_CLASS(Model); -ORT_RUNTIME_CLASS(ModelCompilationOptions); -ORT_RUNTIME_CLASS(HardwareDevice); -ORT_RUNTIME_CLASS(EpDevice); -ORT_RUNTIME_CLASS(KeyValuePairs); - -#ifdef _MSC_VER -typedef _Return_type_success_(return == 0) OrtStatus* OrtStatusPtr; -#else -typedef OrtStatus* OrtStatusPtr; -#endif - -/** \brief Memory allocation interface - * - * Structure of function pointers that defines a memory allocator. This can be created and filled in by the user for custom allocators. - * - * When an allocator is passed to any function, be sure that the allocator object is not destroyed until the last allocated object using it is freed. - */ -typedef struct OrtAllocator { - uint32_t version; ///< Must be initialized to ORT_API_VERSION - void*(ORT_API_CALL* Alloc)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes - void(ORT_API_CALL* Free)(struct OrtAllocator* this_, void* p); ///< Free a block of memory previously allocated with OrtAllocator::Alloc - const struct OrtMemoryInfo*(ORT_API_CALL* Info)(const struct OrtAllocator* this_); ///< Return a pointer to an ::OrtMemoryInfo that describes this allocator - /** - * @brief Optional allocation function to use for memory allocations made during session initialization. - * Use this function if you want to separate allocations made by ORT during Run() calls from - * those made during session initialization. This allows for separate memory management strategies for these allocations. - */ - void*(ORT_API_CALL* Reserve)(struct OrtAllocator* this_, size_t size); ///< Returns a pointer to an allocated block of `size` bytes -} OrtAllocator; - -typedef void(ORT_API_CALL* OrtLoggingFunction)( - void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, - const char* message); - -/** \brief Graph optimization level - * - * Refer to https://www.onnxruntime.ai/docs/performance/graph-optimizations.html#graph-optimization-levels - * for an in-depth understanding of the Graph Optimization Levels. - */ -typedef enum GraphOptimizationLevel { - ORT_DISABLE_ALL = 0, - ORT_ENABLE_BASIC = 1, - ORT_ENABLE_EXTENDED = 2, - ORT_ENABLE_ALL = 99 -} GraphOptimizationLevel; - -typedef enum ExecutionMode { - ORT_SEQUENTIAL = 0, - ORT_PARALLEL = 1, -} ExecutionMode; - -/** \brief Language projection identifiers - * /see OrtApi::SetLanguageProjection - */ -typedef enum OrtLanguageProjection { - ORT_PROJECTION_C = 0, - ORT_PROJECTION_CPLUSPLUS = 1, - ORT_PROJECTION_CSHARP = 2, - ORT_PROJECTION_PYTHON = 3, - ORT_PROJECTION_JAVA = 4, - ORT_PROJECTION_WINML = 5, - ORT_PROJECTION_NODEJS = 6, -} OrtLanguageProjection; - -struct OrtKernelInfo; -typedef struct OrtKernelInfo OrtKernelInfo; -struct OrtKernelContext; -typedef struct OrtKernelContext OrtKernelContext; -struct OrtCustomOp; -typedef struct OrtCustomOp OrtCustomOp; - -typedef enum OrtAllocatorType { - OrtInvalidAllocator = -1, - OrtDeviceAllocator = 0, - OrtArenaAllocator = 1 -} OrtAllocatorType; - -/** \brief Memory types for allocated memory, execution provider specific types should be extended in each provider. - */ -// Whenever this struct is updated, please also update the MakeKey function in onnxruntime / core / framework / execution_provider.cc -typedef enum OrtMemType { - OrtMemTypeCPUInput = -2, ///< Any CPU memory used by non-CPU execution provider - OrtMemTypeCPUOutput = -1, ///< CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED - OrtMemTypeCPU = OrtMemTypeCPUOutput, ///< Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. CUDA_PINNED - OrtMemTypeDefault = 0, ///< The default allocator for execution provider -} OrtMemType; - -/** \brief This mimics OrtDevice type constants so they can be returned in the API - */ -typedef enum OrtMemoryInfoDeviceType { - OrtMemoryInfoDeviceType_CPU = 0, - OrtMemoryInfoDeviceType_GPU = 1, - OrtMemoryInfoDeviceType_FPGA = 2 -} OrtMemoryInfoDeviceType; - -typedef enum OrtHardwareDeviceType { - OrtHardwareDeviceType_CPU, - OrtHardwareDeviceType_GPU, - OrtHardwareDeviceType_NPU -} OrtHardwareDeviceType; - -/** \brief These are the default EP selection policies used by ORT when doing automatic EP selection. - */ -typedef enum OrtExecutionProviderDevicePolicy { - OrtExecutionProviderDevicePolicy_DEFAULT, - OrtExecutionProviderDevicePolicy_PREFER_CPU, - OrtExecutionProviderDevicePolicy_PREFER_NPU, - OrtExecutionProviderDevicePolicy_PREFER_GPU, - OrtExecutionProviderDevicePolicy_MAX_PERFORMANCE, - OrtExecutionProviderDevicePolicy_MAX_EFFICIENCY, - OrtExecutionProviderDevicePolicy_MIN_OVERALL_POWER, -} OrtExecutionProviderDevicePolicy; - -/** \brief Delegate to allow providing custom OrtEpDevice selection logic - * - * This delegate is called by the EP selection code to allow the user to provide custom device selection logic. - * The user can use this to select OrtEpDevice instances from the list of available devices. - * - * \param ep_devices The list of available devices. - * \param num_devices The number of available devices. - * \param model_metadata The model metadata. - * \param runtime_metadata The runtime metadata. May be nullptr. - * \param selected Pre-allocated array to populate with selected OrtEpDevice pointers from ep_devices. - * \param max_selected The maximum number of devices that can be selected in the pre-allocated array. - Currently the maximum is 8. - * \param num_selected The number of selected devices. - * \param state Opaque pointer. Required to use the delegate from other languages like C# and python. - * - * \return OrtStatus* Selection status. Return nullptr on success. - * Use CreateStatus to provide error info. Use ORT_FAIL as the error code. - * ORT will release the OrtStatus* if not null. - */ -typedef OrtStatus*(ORT_API_CALL* EpSelectionDelegate)(_In_ const OrtEpDevice** ep_devices, - _In_ size_t num_devices, - _In_ const OrtKeyValuePairs* model_metadata, - _In_opt_ const OrtKeyValuePairs* runtime_metadata, - _Inout_ const OrtEpDevice** selected, - _In_ size_t max_selected, - _Out_ size_t* num_selected, - _In_ void* state); - -/** \brief Algorithm to use for cuDNN Convolution Op - */ -typedef enum OrtCudnnConvAlgoSearch { - OrtCudnnConvAlgoSearchExhaustive, // expensive exhaustive benchmarking using cudnnFindConvolutionForwardAlgorithmEx - OrtCudnnConvAlgoSearchHeuristic, // lightweight heuristic based search using cudnnGetConvolutionForwardAlgorithm_v7 - OrtCudnnConvAlgoSearchDefault, // default algorithm using CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM -} OrtCudnnConvAlgoSearch; - -/** \brief CUDA Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_CUDA - */ -typedef struct OrtCUDAProviderOptions { -#ifdef __cplusplus - OrtCUDAProviderOptions() - : device_id{}, - cudnn_conv_algo_search{OrtCudnnConvAlgoSearchExhaustive}, - gpu_mem_limit{SIZE_MAX}, - arena_extend_strategy{}, - do_copy_in_default_stream{1}, - has_user_compute_stream{}, - user_compute_stream{}, - default_memory_arena_cfg{}, - tunable_op_enable{false}, - tunable_op_tuning_enable{false}, - tunable_op_max_tuning_duration_ms{} {} -#endif - - /** \brief CUDA device Id - * Defaults to 0. - */ - int device_id; - - /** \brief CUDA Convolution algorithm search configuration. - * See enum OrtCudnnConvAlgoSearch for more details. - * Defaults to OrtCudnnConvAlgoSearchExhaustive. - */ - OrtCudnnConvAlgoSearch cudnn_conv_algo_search; - - /** \brief CUDA memory limit (To use all possible memory pass in maximum size_t) - * Defaults to SIZE_MAX. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ - size_t gpu_mem_limit; - - /** \brief Strategy used to grow the memory arena - * 0 = kNextPowerOfTwo
- * 1 = kSameAsRequested
- * Defaults to 0. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ - int arena_extend_strategy; - - /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the CUDA EP - * 0 = Use separate streams for copying and compute. - * 1 = Use the same stream for copying and compute. - * Defaults to 1. - * WARNING: Setting this to 0 may result in data races for some models. - * Please see issue #4829 for more details. - */ - int do_copy_in_default_stream; - - /** \brief Flag indicating if there is a user provided compute stream - * Defaults to 0. - */ - int has_user_compute_stream; - - /** \brief User provided compute stream. - * If provided, please set `has_user_compute_stream` to 1. - */ - void* user_compute_stream; - - /** \brief CUDA memory arena configuration parameters - */ - OrtArenaCfg* default_memory_arena_cfg; - - /** \brief Enable TunableOp for using. - * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. - * This option can be overridden by environment variable ORT_CUDA_TUNABLE_OP_ENABLE. - */ - int tunable_op_enable; - - /** \brief Enable TunableOp for tuning. - * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. - * This option can be overridden by environment variable ORT_CUDA_TUNABLE_OP_TUNING_ENABLE. - */ - int tunable_op_tuning_enable; - - /** \brief Max tuning duration time limit for each instance of TunableOp. - * Defaults to 0 to disable the limit. - */ - int tunable_op_max_tuning_duration_ms; - -} OrtCUDAProviderOptions; - -/** \brief ROCM Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_ROCM - */ -typedef struct OrtROCMProviderOptions { -#ifdef __cplusplus - OrtROCMProviderOptions() - : device_id{}, - miopen_conv_exhaustive_search{0}, - gpu_mem_limit{SIZE_MAX}, - arena_extend_strategy{}, - do_copy_in_default_stream{1}, - has_user_compute_stream{}, - user_compute_stream{}, - default_memory_arena_cfg{}, - enable_hip_graph{false}, - tunable_op_enable{false}, - tunable_op_tuning_enable{false}, - tunable_op_max_tuning_duration_ms{} {} -#endif - - /** \brief ROCM device Id - * Defaults to 0. - */ - int device_id; - - /** \brief ROCM MIOpen Convolution algorithm exhaustive search option. - * Defaults to 0 (false). - */ - int miopen_conv_exhaustive_search; - - /** \brief ROCM memory limit (To use all possible memory pass in maximum size_t) - * Defaults to SIZE_MAX. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ - size_t gpu_mem_limit; - - /** \brief Strategy used to grow the memory arena - * 0 = kNextPowerOfTwo
- * 1 = kSameAsRequested
- * Defaults to 0. - * \note If a ::OrtArenaCfg has been applied, it will override this field - */ - int arena_extend_strategy; - - /** \brief Flag indicating if copying needs to take place on the same stream as the compute stream in the ROCM EP - * 0 = Use separate streams for copying and compute. - * 1 = Use the same stream for copying and compute. - * Defaults to 1. - * WARNING: Setting this to 0 may result in data races for some models. - * Please see issue #4829 for more details. - */ - int do_copy_in_default_stream; - - /** \brief Flag indicating if there is a user provided compute stream - * Defaults to 0. - */ - int has_user_compute_stream; - - /** \brief User provided compute stream. - * If provided, please set `has_user_compute_stream` to 1. - */ - void* user_compute_stream; - - /** \brief ROCM memory arena configuration parameters - */ - OrtArenaCfg* default_memory_arena_cfg; - - int enable_hip_graph; - - /** \brief Enable TunableOp for using. - * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. - * This option can be overridden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. - */ - int tunable_op_enable; - - /** \brief Enable TunableOp for tuning. - * Set it to 1/0 to enable/disable TunableOp tuning. Otherwise, it is disabled by default. - * This option can be overridden by environment variable ORT_ROCM_TUNABLE_OP_TUNING_ENABLE. - */ - int tunable_op_tuning_enable; - - /** \brief Max tuning duration time limit for each instance of TunableOp. - * Defaults to 0 to disable the limit. - */ - int tunable_op_max_tuning_duration_ms; - -} OrtROCMProviderOptions; - -/** \brief TensorRT Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_TensorRT - */ -typedef struct OrtTensorRTProviderOptions { - int device_id; ///< CUDA device id (0 = default device) - int has_user_compute_stream; // indicator of user specified CUDA compute stream. - void* user_compute_stream; // user specified CUDA compute stream. - int trt_max_partition_iterations; // maximum iterations for TensorRT parser to get capability - int trt_min_subgraph_size; // minimum size of TensorRT subgraphs - size_t trt_max_workspace_size; // maximum workspace size for TensorRT. - int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true - int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true - const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name. - int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true - int trt_dla_enable; // enable DLA. Default 0 = false, nonzero = true - int trt_dla_core; // DLA core number. Default 0 - int trt_dump_subgraphs; // dump TRT subgraph. Default 0 = false, nonzero = true - int trt_engine_cache_enable; // enable engine caching. Default 0 = false, nonzero = true - const char* trt_engine_cache_path; // specify engine cache path - int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true - const char* trt_engine_decryption_lib_path; // specify engine decryption library path - int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true - // This is the legacy struct and don't add new fields here. - // For new field that can be represented by string, please add it in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h - // For non-string field, need to create a new separate api to handle it. -} OrtTensorRTProviderOptions; - -/** \brief MIGraphX Provider Options - * - * \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX - */ -typedef struct OrtMIGraphXProviderOptions { - int device_id; // hip device id. - int migraphx_fp16_enable; // MIGraphX FP16 precision. Default 0 = false, nonzero = true - int migraphx_int8_enable; // MIGraphX INT8 precision. Default 0 = false, nonzero = true - int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true - const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name - int migraphx_save_compiled_model; // migraphx save compiled model. Default 0 = false, noznero = true - const char* migraphx_save_model_path; // migraphx model path name - int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true - const char* migraphx_load_model_path; // migraphx model path name - bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false -} OrtMIGraphXProviderOptions; - -/** \brief OpenVINO Provider Options - * \brief This Struct is frozen since ORT 1.13.0. Its maintained part of Legacy API for compatibility. - * \brief For latest OpenVINO Provider Options update to the ProviderOptions map. - * \brief Latest OpenVINO Provider Options are listed in the - * \htmlonly - * onnxruntime document. - * \endhtmlonly - * \see OrtApi::SessionOptionsAppendExecutionProvider() - */ -typedef struct OrtOpenVINOProviderOptions { -#ifdef __cplusplus - OrtOpenVINOProviderOptions() : device_type{}, - enable_npu_fast_compile{}, - device_id{}, - num_of_threads{}, - cache_dir{}, - context{}, - enable_opencl_throttling{}, - enable_dynamic_shapes{} {} -#endif - /** \brief Device type string - * - * Valid settings are one of: "CPU_FP32", "CPU_FP16", "GPU_FP32", "GPU_FP16" - */ - const char* device_type; - unsigned char enable_npu_fast_compile; ///< 0 = disabled, nonzero = enabled - const char* device_id; - size_t num_of_threads; ///< 0 = Use default number of threads - const char* cache_dir; // path is set to empty by default - void* context; - unsigned char enable_opencl_throttling; ///< 0 = disabled, nonzero = enabled - unsigned char enable_dynamic_shapes; ///< 0 = disabled, nonzero = enabled -} OrtOpenVINOProviderOptions; - -struct OrtApi; -typedef struct OrtApi OrtApi; - -struct OrtTrainingApi; -typedef struct OrtTrainingApi OrtTrainingApi; - -struct OrtModelEditorApi; -typedef struct OrtModelEditorApi OrtModelEditorApi; - -struct OrtCompileApi; -typedef struct OrtCompileApi OrtCompileApi; - -struct OrtEpApi; -typedef struct OrtEpApi OrtEpApi; - -/** \brief The helper interface to get the right version of OrtApi - * - * Get a pointer to this structure through ::OrtGetApiBase - */ -struct OrtApiBase { - /** \brief Get a pointer to the requested version of the ::OrtApi - * - * \param[in] version Must be ::ORT_API_VERSION - * \return The ::OrtApi for the version requested, nullptr will be returned if this version is unsupported, for example when using a runtime - * older than the version created with this header file. - * - * One can call GetVersionString() to get the version of the Onnxruntime library for logging - * and error reporting purposes. - */ - const OrtApi*(ORT_API_CALL* GetApi)(uint32_t version)NO_EXCEPTION; - - /** \brief Returns a null terminated string of the version of the Onnxruntime library (eg: "1.8.1") - * - * \return UTF-8 encoded version string. Do not deallocate the returned buffer. - */ - const char*(ORT_API_CALL* GetVersionString)(void)NO_EXCEPTION; -}; - -typedef struct OrtApiBase OrtApiBase; - -/** \brief The Onnxruntime library's entry point to access the C API - * - * Call this to get the a pointer to an ::OrtApiBase - */ -ORT_EXPORT const OrtApiBase* ORT_API_CALL OrtGetApiBase(void) NO_EXCEPTION; - -/** \brief Thread work loop function - * - * Onnxruntime will provide the working loop on custom thread creation - * Argument is an onnxruntime built-in type which will be provided when thread pool calls OrtCustomCreateThreadFn - */ -typedef void (*OrtThreadWorkerFn)(void* ort_worker_fn_param); - -typedef const struct OrtCustomHandleType { - char __place_holder; -}* OrtCustomThreadHandle; - -/** \brief Ort custom thread creation function - * - * The function should return a thread handle to be used in onnxruntime thread pools - * Onnxruntime will throw exception on return value of nullptr or 0, indicating that the function failed to create a thread - */ -typedef OrtCustomThreadHandle (*OrtCustomCreateThreadFn)(void* ort_custom_thread_creation_options, OrtThreadWorkerFn ort_thread_worker_fn, void* ort_worker_fn_param); - -/** \brief Custom thread join function - * - * Onnxruntime thread pool destructor will call the function to join a custom thread. - * Argument ort_custom_thread_handle is the value returned by OrtCustomCreateThreadFn - */ -typedef void (*OrtCustomJoinThreadFn)(OrtCustomThreadHandle ort_custom_thread_handle); - -typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options, const OrtApiBase* api); - -/** \brief Callback function for RunAsync - * - * \param[in] user_data User specific data that passed back to the callback - * \param[out] outputs On succeed, outputs host inference results, on error, the value will be nullptr - * \param[out] num_outputs Number of outputs, on error, the value will be zero - * \param[out] status On error, status will provide details - */ -typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status); - -/** \brief The C API - * - * All C API functions are defined inside this structure as pointers to functions. - * Call OrtApiBase::GetApi to get a pointer to it - * - * \nosubgrouping - */ -struct OrtApi { - /// \name OrtStatus - /// @{ - - /** - * \brief Create an OrtStatus from a null terminated string - * - * \param[in] code - * \param[in] msg A null-terminated string. Its contents will be copied. - * \return A new OrtStatus object, must be destroyed with OrtApi::ReleaseStatus - */ - OrtStatus*(ORT_API_CALL* CreateStatus)(OrtErrorCode code, _In_ const char* msg)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - /** \brief Get OrtErrorCode from OrtStatus - * - * \param[in] status - * \return OrtErrorCode that \p status was created with - */ - OrtErrorCode(ORT_API_CALL* GetErrorCode)(_In_ const OrtStatus* status) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - /** \brief Get error string from OrtStatus - * - * \param[in] status - * \return The error message inside the `status`. Do not free the returned value. - */ - const char*(ORT_API_CALL* GetErrorMessage)(_In_ const OrtStatus* status)NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - /// @} - /// \name OrtEnv - /// @{ - - /** \brief Create an OrtEnv - * - * \note Invoking this function will return the same instance of the environment as that returned by a previous call - * to another env creation function; all arguments to this function will be ignored. - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateEnv, OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); - - /** \brief Create an OrtEnv - * - * \note Invoking this function will return the same instance of the environment as that returned by a previous call - * to another env creation function; all arguments to this function will be ignored. If you want to provide your - * own logging function, consider setting it using the SetUserLoggingFunction API instead. - * \param[in] logging_function A pointer to a logging function. - * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `logging_function`. This parameter is optional. - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateEnvWithCustomLogger, _In_ OrtLoggingFunction logging_function, _In_opt_ void* logger_param, - _In_ OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); - - /** \brief Enable Telemetry - * - * \note Telemetry events are on by default since they are lightweight - * \param[in] env - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(EnableTelemetryEvents, _In_ const OrtEnv* env); - /** \brief Disable Telemetry - * - * \see OrtApi::EnableTelemetryEvents - * \param[in] env - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(DisableTelemetryEvents, _In_ const OrtEnv* env); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief Create an OrtSession from a model file - * - * \param[in] env - * \param[in] model_path - * \param[in] options - * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - // TODO: document the path separator convention? '/' vs '\' - // TODO: should specify the access characteristics of model_path. Is this read only during the - // execution of CreateSession, or does the OrtSession retain a handle to the file/directory - // and continue to access throughout the OrtSession lifetime? - // What sort of access is needed to model_path : read or read/write? - ORT_API2_STATUS(CreateSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); - - /** \brief Create an OrtSession from memory - * - * \param[in] env - * \param[in] model_data - * \param[in] model_data_length - * \param[in] options - * \param[out] out Returned newly created OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSessionFromArray, _In_ const OrtEnv* env, - _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); - - /** \brief Run the model in an ::OrtSession - * - * Will not return until the model run has completed. Multiple threads might be used to run the model based on - * the options in the ::OrtSession and settings used when creating the ::OrtEnv - * - * \param[in] session - * \param[in] run_options If nullptr, will use a default ::OrtRunOptions - * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names - * \param[in] inputs Array of ::OrtValue%s of the input values - * \param[in] input_len Number of elements in the input_names and inputs arrays - * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names - * \param[in] output_names_len Number of elements in the output_names and outputs array - * \param[out] outputs Array of ::OrtValue%s that the outputs are stored in. This can also be - * an array of nullptr values, in this case ::OrtValue objects will be allocated and pointers - * to them will be set into the `outputs` array. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(Run, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options, - _In_reads_(input_len) const char* const* input_names, - _In_reads_(input_len) const OrtValue* const* inputs, size_t input_len, - _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, - _Inout_updates_all_(output_names_len) OrtValue** outputs); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Create an ::OrtSessionOptions object - * - * To use additional providers, you must build ORT with the extra providers enabled. Then call one of these - * functions to enable them in the session:
- * OrtSessionOptionsAppendExecutionProvider_CPU
- * OrtSessionOptionsAppendExecutionProvider_CUDA
- * OrtSessionOptionsAppendExecutionProvider_(remaining providers...)
- * The order they are called indicates the preference order as well. In other words call this method - * on your most preferred execution provider first followed by the less preferred ones. - * If none are called Ort will use its internal CPU execution provider. - * - * \param[out] options The newly created OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSessionOptions, _Outptr_ OrtSessionOptions** options); - - /** \brief Set filepath to save optimized model after graph level transformations - * - * \param[in] options - * \param[in] optimized_model_filepath - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetOptimizedModelFilePath, _Inout_ OrtSessionOptions* options, - _In_ const ORTCHAR_T* optimized_model_filepath); - - /** \brief Create a copy of an existing ::OrtSessionOptions - * - * \param[in] in_options OrtSessionOptions to copy - * \param[out] out_options Returned newly created ::OrtSessionOptions. Must be freed with OrtApi::ReleaseSessionOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CloneSessionOptions, _In_ const OrtSessionOptions* in_options, - _Outptr_ OrtSessionOptions** out_options); - - /** \brief Set execution mode - * - * Controls whether you want to execute operators in your graph sequentially or in parallel. Usually when the model - * has many branches, setting this option to ExecutionMode.ORT_PARALLEL will give you better performance. - * See [docs/ONNX_Runtime_Perf_Tuning.md] for more details. - * - * \param[in] options - * \param[in] execution_mode - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetSessionExecutionMode, _Inout_ OrtSessionOptions* options, ExecutionMode execution_mode); - - /** \brief Enable profiling for a session - * - * \param[in] options - * \param[in] profile_file_prefix - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(EnableProfiling, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* profile_file_prefix); - - /** \brief Disable profiling for a session - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(DisableProfiling, _Inout_ OrtSessionOptions* options); - - /** \brief Enable the memory pattern optimization - * - * The idea is if the input shapes are the same, we could trace the internal memory allocation - * and generate a memory pattern for future request. So next time we could just do one allocation - * with a big chunk for all the internal memory allocation. - * \note Memory pattern optimization is only available when Sequential Execution mode is enabled (see OrtApi::SetSessionExecutionMode) - * - * \see OrtApi::DisableMemPattern - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(EnableMemPattern, _Inout_ OrtSessionOptions* options); - - /** \brief Disable the memory pattern optimization - * - * \see OrtApi::EnableMemPattern - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(DisableMemPattern, _Inout_ OrtSessionOptions* options); - - /** \brief Enable the memory arena on CPU - * - * Arena may pre-allocate memory for future usage. - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(EnableCpuMemArena, _Inout_ OrtSessionOptions* options); - - /** \brief Disable the memory arena on CPU - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(DisableCpuMemArena, _Inout_ OrtSessionOptions* options); - - /** \brief Set session log id - * - * \param[in] options - * \param[in] logid The log identifier. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetSessionLogId, _Inout_ OrtSessionOptions* options, const char* logid); - - /** \brief Set session log verbosity level - * - * Applies to session load, initialization, etc - * - * \param[in] options - * \param[in] session_log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetSessionLogVerbosityLevel, _Inout_ OrtSessionOptions* options, int session_log_verbosity_level); - - /** \brief Set session log severity level - * - * \param[in] options - * \param[in] session_log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetSessionLogSeverityLevel, _Inout_ OrtSessionOptions* options, int session_log_severity_level); - - /** \brief Set the optimization level to apply when loading a graph - * - * Please see https://onnxruntime.ai/docs/performance/model-optimizations/graph-optimizations.html for an in-depth explanation - * \param[in,out] options The session options object - * \param[in] graph_optimization_level The optimization level - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetSessionGraphOptimizationLevel, _Inout_ OrtSessionOptions* options, - GraphOptimizationLevel graph_optimization_level); - - /** \brief Sets the number of threads used to parallelize the execution within nodes - * - * When running a single node operation, ex. add, this sets the maximum number of threads to use. - * - * \note If built with OpenMP, this has no effect on the number of threads used. In this case - * use the OpenMP env variables to configure the number of intra op num threads. - * - * \param[in] options - * \param[in] intra_op_num_threads Number of threads to use
- * A value of 0 will use the default number of threads
- * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetIntraOpNumThreads, _Inout_ OrtSessionOptions* options, int intra_op_num_threads); - - /** \brief Sets the number of threads used to parallelize the execution of the graph - * - * If nodes can be run in parallel, this sets the maximum number of threads to use to run them in parallel. - * - * \note If sequential execution is enabled this value is ignored, it acts as if it was set to 1. - * - * \param[in] options - * \param[in] inter_op_num_threads Number of threads to use
- * A value of 0 will use the default number of threads
- * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetInterOpNumThreads, _Inout_ OrtSessionOptions* options, int inter_op_num_threads); - - /// @} - /// \name OrtCustomOpDomain - /// @{ - - /** \brief Create a custom op domain - * - * \param[in] domain - * \param[out] out Newly created domain. Must be freed with OrtApi::ReleaseCustomOpDomain - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateCustomOpDomain, _In_ const char* domain, _Outptr_ OrtCustomOpDomain** out); - - /** \brief Add a custom op to a custom op domain - * - * \note The OrtCustomOp* pointer must remain valid until the ::OrtCustomOpDomain using it is released - * - * \param[in] custom_op_domain - * \param[in] op - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CustomOpDomain_Add, _Inout_ OrtCustomOpDomain* custom_op_domain, _In_ const OrtCustomOp* op); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Add custom op domain to a session options - * - * \note The OrtCustomOpDomain* must not be deleted until all sessions using it are released - * - * \param[in] options - * \param[in] custom_op_domain - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(AddCustomOpDomain, _Inout_ OrtSessionOptions* options, _In_ OrtCustomOpDomain* custom_op_domain); - - /** \deprecated Use OrtApi::RegisterCustomOpsLibrary_V2. - * - * Registers custom ops from a shared library. - * - * Loads a shared library (dll on windows, so on linux, etc) named 'library_path' and looks for this entry point: - * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); - * It then passes in the provided session options to this function along with the api base. - * The handle to the loaded library is returned in library_handle. It can be freed by the caller after all sessions using the passed in - * session options are destroyed, or if an error occurs and it is non null. - * - * \param[in] options - * \param[in] library_path - * \param[out] library_handle OS specific handle to the loaded library (Use FreeLibrary on Windows, dlclose on Linux, etc.. to unload) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RegisterCustomOpsLibrary, _Inout_ OrtSessionOptions* options, _In_ const char* library_path, _Outptr_ void** library_handle); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief Get input count for a session - * - * This number must also match the number of inputs passed to OrtApi::Run - * - * \see OrtApi::SessionGetInputTypeInfo, OrtApi::SessionGetInputName, OrtApi::Session - * - * \param[in] session - * \param[out] out Number of inputs - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetInputCount, _In_ const OrtSession* session, _Out_ size_t* out); - - /** \brief Get output count for a session - * - * This number must also match the number of outputs returned by OrtApi::Run - * - * \see OrtApi::SessionGetOutputTypeInfo, OrtApi::SessionGetOutputName, OrtApi::Session - * - * \param[in] session - * \param[out] out Number of outputs - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOutputCount, _In_ const OrtSession* session, _Out_ size_t* out); - - /** \brief Get overridable initializer count - * - * \see OrtApi::SessionGetOverridableInitializerTypeInfo, OrtApi::SessionGetOverridableInitializerName - * - * \param[in] session - * \param[in] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOverridableInitializerCount, _In_ const OrtSession* session, _Out_ size_t* out); - - /** \brief Get input type information - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) - * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetInputTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); - - /** \brief Get output type information - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) - * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOutputTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); - - /** \brief Get overridable initializer type information - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) - * \param[out] type_info Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOverridableInitializerTypeInfo, _In_ const OrtSession* session, size_t index, _Outptr_ OrtTypeInfo** type_info); - - /** \brief Get input name - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetInputCount returns (exclusive) - * \param[in] allocator - * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetInputName, _In_ const OrtSession* session, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /** \brief Get output name - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOutputCount returns (exclusive) - * \param[in] allocator - * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOutputName, _In_ const OrtSession* session, size_t index, _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /** \brief Get overridable initializer name - * - * \param[in] session - * \param[in] index Must be between 0 (inclusive) and what OrtApi::SessionGetOverridableInitializerCount returns (exclusive) - * \param[in] allocator - * \param[out] value Set to a null terminated UTF-8 encoded string allocated using `allocator`. Must be freed using `allocator`. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetOverridableInitializerName, _In_ const OrtSession* session, size_t index, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /// @} - /// \name OrtRunOptions - /// @{ - - /** \brief Create an OrtRunOptions - * - * \param[out] out Returned newly created ::OrtRunOptions. Must be freed with OrtApi::ReleaseRunOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateRunOptions, _Outptr_ OrtRunOptions** out); - - /** \brief Set per-run log verbosity level - * - * \see OrtApi::RunOptionsGetRunLogVerbosityLevel - * - * \param[in] options - * \param[in] log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RunOptionsSetRunLogVerbosityLevel, _Inout_ OrtRunOptions* options, int log_verbosity_level); - - /** \brief Set per-run log severity level - * - * \see OrtApi::RunOptionsGetRunLogSeverityLevel - * - * \param[in] options - * \param[in] log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). - */ - ORT_API2_STATUS(RunOptionsSetRunLogSeverityLevel, _Inout_ OrtRunOptions* options, int log_severity_level); - - /** \brief Set per-run tag - * - * This is used in a per-run log identifier. - * - * \see OrtApi::RunOptionsGetRunTag - * - * \param[in] options - * \param[in] run_tag The run tag. - */ - ORT_API2_STATUS(RunOptionsSetRunTag, _Inout_ OrtRunOptions* options, _In_ const char* run_tag); - - /** \brief Get per-run log verbosity level - * - * \see OrtApi::RunOptionsSetRunLogVerbosityLevel - * - * \param[in] options - * \param[out] log_verbosity_level \snippet{doc} snippets.dox Log Verbosity Level - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RunOptionsGetRunLogVerbosityLevel, _In_ const OrtRunOptions* options, - _Out_ int* log_verbosity_level); - - /** \brief Get per-run log severity level - * - * \see OrtApi::RunOptionsSetRunLogSeverityLevel - * - * \param[in] options - * \param[out] log_severity_level The log severity level (refer to ::OrtLoggingLevel for possible values). - */ - ORT_API2_STATUS(RunOptionsGetRunLogSeverityLevel, _In_ const OrtRunOptions* options, _Out_ int* log_severity_level); - - /** \brief Get per-run tag - * - * This is used in a per-run log identifier. - * - * \see OrtApi::RunOptionsSetRunTag - * - * \param[in] options - * \param[out] run_tag The run tag. - * Do not free this value, it is owned by `options`. It will be invalidated if the run tag - * changes (i.e., with OrtApi::RunOptionsSetRunTag) or `options` is freed. - */ - ORT_API2_STATUS(RunOptionsGetRunTag, _In_ const OrtRunOptions* options, _Out_ const char** run_tag); - - /** \brief Set terminate flag - * - * If a currently executing session needs to be force terminated, this can be called from another thread to force it to fail with an error. - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RunOptionsSetTerminate, _Inout_ OrtRunOptions* options); - - /** \brief Clears the terminate flag - * - * Used so the OrtRunOptions instance can be used in a new OrtApi::Run call without it instantly terminating - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RunOptionsUnsetTerminate, _Inout_ OrtRunOptions* options); - - /// @} - /// \name OrtValue - /// @{ - - /** \brief Create a tensor - * - * Create a tensor using a supplied ::OrtAllocator - * - * \param[in] allocator - * \param[in] shape Pointer to the tensor shape dimensions. - * \param[in] shape_len The number of tensor shape dimensions. - * \param[in] type - * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type, _Outptr_ OrtValue** out); - - /** \brief Create a tensor backed by a user supplied buffer - * - * Create a tensor with user's buffer. You can fill the buffer either before calling this function or after. - * p_data is owned by caller. ReleaseValue won't release p_data. - * - * If you wish to transfer ownership of p_data to ORT use CreateTensorWithDataAndDeleterAsOrtValue. - * - * \param[in] info Memory description of where the p_data buffer resides (CPU vs GPU etc). - * \param[in] p_data Pointer to the data buffer. - * \param[in] p_data_len The number of bytes in the data buffer. - * \param[in] shape Pointer to the tensor shape dimensions. - * \param[in] shape_len The number of tensor shape dimensions. - * \param[in] type The data type. - * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateTensorWithDataAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, - size_t p_data_len, _In_ const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type, - _Outptr_ OrtValue** out); - - /** \brief Return if an ::OrtValue is a tensor type - * - * \param[in] value A tensor type (string tensors are not supported) - * \param[out] out Set to 1 iff ::OrtValue is a tensor, 0 otherwise - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(IsTensor, _In_ const OrtValue* value, _Out_ int* out); - - /** \brief Get a pointer to the raw data inside a tensor - * - * Used to read/write/modify the internal tensor data directly. - * \note The returned pointer is valid until the \p value is destroyed. - * - * \param[in] value A tensor type (string tensors are not supported) - * \param[out] out Filled in with a pointer to the internal storage - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorMutableData, _In_ OrtValue* value, _Outptr_ void** out); - - /** \brief Set all strings at once in a string tensor - * - * \param[in,out] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING - * \param[in] s An array of strings. Each string in this array must be null terminated. - * \param[in] s_len Count of strings in s (Must match the size of \p value's tensor shape) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(FillStringTensor, _Inout_ OrtValue* value, _In_ const char* const* s, size_t s_len); - - /** \brief Get total byte length for all strings in a string tensor - * - * Typically used with OrtApi::GetStringTensorContent - * - * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING - * \param[out] len Total byte length of all strings (does not include trailing nulls) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetStringTensorDataLength, _In_ const OrtValue* value, _Out_ size_t* len); - - /** \brief Get all strings from a string tensor - * - * An example of the results:
- * Given \p value is a string tensor with the strings { "This" "is" "a" "test" }
- * \p s must have a size of 11 bytes
- * \p offsets must have 4 elements
- * After the call, these values will be filled in:
- * \p s will contain "Thisisatest"
- * \p offsets will contain { 0, 4, 6, 7 }
- * The length of the last string is just s_len - offsets[last] - * - * \param[in] value A tensor of type ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING - * \param[in] s Buffer to sequentially write all tensor strings to. Each string is NOT null-terminated. - * \param[in] s_len Number of bytes of buffer pointed to by \p s (Get it from OrtApi::GetStringTensorDataLength) - * \param[out] offsets Array of start offsets into the strings written to \p s - * \param[in] offsets_len Number of elements in offsets - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetStringTensorContent, _In_ const OrtValue* value, _Out_writes_bytes_all_(s_len) void* s, - size_t s_len, _Out_writes_all_(offsets_len) size_t* offsets, size_t offsets_len); - - /// @} - /// \name OrtTypeInfo - /// @{ - - /** \brief Get ::OrtTensorTypeAndShapeInfo from an ::OrtTypeInfo - * - * \param[in] type_info - * \param[out] out Do not free this value, it will be valid until type_info is freed. - * If type_info does not represent tensor, this value will be set to nullptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CastTypeInfoToTensorInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtTensorTypeAndShapeInfo** out); - - /** \brief Get ::ONNXType from ::OrtTypeInfo - * - * \param[in] type_info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetOnnxTypeFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ enum ONNXType* out); - - /// @} - /// \name OrtTensorTypeAndShapeInfo - /// @{ - - /** \brief Create an ::OrtTensorTypeAndShapeInfo object - * - * \param[out] out Returns newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateTensorTypeAndShapeInfo, _Outptr_ OrtTensorTypeAndShapeInfo** out); - - /** \brief Set element type in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[in] type - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetTensorElementType, _Inout_ OrtTensorTypeAndShapeInfo* info, enum ONNXTensorElementDataType type); - - /** \brief Set shape information in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[in] dim_values Array with `dim_count` elements. Can contain negative values. - * \param[in] dim_count Number of elements in `dim_values` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetDimensions, OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); - - /** \brief Get element type in ::OrtTensorTypeAndShapeInfo - * - * \see OrtApi::SetTensorElementType - * - * \param[in] info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorElementType, _In_ const OrtTensorTypeAndShapeInfo* info, - _Out_ enum ONNXTensorElementDataType* out); - - /** \brief Get dimension count in ::OrtTensorTypeAndShapeInfo - * - * \see OrtApi::GetDimensions - * - * \param[in] info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetDimensionsCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); - - /** \brief Get dimensions in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[out] dim_values Array with `dim_values_length` elements. On return, filled with the dimensions stored in the ::OrtTensorTypeAndShapeInfo - * \param[in] dim_values_length Number of elements in `dim_values`. Use OrtApi::GetDimensionsCount to get this value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, - size_t dim_values_length); - - /** \brief Get symbolic dimension names in ::OrtTensorTypeAndShapeInfo - * - * \param[in] info - * \param[in] dim_params Array with `dim_params_length` elements. On return filled with pointers to null terminated strings of the dimension names - * \param[in] dim_params_length Number of elements in `dim_params`. Use OrtApi::GetDimensionsCount to get this value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSymbolicDimensions, _In_ const OrtTensorTypeAndShapeInfo* info, - _Out_writes_all_(dim_params_length) const char* dim_params[], size_t dim_params_length); - - /** \brief Get total number of elements in a tensor shape from an ::OrtTensorTypeAndShapeInfo - * - * Return the number of elements specified by the tensor shape (all dimensions multiplied by each other). - * For 0 dimensions, 1 is returned. If any dimension is less than 0, the result is always -1. - * - * Examples:
- * [] = 1
- * [1,3,4] = 12
- * [2,0,4] = 0
- * [-1,3,4] = -1
- * - * \param[in] info - * \param[out] out Number of elements - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorShapeElementCount, _In_ const OrtTensorTypeAndShapeInfo* info, _Out_ size_t* out); - - /// @} - /// \name OrtValue - /// @{ - - /** \brief Get type and shape information from a tensor ::OrtValue - * - * \param[in] value Must be a tensor (not a map/sequence/etc) or will return failure - * \param[out] out Newly created ::OrtTensorTypeAndShapeInfo. Must be freed with OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorTypeAndShape, _In_ const OrtValue* value, _Outptr_ OrtTensorTypeAndShapeInfo** out); - - /** \brief Get type information of an OrtValue - * - * \param[in] value - * \param[out] out Newly created ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_maybenull_ OrtTypeInfo** out); - - /** \brief Get ONNXType of an ::OrtValue - * - * \param[in] value - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out); - - /// @} - /// \name OrtMemoryInfo - /// @{ - - /** \brief Create an ::OrtMemoryInfo - * - * \param[in] name - * \param[in] type - * \param[in] id - * \param[in] mem_type - * \param[out] out Newly created ::OrtMemoryInfo. Must be freed with OrtAPi::ReleaseMemoryInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateMemoryInfo, _In_ const char* name, enum OrtAllocatorType type, int id, - enum OrtMemType mem_type, _Outptr_ OrtMemoryInfo** out); - - /** \brief Create an ::OrtMemoryInfo for CPU memory - * - * Special case version of OrtApi::CreateMemoryInfo for CPU based memory. Same as using OrtApi::CreateMemoryInfo with name = "Cpu" and id = 0. - * - * \param[in] type - * \param[in] mem_type - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type, - _Outptr_ OrtMemoryInfo** out); - - /** \brief Compare ::OrtMemoryInfo objects for equality - * - * Compares all settings of each ::OrtMemoryInfo for equality - * - * \param[in] info1 - * \param[in] info2 - * \param[out] out Set to 0 if equal, -1 if not equal - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, _In_ const OrtMemoryInfo* info2, _Out_ int* out); - - /** \brief Get name from ::OrtMemoryInfo - * - * \param[in] ptr - * \param[out] out Writes null terminated string to this pointer. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtMemoryInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(MemoryInfoGetName, _In_ const OrtMemoryInfo* ptr, _Out_ const char** out); - - /** \brief Get the id from ::OrtMemoryInfo - */ - ORT_API2_STATUS(MemoryInfoGetId, _In_ const OrtMemoryInfo* ptr, _Out_ int* out); - - /** \brief Get the ::OrtMemType from ::OrtMemoryInfo - */ - ORT_API2_STATUS(MemoryInfoGetMemType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtMemType* out); - - /** \brief Get the ::OrtAllocatorType from ::OrtMemoryInfo - */ - ORT_API2_STATUS(MemoryInfoGetType, _In_ const OrtMemoryInfo* ptr, _Out_ OrtAllocatorType* out); - - /// @} - /// \name OrtAllocator - /// @{ - - /// \brief Calls OrtAllocator::Alloc function - ORT_API2_STATUS(AllocatorAlloc, _Inout_ OrtAllocator* ort_allocator, size_t size, _Outptr_ void** out); - /// \brief Calls OrtAllocator::Free function - ORT_API2_STATUS(AllocatorFree, _Inout_ OrtAllocator* ort_allocator, void* p); - /// \brief Calls OrtAllocator::Info function - ORT_API2_STATUS(AllocatorGetInfo, _In_ const OrtAllocator* ort_allocator, _Outptr_ const struct OrtMemoryInfo** out); - - /** \brief Get the default allocator - * - * The default allocator is a CPU based, non-arena. Always returns the same pointer to the same default allocator. - * - * \param[out] out Returned value should NOT be freed - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetAllocatorWithDefaultOptions, _Outptr_ OrtAllocator** out); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Override session symbolic dimensions - * - * Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable - * optimizations that can take advantage of fixed values (such as memory planning, etc) - * - * \param[in] options - * \param[in] dim_denotation - * \param[in] dim_value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* dim_denotation, - _In_ int64_t dim_value); - - /// @} - /// \name OrtValue - /// @{ - - /* Internal information (not seen in Doxygen) - * - * APIs to support non-tensor types - map and sequence. - * Currently only the following types are supported - * Note: the following types should be kept in sync with data_types.h - * Map types - * ========= - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * std::map - * - * Sequence types - * ============== - * std::vector - * std::vector - * std::vector - * std::vector - * std::vector> - * std::vector - */ - - /** \brief Get non tensor data from an ::OrtValue - * - * If `value` is of type ONNX_TYPE_MAP, you need to retrieve the keys and values - * separately. Use index=0 to retrieve keys and index=1 to retrieve values. - * If `value` is of type ONNX_TYPE_SEQUENCE, use index to retrieve the index'th element - * of the sequence. - * - * \param[in] value - * \param[in] index See above for usage based on `value` type - * \param[in] allocator Allocator used to allocate ::OrtValue - * \param[out] out Created ::OrtValue that holds the element requested. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetValue, _In_ const OrtValue* value, int index, _Inout_ OrtAllocator* allocator, - _Outptr_ OrtValue** out); - - /** \brief Get non tensor value count from an ::OrtValue - * - * If `value` is of type ONNX_TYPE_MAP 2 will always be returned. For ONNX_TYPE_SEQUENCE - * the number of elements in the sequence will be returned - * - * \param[in] value - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetValueCount, _In_ const OrtValue* value, _Out_ size_t* out); - - /** \brief Create a map or sequence ::OrtValue - * - * To construct a map (ONNX_TYPE_MAP), use num_values = 2 and `in` should be an array of 2 ::OrtValue%s - * representing keys and values.
- * - * To construct a sequence (ONNX_TYPE_SEQUENCE), use num_values = N where N is the number of the elements in the - * sequence. 'in' should be an array of N ::OrtValue%s. - * - * \param[in] in See above for details - * \param[in] num_values - * \param[in] value_type Must be either ONNX_TYPE_MAP or ONNX_TYPE_SEQUENCE - * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateValue, _In_reads_(num_values) const OrtValue* const* in, size_t num_values, - enum ONNXType value_type, _Outptr_ OrtValue** out); - - /** \brief Create an opaque (custom user defined type) ::OrtValue - * - * Constructs an ::OrtValue that contains a value of non-standard type created for - * experiments or while awaiting standardization. ::OrtValue in this case would contain - * an internal representation of the Opaque type. Opaque types are distinguished from - * each other by two strings 1) domain and 2) type name. The combination of the two - * must be unique, so the type representation is properly identified internally. The combination - * must be properly registered from within ORT at both compile/run time or by another API. - * - * To construct the ::OrtValue pass domain and type names, also a pointer to a data container - * the type of which must be known to both ORT and the client program. That data container may or may - * not match the internal representation of the Opaque type. The sizeof(data_container) is passed for - * verification purposes. - * - * \param[in] domain_name Null terminated string of the domain name - * \param[in] type_name Null terminated string of the type name - * \param[in] data_container User pointer Data to populate ::OrtValue - * \param[in] data_container_size Size in bytes of what `data_container` points to - * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateOpaqueValue, _In_z_ const char* domain_name, _In_z_ const char* type_name, - _In_ const void* data_container, size_t data_container_size, _Outptr_ OrtValue** out); - - /** \brief Get internal data from an opaque (custom user defined type) ::OrtValue - * - * Copies internal data from an opaque value into a user provided buffer - * - * \see OrtApi::CreateOpaqueValue - * - * \param[in] domain_name Null terminated string of the domain name - * \param[in] type_name Null terminated string of the type name - * \param[in] in The opaque ::OrtValue - * \param[out] data_container Buffer to copy data into - * \param[out] data_container_size Size in bytes of the buffer pointed to by data_container. Must match the size of the internal buffer. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetOpaqueValue, _In_ const char* domain_name, _In_ const char* type_name, _In_ const OrtValue* in, - _Out_ void* data_container, size_t data_container_size); - - /// @} - /// \name OrtKernelInfo - /// Custom operator APIs. - /// @{ - - /** \brief Get a float stored as an attribute in the graph node - * - * \param[in] info ::OrtKernelInfo instance - * \param[in] name Null terminated string of the name of the attribute - * \param[out] out Pointer to memory where the attribute will be stored - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttribute_float, _In_ const OrtKernelInfo* info, _In_ const char* name, - _Out_ float* out); - - /** \brief Fetch a 64-bit int stored as an attribute in the graph node - * - * \param[in] info ::OrtKernelInfo instance - * \param[in] name Null terminated string of the name of the attribute - * \param[out] out Pointer to memory where the attribute will be stored - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttribute_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, - _Out_ int64_t* out); - - /** \brief Fetch a string stored as an attribute in the graph node - * - * If `out` is nullptr, the value of `size` is set to the true size of the string - * attribute, and a success status is returned. - * - * If the `size` parameter is greater than or equal to the actual string attribute's size, - * the value of `size` is set to the true size of the string attribute, the provided memory - * is filled with the attribute's contents, and a success status is returned. - * - * If the `size` parameter is less than the actual string attribute's size and `out` - * is not nullptr, the value of `size` is set to the true size of the string attribute - * and a failure status is returned.) - * - * \param[in] info ::OrtKernelInfo instance - * \param[in] name Null terminated string of the name of the attribute - * \param[out] out Pointer to memory where the attribute will be stored - * \param[in,out] size See above comments for details - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttribute_string, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ char* out, - _Inout_ size_t* size); - - /// @} - /// \name OrtKernelContext - /// Custom operator APIs. - /// @{ - - /** \brief Used for custom operators, get the input count of a kernel - * - * \see ::OrtCustomOp - */ - ORT_API2_STATUS(KernelContext_GetInputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); - - /** \brief Used for custom operators, get the output count of a kernel - * - * \see ::OrtCustomOp - */ - ORT_API2_STATUS(KernelContext_GetOutputCount, _In_ const OrtKernelContext* context, _Out_ size_t* out); - - /** \brief Used for custom operators, get an input of a kernel - * - * The function attempts fetches the input of the kernel. If the input is optional - * and not present, the function returns success and out is set to nullptr. - * - * \param[in] context ::OrtKernelContext instance - * \param[in] index See KernelContext_GetInputCount for boundaries check. - * \param[out] out OrtValue if the input is present otherwise is set nullptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelContext_GetInput, _In_ const OrtKernelContext* context, _In_ size_t index, - _Out_ const OrtValue** out); - - /** \brief Used for custom operators, get an output of a kernel - * - * The function attempts fetches the output of the kernel. If the output is optional - * and not present, the function returns success and out is set to nullptr. - * - * \param[in] context ::OrtKernelContext instance - * \param[in] index See KernelContext_GetOutputCount for boundaries check. - * \param[in] dim_values output dimensions - * \param[in] dim_count number of dimensions - * \param[out] out a ptr to OrtValue to output otherwise set to nullptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelContext_GetOutput, _Inout_ OrtKernelContext* context, _In_ size_t index, - _In_ const int64_t* dim_values, size_t dim_count, _Outptr_ OrtValue** out); - - /// @} - /// \name OrtEnv - /// @{ - ORT_CLASS_RELEASE(Env); - /// @} - /// \name OrtStatus - /// @{ - ORT_CLASS_RELEASE(Status); - /// @} - /// \name OrtMemoryInfo - /// @{ - ORT_CLASS_RELEASE(MemoryInfo); - /// @} - /// \name OrtSession - /// @{ - ORT_CLASS_RELEASE(Session); // Don't call ReleaseSession from Dllmain (because session owns a thread pool) - /// @} - /// \name OrtValue - /// @{ - ORT_CLASS_RELEASE(Value); - /// @} - /// \name OrtRunOptions - /// @{ - ORT_CLASS_RELEASE(RunOptions); - /// @} - /// \name OrtTypeInfo - /// @{ - ORT_CLASS_RELEASE(TypeInfo); - /// @} - /// \name OrtTensorTypeAndShapeInfo - /// @{ - ORT_CLASS_RELEASE(TensorTypeAndShapeInfo); - /// @} - /// \name OrtSessionOptions - /// @{ - ORT_CLASS_RELEASE(SessionOptions); - /// @} - /// \name OrtCustomOpDomain - /// @{ - ORT_CLASS_RELEASE(CustomOpDomain); - - /// @} - /// \name OrtTypeInfo - /// @{ - - /** \brief Get denotation from type information - * - * Augments ::OrtTypeInfo to return denotations on the type. - * - * This is used by WinML to determine if an input/output is intended to be an Image or a Tensor. - * - * \param[in] type_info - * \param[out] denotation Pointer to the null terminated denotation string is written to this pointer. This pointer is valid until the object is destroyed or the name is changed, do not free. - * \param[out] len Length in bytes of the string returned in `denotation` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetDenotationFromTypeInfo, _In_ const OrtTypeInfo* type_info, _Out_ const char** const denotation, - _Out_ size_t* len); - - /** \brief Get detailed map information from an ::OrtTypeInfo - * - * This augments ::OrtTypeInfo to return an ::OrtMapTypeInfo when the type is a map. - * The OrtMapTypeInfo has additional information about the map's key type and value type. - * - * This is used by WinML to support model reflection APIs. - * - * \param[out] type_info - * \param[out] out A pointer to the ::OrtMapTypeInfo. Do not free this value. If type_info - * does not contain a map, this value will be set to nullptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CastTypeInfoToMapTypeInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtMapTypeInfo** out); - - /** \brief Cast ::OrtTypeInfo to an ::OrtSequenceTypeInfo - * - * This api augments ::OrtTypeInfo to return an ::OrtSequenceTypeInfo when the type is a sequence. - * The ::OrtSequenceTypeInfo has additional information about the sequence's element type. - * - * This is used by WinML to support model reflection APIs. - * - * \param[in] type_info - * \param[out] out A pointer to the OrtSequenceTypeInfo. Do not free this value. If type_info - * doesn not contain a sequence, this value will be set to nullptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CastTypeInfoToSequenceTypeInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtSequenceTypeInfo** out); - - /// @} - /// \name OrtMapTypeInfo - /// @{ - - /** \brief Get key type from an ::OrtMapTypeInfo - * - * Key types are restricted to being scalar types. - * - * This is used by WinML to support model reflection APIs. - * - * \param[in] map_type_info - * \param[out] out - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetMapKeyType, _In_ const OrtMapTypeInfo* map_type_info, _Out_ enum ONNXTensorElementDataType* out); - - /** \brief Get the value type from an ::OrtMapTypeInfo - * - * \param[in] map_type_info - * \param[out] type_info A copy of the OrtTypeInfo for the map value type. - * The user must free this value with ReleaseTypeInfo. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetMapValueType, _In_ const OrtMapTypeInfo* map_type_info, _Outptr_ OrtTypeInfo** type_info); - - /// @} - /// \name OrtSequenceTypeInfo - /// @{ - - /** \brief Get element type from an ::OrtSequenceTypeInfo - * - * This is used by WinML to support model reflection APIs. - * - * \param[in] sequence_type_info - * \param[out] type_info A copy of the OrtTypeInfo for the sequence element type. - * The user must free this value with ReleaseTypeInfo. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSequenceElementType, _In_ const OrtSequenceTypeInfo* sequence_type_info, - _Outptr_ OrtTypeInfo** type_info); - - /// @} - /// \name OrtMapTypeInfo - /// @{ - ORT_CLASS_RELEASE(MapTypeInfo); - /// @} - /// \name OrtSequenceTypeInfo - /// @{ - ORT_CLASS_RELEASE(SequenceTypeInfo); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief End profiling and return filename of the profile data - * - * Profiling is turned on through OrtApi::EnableProfiling - * - * \param[in] session - * \param[in] allocator - * \param[out] out Null terminated string of the filename, allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionEndProfiling, _In_ OrtSession* session, _Inout_ OrtAllocator* allocator, _Outptr_ char** out); - - /** \brief Get ::OrtModelMetadata from an ::OrtSession - * - * \param[in] session - * \param[out] out Newly created ::OrtModelMetadata. Must be freed using OrtApi::ReleaseModelMetadata - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetModelMetadata, _In_ const OrtSession* session, _Outptr_ OrtModelMetadata** out); - - /// @} - /// \name OrtModelMetadata - /// @{ - - /** \brief Get `producer name` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetProducerName, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /** \brief Get `graph name` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetGraphName, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /** \brief Get `domain` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetDomain, _In_ const OrtModelMetadata* model_metadata, _Inout_ OrtAllocator* allocator, - _Outptr_ char** value); - - /** \brief Get `description` from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetDescription, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /** \brief Return data for a key in the custom metadata map in an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[in] allocator - * \param[in] key Null terminated string - * \param[out] value Set to a null terminated string allocated using `allocator`. Must be freed using `allocator` - * `value` will be set to nullptr if the given key is not found in the custom metadata map. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataLookupCustomMetadataMap, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _In_ const char* key, _Outptr_result_maybenull_ char** value); - - /** \brief Get version number from an ::OrtModelMetadata - * - * \param[in] model_metadata - * \param[out] value Set to the version number - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetVersion, _In_ const OrtModelMetadata* model_metadata, _Out_ int64_t* value); - - ORT_CLASS_RELEASE(ModelMetadata); - - /// @} - /// \name OrtEnv - /// @{ - - /** \brief Create an OrtEnv - * - * Create an environment with global threadpools that will be shared across sessions. - * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use - * its own thread pools. - * - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[in] tp_options - * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateEnvWithGlobalThreadPools, OrtLoggingLevel log_severity_level, _In_ const char* logid, - _In_ const OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Use global thread pool on a session - * - * Disable using per session thread pool and use the shared global threadpool. - * This should be used in conjunction with OrtApi::CreateEnvWithGlobalThreadPools. - * - * \param[in] options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(DisablePerSessionThreads, _Inout_ OrtSessionOptions* options); - - /// @} - /// \name OrtThreadingOptions - /// @{ - - /** \brief Create an ::OrtThreadingOptions - * - * \param[out] out Newly created ::OrtThreadingOptions. Must be freed with OrtApi::ReleaseThreadingOptions - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateThreadingOptions, _Outptr_ OrtThreadingOptions** out); - - ORT_CLASS_RELEASE(ThreadingOptions); - - /// @} - /// \name OrtModelMetadata - /// @{ - - /** - * - * \param[in] model_metadata - * \param[in] allocator - * \param[out] keys Array of null terminated strings (array count = num_keys) allocated using `allocator`. - * The strings and the pointer array must be freed using `allocator` - * `keys` will be set to nullptr if the custom metadata map is empty. - * \param[out] num_keys Set to the number of elements in the `keys` array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetCustomMetadataMapKeys, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_result_buffer_maybenull_(*num_keys) char*** keys, _Out_ int64_t* num_keys); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** - * - * Override symbolic dimensions (by specific name strings) with actual values - * if known at session initialization time to enable optimizations that can - * take advantage of fixed values (such as memory planning, etc) - * - */ - ORT_API2_STATUS(AddFreeDimensionOverrideByName, - _Inout_ OrtSessionOptions* options, _In_ const char* dim_name, - _In_ int64_t dim_value); - - /// @} - /// \name Misc - /// @{ - - /** \brief Get the names of all available providers - * - * \note The providers in the list are not guaranteed to be usable. They may fail to load due to missing system dependencies. - * For example, if the CUDA/cuDNN libraries are not installed, the CUDA provider will report an error when it is added to the session options. - * - * \param[out] out_ptr Set to a pointer to an array of null terminated strings of the available providers. The entries and the - * array itself must be freed using OrtApi::ReleaseAvailableProviders - * \param[out] provider_length Set to the number of entries in the `out_ptr` array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetAvailableProviders, _Outptr_ char*** out_ptr, _Out_ int* provider_length); - - /** \brief Release data from OrtApi::GetAvailableProviders. This API will never fail - * so you can rely on it in a noexcept code. - * - * \param[in] ptr The `out_ptr` result from OrtApi::GetAvailableProviders. - * \param[in] providers_length The `provider_length` result from OrtApi::GetAvailableProviders - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ReleaseAvailableProviders, _In_ char** ptr, - _In_ int providers_length); - - /// @} - /// \name OrtValue - /// @{ - - /** \brief Get the length of a single string in a string tensor - * - * \param[in] value A string tensor - * \param[in] index Index of the string in the tensor - * \param[out] out Set to number of bytes of the string element - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetStringTensorElementLength, _In_ const OrtValue* value, size_t index, _Out_ size_t* out); - - /** \brief Get a single string from a string tensor - * - * \param[in] value A string tensor - * \param[in] s_len Number of bytes in the `s` buffer. Must match the value returned by OrtApi::GetStringTensorElementLength. - * \param[in] index Index of the string in the tensor - * \param[out] s The string element contents in UTF-8 encoding. The string is NOT null-terminated. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetStringTensorElement, _In_ const OrtValue* value, size_t s_len, size_t index, _Out_writes_bytes_all_(s_len) void* s); - - /** \brief Set a single string in a string tensor - * - * \param[in] value A string tensor - * \param[in] s A null terminated UTF-8 encoded string - * \param[in] index Index of the string in the tensor to set - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(FillStringTensorElement, _Inout_ OrtValue* value, _In_ const char* s, size_t index); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Set a session configuration entry as a pair of strings - * - * If a configuration with same key exists, this will overwrite the configuration with the given config_value. - * - * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h - * - * \param[in] options - * \param[in] config_key A null terminated string representation of the config key - * \param[in] config_value A null terminated string representation of the config value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(AddSessionConfigEntry, _Inout_ OrtSessionOptions* options, - _In_z_ const char* config_key, _In_z_ const char* config_value); - - /// @} - /// \name OrtAllocator - /// @{ - - /** \brief Create an allocator for an ::OrtSession following an ::OrtMemoryInfo - * - * \param[in] session - * \param[in] mem_info valid ::OrtMemoryInfo instance - * \param[out] out Newly created ::OrtAllocator. Must be freed with OrtApi::ReleaseAllocator - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateAllocator, _In_ const OrtSession* session, _In_ const OrtMemoryInfo* mem_info, - _Outptr_ OrtAllocator** out); - - /** \brief Release an ::OrtAllocator obtained from OrtApi::CreateAllocator - */ - ORT_CLASS_RELEASE(Allocator); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief Run a model using Io Bindings for the inputs & outputs - * - * \see OrtApi::Run - * - * \param[in] session - * \param[in] run_options - * \param[in] binding_ptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RunWithBinding, _Inout_ OrtSession* session, _In_ const OrtRunOptions* run_options, _In_ const OrtIoBinding* binding_ptr); - - /** \brief Create an ::OrtIoBinding instance - * - * An IoBinding object allows one to bind pre-allocated ::OrtValue%s to input names. - * Thus if you want to use a raw on device buffer as input or output you can avoid - * extra copy during runtime. - * - * \param[in] session - * \param[out] out Newly created ::OrtIoBinding. Must be freed with OrtApi::ReleaseIoBinding - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateIoBinding, _Inout_ OrtSession* session, _Outptr_ OrtIoBinding** out); - - /// @} - /// \name OrtIoBinding - /// @{ - - /** \brief Release an ::OrtIoBinding obtained from OrtApi::CreateIoBinding - */ - ORT_CLASS_RELEASE(IoBinding); - - /** \brief Bind an ::OrtValue to an ::OrtIoBinding input - * - * When using OrtApi::RunWithBinding this value is used for the named input - * - * \param[in] binding_ptr - * \param[in] name Name for the model input - * \param[in] val_ptr ::OrtValue of Tensor type. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(BindInput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); - - /** \brief Bind an ::OrtValue to an ::OrtIoBinding output - * - * When using OrtApi::RunWithBinding this value is used for the named output - * - * \param[in] binding_ptr - * \param[in] name Null terminated string of the model output name - * \param[in] val_ptr ::OrtValue of Tensor type. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(BindOutput, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtValue* val_ptr); - - /** \brief Bind an ::OrtIoBinding output to a device - * - * Binds the ::OrtValue to a device which is specified by ::OrtMemoryInfo. - * You can either create an instance of ::OrtMemoryInfo with a device id or obtain one from the allocator that you have created/are using - * This is useful when one or more outputs have dynamic shapes and, it is hard to pre-allocate and bind a chunk of - * memory within ::OrtValue ahead of time. - * - * \see OrtApi::RunWithBinding - * - * \param[in] binding_ptr - * \param[in] name Null terminated string of the device name - * \param[in] mem_info_ptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(BindOutputToDevice, _Inout_ OrtIoBinding* binding_ptr, _In_ const char* name, _In_ const OrtMemoryInfo* mem_info_ptr); - - /** \brief Get the names of an ::OrtIoBinding's outputs - * - * Returns the names of the outputs in the order they were bound. This is useful after running the model - * with bound outputs because the returned names are in order in which output ::OrtValue are returned. This is useful if - * the order of outputs and their names is not known. - * - * \param[in] binding_ptr - * \param[in] allocator Allocator used to allocate continuous buffers for output strings and lengths. - * \param[out] buffer Returns an array of non-null terminated UTF-8 strings. The number of strings stored is returned in the count parameter. - * This buffer is allocated using `allocator` and must be freed using it. - * \param[out] lengths Returns an array of `count` lengths of the strings returned in `buffer` - * This buffer is allocated using `allocator` and must be freed using it. - * \param[out] count Number of strings returned. If `binding_ptr` has no bound outputs, zero is returned, - * no memory allocation is performed and buffer and lengths are set to nullptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetBoundOutputNames, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, - _Out_ char** buffer, _Out_writes_all_(count) size_t** lengths, _Out_ size_t* count); - - /** \brief Get the output ::OrtValue objects from an ::OrtIoBinding - * - * Returns an array of pointers to individually allocated ::OrtValue%s that contain results of a model execution with OrtApi::RunWithBinding - * The array contains the same number of ::OrtValue%s and they are in the same order as they were bound with OrtApi::BindOutput - * or OrtApi::BindOutputToDevice. - * - * The returned ::OrtValue%s must be released using OrtApi::ReleaseValue after they are no longer needed. - * The array is allocated using the specified instance of the allocator and must be freed using the same allocator after - * all the ::OrtValue%s contained therein are individually released. - * - * \param[in] binding_ptr - * \param[in] allocator Allocator used to allocate output array - * \param[out] output Set to the allocated array of allocated ::OrtValue outputs. Set to nullptr if there are 0 outputs. - * \param[out] output_count Set to number of ::OrtValue%s returned - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetBoundOutputValues, _In_ const OrtIoBinding* binding_ptr, _In_ OrtAllocator* allocator, - _Out_writes_all_(output_count) OrtValue*** output, _Out_ size_t* output_count); - - /** \brief Clears any previously set Inputs for an ::OrtIoBinding - */ - void(ORT_API_CALL* ClearBoundInputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - /** \brief Clears any previously set Outputs for an ::OrtIoBinding - */ - void(ORT_API_CALL* ClearBoundOutputs)(_Inout_ OrtIoBinding* binding_ptr) NO_EXCEPTION ORT_ALL_ARGS_NONNULL; - - /// @} - /// \name OrtValue - /// @{ - - /** \brief Direct memory access to a specified tensor element - * - * For example, given a tensor with shape of [3,224,224], a pointer to the element at location [2,150,128] can be retrieved - * - * This function only works for numeric type tensors (No strings, etc). - * This is a no-copy method whose returned pointer is valid until the passed in ::OrtValue is free'd. - * - * \param[in] value - * \param[in] location_values Pointer to an array of index values that specify an element's location relative to its shape - * \param[in] location_values_count Number of elements in location_values. Must match the number of elements in the tensor's shape. - * \param[out] out Set to a pointer to the element specified - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(TensorAt, _Inout_ OrtValue* value, const int64_t* location_values, size_t location_values_count, _Outptr_ void** out); - - /// @} - /// \name OrtEnv - /// @{ - - /** \brief Create an allocator and register it with the ::OrtEnv - * - * Enables sharing the allocator between multiple sessions that use the same env instance. - * Lifetime of the created allocator will be valid for the duration of the environment. - * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. - * - * See https://onnxruntime.ai/docs/get-started/with-c.html for details. - * - * \param[in] env ::OrtEnv instance - * \param[in] mem_info - * \param[in] arena_cfg Pass nullptr for defaults - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateAndRegisterAllocator, _Inout_ OrtEnv* env, _In_ const OrtMemoryInfo* mem_info, - _In_ const OrtArenaCfg* arena_cfg); - - /** \brief Set language projection - * - * Set the language projection for collecting telemetry data when Env is created. - * - * The default is ORT_PROJECTION_C, which means it will classify the language not in the list to C also. - * - * \param[in] ort_env - * \param[in] projection - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetLanguageProjection, _In_ const OrtEnv* ort_env, _In_ OrtLanguageProjection projection); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief Return the time that profiling was started - * - * \note The timer precision varies per platform. On Windows and MacOS, the precision will be ~100ns - * - * \param[in] session - * \param[out] out nanoseconds of profiling's start time - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionGetProfilingStartTimeNs, _In_ const OrtSession* session, _Outptr_ uint64_t* out); - - /// @} - /// \name OrtThreadingOptions - /// @{ - - /** \brief Set global intra-op thread count - * - * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools - * - * \param[in] tp_options - * \param[in] intra_op_num_threads Number of threads, special values:
- * 0 = Use default thread count
- * 1 = The invoking thread will be used; no threads will be created in the thread pool. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalIntraOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int intra_op_num_threads); - - /** \brief Set global inter-op thread count - * - * This configures the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools - * - * \param[in] tp_options - * \param[in] inter_op_num_threads Number of threads, special values:
- * 0 = Use default thread count
- * 1 = The invoking thread will be used; no threads will be created in the thread pool. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalInterOpNumThreads, _Inout_ OrtThreadingOptions* tp_options, int inter_op_num_threads); - - /** \brief Set global spin control options - * - * This will configure the global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. - * Allow spinning of thread pools when their queues are empty. This will set the value for both - * inter_op and intra_op threadpools. - * - * \param[in] tp_options - * \param[in] allow_spinning Valid values are 0 or 1.
- * 0 = It won't spin (recommended if CPU usage is high)
- * 1 = Threadpool will spin to wait for queue to become non-empty - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalSpinControl, _Inout_ OrtThreadingOptions* tp_options, int allow_spinning); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Add a pre-allocated initializer to a session - * - * If a model contains an initializer with a name that is same as the name passed to this call, - * ORT will use this initializer instance instead of deserializing one from the model file. This - * is useful when you want to share the same initializer across sessions. - * - * \param[in] options - * \param[in] name Null terminated string of the initializer name - * \param[in] val ::OrtValue containing the initializer. Its lifetime and the underlying initializer buffer must be - * managed by the user (created using the OrtApi::CreateTensorWithDataAsOrtValue) and it must outlive the session object - * to which it is added. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(AddInitializer, _Inout_ OrtSessionOptions* options, _In_z_ const char* name, - _In_ const OrtValue* val); - - /// @} - /// \name OrtEnv - /// @{ - - /** - * Create a custom environment with global threadpools and logger that will be shared across sessions. - * Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use - * its own thread pools. - * - * \param[in] logging_function A pointer to a logging function. - * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `logging_function`. - * \param[in] log_severity_level The log severity level. - * \param[in] logid The log identifier. - * \param[in] tp_options - * \param[out] out Newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateEnvWithCustomLoggerAndGlobalThreadPools, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, OrtLoggingLevel log_severity_level, - _In_ const char* logid, _In_ const struct OrtThreadingOptions* tp_options, _Outptr_ OrtEnv** out); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Append CUDA provider to session options - * - * If CUDA is not available (due to a non CUDA enabled build, or if CUDA is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] cuda_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA, - _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptions* cuda_options); - - /** \brief Append ROCM execution provider to the session options - * - * If ROCM is not available (due to a non ROCM enabled build, or if ROCM is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] rocm_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_ROCM, - _In_ OrtSessionOptions* options, _In_ const OrtROCMProviderOptions* rocm_options); - - /** \brief Append OpenVINO execution provider to the session options - * - * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail. - * - * \param[in] options - * \param[in] provider_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO, - _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options); - - /// @} - /// \name OrtThreadingOptions - /// @{ - - /** \brief Set threading flush-to-zero and denormal-as-zero - * - * Sets global thread pool options to be used in the call to OrtApi::CreateEnvWithGlobalThreadPools. - * Flush-to-zero and denormal-as-zero are applied to threads in both intra and inter global thread pool. - * \note This option is not needed if the models used have no denormals. Having no denormals is recommended as this option may hurt model accuracy. - * - * \param[in] tp_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalDenormalAsZero, _Inout_ OrtThreadingOptions* tp_options); - - /// @} - /// \name OrtArenaCfg - /// @{ - - /** \deprecated Use OrtApi::CreateArenaCfgV2 - * - * This will create the configuration of an arena that can eventually be used to define an arena based allocator's behavior - * - * \param[in] max_mem Use 0 to allow ORT to choose the default - * \param[in] arena_extend_strategy Use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested - * \param[in] initial_chunk_size_bytes Use -1 to allow ORT to choose the default - * \param[in] max_dead_bytes_per_chunk Use -1 to allow ORT to choose the default - * \param[in] out A pointer to an OrtArenaCfg instance - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateArenaCfg, _In_ size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, - int max_dead_bytes_per_chunk, _Outptr_ OrtArenaCfg** out); - - ORT_CLASS_RELEASE(ArenaCfg); - - /// @} - /// \name OrtModelMetadata - /// @{ - - /** - * Use this to obtain the description of the graph present in the model - * (doc_string field of the GraphProto message within the ModelProto message). - * If it doesn't exist, an empty string will be returned. - * - * \param[in] model_metadata An instance of ::OrtModelMetadata - * \param[in] allocator Allocator used to allocate the string that will be returned back - * \param[out] value Set to a null terminated string allocated using `allocator`. The caller is responsible for freeing it using `allocator` - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(ModelMetadataGetGraphDescription, _In_ const OrtModelMetadata* model_metadata, - _Inout_ OrtAllocator* allocator, _Outptr_ char** value); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Append TensorRT provider to session options - * - * If TensorRT is not available (due to a non TensorRT enabled build, or if TensorRT is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] tensorrt_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_TensorRT, - _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options); - - /// @} - /// \name Misc - /// @{ - - /** \brief Set current GPU device ID - * - * Set the current device id of the GPU execution provider (CUDA/tensorrt/rocm). The device id should be less - * than the total number of devices available. This is only useful when multiple-GPUs are installed and it is - * required to restrict execution to a single GPU. - * - * \param[in] device_id - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetCurrentGpuDeviceId, _In_ int device_id); - - /** \brief Get current GPU device ID - * - * Get the current device id of the GPU execution provider (CUDA/tensorrt/rocm). - * - * \see OrtApi::SetCurrentGpuDeviceId - * - * \param[out] device_id - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetCurrentGpuDeviceId, _In_ int* device_id); - - /// @} - /// \name OrtKernelInfo - /// Custom operator APIs. - /// @{ - - /** \brief Fetch an array of int64_t values stored as an attribute in the graph node - * - * - * If `out` is nullptr, the value of `size` is set to the true size of the attribute - * array's size, and a success status is returned. - * - * If the `size` parameter is greater than or equal to the actual attribute array's size, - * the value of `size` is set to the true size of the attribute array's size, - * the provided memory is filled with the attribute's contents, - * and a success status is returned. - * - * If the `size` parameter is less than the actual attribute array's size and `out` - * is not nullptr, the value of `size` is set to the true size of the attribute array's size - * and a failure status is returned.) - * - * \param[in] info instance - * \param[in] name name of the attribute to be parsed - * \param[out] out pointer to memory where the attribute's contents are to be stored - * \param[in, out] size actual size of attribute array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, - _Out_ float* out, _Inout_ size_t* size); - - /** \brief Fetch an array of int64_t values stored as an attribute in the graph node - * - * If `out` is nullptr, the value of `size` is set to the true size of the attribute - * array's size, and a success status is returned. - * - * If the `size` parameter is greater than or equal to the actual attribute array's size, - * the value of `size` is set to the true size of the attribute array's size, - * the provided memory is filled with the attribute's contents, - * and a success status is returned. - * - * If the `size` parameter is less than the actual attribute array's size and `out` - * is not nullptr, the value of `size` is set to the true size of the attribute array's size - * and a failure status is returned.) - * - * \param[in] info instance - * \param[in] name name of the attribute to be parsed - * \param[out] out pointer to memory where the attribute's contents are to be stored - * \param[in, out] size actual size of attribute array - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttributeArray_int64, _In_ const OrtKernelInfo* info, _In_ const char* name, - _Out_ int64_t* out, _Inout_ size_t* size); - - /// @} - /// \name OrtArenaCfg - /// @{ - - /** \brief Create an ::OrtArenaCfg - * - * Create the configuration of an arena that can eventually be used to define an arena based allocator's behavior. - * - * Supported keys are (See https://onnxruntime.ai/docs/get-started/with-c.html for details on what the - * following parameters mean and how to choose these values.): - * "max_mem": Maximum memory that can be allocated by the arena based allocator. - * Use 0 for ORT to pick the best value. Default is 0. - * "arena_extend_strategy": 0 = kNextPowerOfTwo, 1 = kSameAsRequested. - * Use -1 to allow ORT to choose the default. - * "initial_chunk_size_bytes": (Possible) Size of the first allocation in the arena. - * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. - * Ultimately, the first allocation size is determined by the allocation memory request. - * "max_dead_bytes_per_chunk": Threshold of unused memory in an allocated chunk of arena memory after - * crossing which the current chunk is chunked into 2. - * "initial_growth_chunk_size_bytes": (Possible) Size of the second allocation in the arena. - * Only relevant if arena strategy is `kNextPowerOfTwo`. Use -1 to allow ORT to choose the default. - * "max_power_of_two_extend_bytes": The maximum enxtend size if arena strategy is `kNextPowerOfTwo`. - * It is not an allocation limit, it is only a limit for extension when requested byte is less than the limit. - * When requested bytes is more than the limit, allocator will still return as requested. - * Use -1 to allow ORT to choose the default 1GB for max_power_of_two_extend_bytes. - * Ultimately, the allocation size is determined by the allocation memory request. - * Further allocation sizes are governed by the arena extend strategy. - * - * \param[in] arena_config_keys Keys to configure the arena - * \param[in] arena_config_values Values to configure the arena - * \param[in] num_keys Number of keys in `arena_config_keys` and `arena_config_values` - * \param[out] out Newly created ::OrtArenaCfg. Must be freed with OrtApi::ReleaseArenaCfg - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateArenaCfgV2, _In_reads_(num_keys) const char* const* arena_config_keys, - _In_reads_(num_keys) const size_t* arena_config_values, _In_ size_t num_keys, - _Outptr_ OrtArenaCfg** out); - - /// @} - /// \name OrtRunOptions - /// @{ - - /** \brief Set a single run configuration entry as a pair of strings - * - * If a configuration with same key exists, this will overwrite the configuration with the given config_value - * - * The config_key and the format of config_value are defined in onnxruntime_run_options_config_keys.h - * - * \param[in] options - * \param[in] config_key A null terminated string representation of the config key - * \param[in] config_value A null terminated string representation of the config value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(AddRunConfigEntry, _Inout_ OrtRunOptions* options, - _In_z_ const char* config_key, _In_z_ const char* config_value); - - /// @} - /// \name OrtPrepackedWeightsContainer - /// @{ - - /** \brief Create an ::OrtPrepackedWeightsContainer - * - * This container will hold pre-packed buffers of shared initializers for sharing between sessions - * (i.e.) if there are shared initializers that can be shared between sessions, the pre-packed buffers - * of these (if any) may possibly be shared to provide memory footprint savings. Pass this container - * to sessions that you would like to share pre-packed buffers of shared initializers at session - * creation time. - * - * \param[out] out Newly created ::OrtPrepackedWeightsContainer. Must be freed with OrtApi::ReleasePrepackedWeightsContainer - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreatePrepackedWeightsContainer, _Outptr_ OrtPrepackedWeightsContainer** out); - - /** \brief Release OrtPrepackedWeightsContainer instance - * - * \note instance must not be released until the sessions using it are released - */ - ORT_CLASS_RELEASE(PrepackedWeightsContainer); - - /// @} - /// \name OrtSession - /// @{ - - /** \brief Create session with prepacked weights container - * - * Same functionality offered by OrtApi::CreateSession except that a container that contains - * pre-packed weights' buffers is written into/read from by the created session. - * This is useful when used in conjunction with OrtApi::AddInitializer which injects - * shared initializer info into sessions. Wherever possible, the pre-packed versions of these - * shared initializers are cached in this container so that multiple sessions can just re-use - * these instead of duplicating these in memory. - * - * \param[in] env OrtEnv instance instance - * \param[in] model_path Null terminated string of the path (wchar on Windows, char otherwise) - * \param[in] options - * \param[in] prepacked_weights_container - * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSessionWithPrepackedWeightsContainer, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, - _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, - _Outptr_ OrtSession** out); - - /** \brief Create session from memory with prepacked weights container - * - * Same functionality offered by OrtApi::CreateSessionFromArray except that a container that contains - * pre-packed weights' buffers is written into/read from by the created session. - * This is useful when used in conjunction with OrtApi::AddInitializer which injects - * shared initializer info into sessions. Wherever possible, the pre-packed versions of these - * shared initializers are cached in this container so that multiple sessions can just re-use - * these instead of duplicating these in memory. - * - * \param[in] env - * \param[in] model_data Array of bytes holding the model - * \param[in] model_data_length Number of bytes in `model_data_model` - * \param[in] options - * \param[in] prepacked_weights_container - * \param[out] out Newly created ::OrtSession. Must be freed with OrtApi::ReleaseSession - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSessionFromArrayWithPrepackedWeightsContainer, _In_ const OrtEnv* env, - _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, - _Inout_ OrtPrepackedWeightsContainer* prepacked_weights_container, - _Outptr_ OrtSession** out); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Append TensorRT execution provider to the session options - * - * If TensorRT is not available (due to a non TensorRT enabled build), this function will return failure. - * - * This is slightly different from OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, it takes an - * ::OrtTensorRTProviderOptions which is publicly defined. This takes an opaque ::OrtTensorRTProviderOptionsV2 - * which must be created with OrtApi::CreateTensorRTProviderOptions. - * - * For OrtApi::SessionOptionsAppendExecutionProvider_TensorRT, the user needs to instantiate ::OrtTensorRTProviderOptions - * as well as allocate/release buffers for some members of ::OrtTensorRTProviderOptions. - * Here, OrtApi::CreateTensorRTProviderOptions and Ortapi::ReleaseTensorRTProviderOptions will do the memory management for you. - * - * \param[in] options - * \param[in] tensorrt_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_TensorRT_V2, - _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options); - - /// @} - /// \name OrtTensorRTProviderOptionsV2 - /// @{ - - /** \brief Create an OrtTensorRTProviderOptionsV2 - * - * \param[out] out Newly created ::OrtTensorRTProviderOptionsV2. Must be released with OrtApi::ReleaseTensorRTProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateTensorRTProviderOptions, _Outptr_ OrtTensorRTProviderOptionsV2** out); - - /** \brief Set options in a TensorRT Execution Provider. - * - * Please refer to https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#cc - * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtTensorRTProviderOptionsV2 - * and value should be its related range. Recreates the options and only sets the supplied values. - * - * For example, key="trt_max_workspace_size" and value="2147483648" - * - * \param[in] tensorrt_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(UpdateTensorRTProviderOptions, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** \brief Get serialized TensorRT provider options string. - * - * For example, "trt_max_workspace_size=2147483648;trt_max_partition_iterations=10;trt_int8_enable=1;......" - * - * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance - * \param allocator - a ptr to an instance of OrtAllocator obtained with OrtApi::CreateAllocator or OrtApi::GetAllocatorWithDefaultOptions - * the specified allocator will be used to allocate continuous buffers for output strings and lengths. - * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorRTProviderOptionsAsString, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - - /** \brief Release an ::OrtTensorRTProviderOptionsV2 - * - * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does - */ - void(ORT_API_CALL* ReleaseTensorRTProviderOptions)(_Frees_ptr_opt_ OrtTensorRTProviderOptionsV2* input); - - /// @} - /// \name OrtSessionOptions - /// @{ - - /** \brief Enable custom operators - * - * See onnxruntime-extensions: https://github.com/microsoft/onnxruntime-extensions.git - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(EnableOrtCustomOps, _Inout_ OrtSessionOptions* options); - - /// @} - /// \name OrtAllocator - /// @{ - - /** \brief Register a custom allocator - * - * Enables sharing between multiple sessions that use the same env instance. - * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. - * - * The behavior of this is exactly the same as OrtApi::CreateAndRegisterAllocator except - * instead of ORT creating an allocator based on provided info, in this case - * ORT uses the user-provided custom allocator. - * See https://onnxruntime.ai/docs/get-started/with-c.html for details. - * - * \param[in] env - * \param[in] allocator User provided allocator - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(RegisterAllocator, _Inout_ OrtEnv* env, _In_ OrtAllocator* allocator); - - /** \brief Unregister a custom allocator - * - * It is an error if you provide an ::OrtMemoryInfo not corresponding to any - * registered allocators for sharing. - * - * \param[in] env - * \param[in] mem_info - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(UnregisterAllocator, _Inout_ OrtEnv* env, - _In_ const OrtMemoryInfo* mem_info); - - /// @} - /// \name OrtValue - /// @{ - - /** \brief Sets *out to 1 iff an ::OrtValue is a SparseTensor, and 0 otherwise - * - * \param[in] value existing ::OrtValue - * \param[out] out unless an error occurs, contains 1 iff the value contains an instance - * of sparse tensor or 0 otherwise. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(IsSparseTensor, _In_ const OrtValue* value, _Out_ int* out); - - /** \brief Create an ::OrtValue with a sparse tensor that is empty. - * - * Use FillSparseTensor() functions to populate sparse tensor with non-zero values and - * format specific indices data. - * Use ReleaseValue to destroy the sparse tensor, this will also release the buffer inside the output value - * if any was allocated. - * \param[in,out] allocator allocator to use when performing an allocation. Allocation will be performed - * by FillSparseTensor() APIs. The lifespan of the allocator instance must eclipse the lifespan - * this sparse tensor instance as the same allocator will be used to free memory. - * \param[in] dense_shape shape of the original dense tensor - * \param[in] dense_shape_len number of shape dimensions being passed - * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx - * \param[out] out Should be freed by calling ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSparseTensorAsOrtValue, _Inout_ OrtAllocator* allocator, _In_ const int64_t* dense_shape, - size_t dense_shape_len, ONNXTensorElementDataType type, _Outptr_ OrtValue** out); - - /** - * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. - * This will allocate required memory and copy the supplied NNZ values and COO indices into that memory allocation. - * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. - * - * \param[in,out] ort_value ::OrtValue to populate with data - * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified - * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. - * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. - * \param[in] values_shape pointer to values shape array - * \param[in] values_shape_len length of the values_shape - * \param[in] values pointer to an array of values. For strings, pass const char**. - * \param[in] indices_data pointer to a location of COO indices - * \param[in] indices_num number of COO indices - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(FillSparseTensorCoo, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, - _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, - _In_ const int64_t* indices_data, size_t indices_num); - - /** - * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. - * This will allocate required memory and copy the supplied NNZ values and CSR indices into that memory allocation. - * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. - * - * \param[in,out] ort_value ::OrtValue to populate with data - * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified - * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. - * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. - * \param[in] values_shape pointer to values shape array - * \param[in] values_shape_len length of the values_shape - * \param[in] values - pointer to an array of values. For strings, pass const char**. - * \param[in] inner_indices_data pointer to a location of CSR inner indices - * \param[in] inner_indices_num number of CSR inner indices - * \param[in] outer_indices_data pointer to a location of CSR outer indices - * \param[in] outer_indices_num number of CSR outer indices - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(FillSparseTensorCsr, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, - _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, - _In_ const int64_t* inner_indices_data, size_t inner_indices_num, - _In_ const int64_t* outer_indices_data, size_t outer_indices_num); - - /** - * This fills populates an empty tensor that was created using OrtApi::CreateSparseTensorAsOrtValue. - * This will allocate required memory and copy the supplied NNZ values and BlockSparse indices into that memory allocation. - * Memory allocation is performed using the allocator that was specified with OrtApi::CreateSparseTensorAsOrtValue. - * - * \param[in,out] ort_value ::OrtValue to populate with data - * \param[in] data_mem_info serves to identify the location of the data to be copied. If the allocator specified - * at the creation time has memory info that is not the same as mem_info argument to this function a X-device copy will be performed. - * String data is assumed to be on CPU and will only be copied into a CPU allocated buffer. - * \param[in] values_shape - * \param[in] values_shape_len - * \param[in] values structure with values information - * \param[in] indices_shape_data pointer to a location of indices shape - * \param[in] indices_shape_len length of the block sparse indices shape - * \param[in] indices_data pointer to a location of indices data. Shape will determine the length of the indices data. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(FillSparseTensorBlockSparse, _Inout_ OrtValue* ort_value, _In_ const OrtMemoryInfo* data_mem_info, - _In_ const int64_t* values_shape, size_t values_shape_len, _In_ const void* values, - _In_ const int64_t* indices_shape_data, size_t indices_shape_len, - _In_ const int32_t* indices_data); - - /** - * Create an ::OrtValue with a sparse tensor. This is the first step. - * Next, use UseIndices() functions to supply sparse tensor with - * format specific indices data and set its sparse format to a specific enum value. - * This will not perform memory allocations. It will - * use supplied user buffer which should outlive the created sparse tensor. - * Use OrtApi::ReleaseValue to destroy the sparse tensor. It would not release the supplied values buffer. - * This function can not be used to map strings from the user allocated memory. Strings must always be copied - * and have UTF-8 encoding. Therefore, use OrtApi::CreateSparseTensorAsOrtValue above and then fill it with data - * using appropriate Make*() function. - * - * \param[in] info memory info where sparse values reside. - * \param[in,out] p_data pointer to a user allocated buffer with values. To create a full sparse tensor with no non-zero - * values, pass nullptr - * \param[in] dense_shape shape of the original dense tensor - * \param[in] dense_shape_len number of shape dimensions being passed - * \param[in] values_shape shape of the values data. To create a fully sparse tensor with no non-zero values, - * pass {0} shape. - * \param[in] values_shape_len number of values shape dimensions - * \param[in] type must be one of TENSOR_ELEMENT_DATA_TYPE_xxxx - * \param[out] out Should be freed by calling ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(CreateSparseTensorWithValuesAsOrtValue, _In_ const OrtMemoryInfo* info, _Inout_ void* p_data, - _In_ const int64_t* dense_shape, size_t dense_shape_len, - _In_ const int64_t* values_shape, size_t values_shape_len, - ONNXTensorElementDataType type, _Outptr_ OrtValue** out); - - /** - * This assigns Coo format indices to the SparseTensor that was created by - * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to - * ORT_SPARSE_COO. This will not allocate any additional memory for data. The life span of - * indices_data buffer should eclipse the life span of this ::OrtValue. - * - * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue - * \param[in,out] indices_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. - * \param[in] indices_num number of COO indices. Should either be 0 for fully sparse tensors, be equal - * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue for 1-D {nnz} indices or - * be twice as number of nnz values for a 2-D indices {nnz, 2} - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(UseCooIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* indices_data, size_t indices_num); - - /** - * The assigns CSR format indices to the SparseTensor that was created by - * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to - * ORT_SPARSE_CSRC. This will not allocate any additional memory for data. The life spans of - * inner_data and outer_data buffers should eclipse the life span of this ::OrtValue. - * - * \param[in,out] ort_value ::OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue - * \param[in,out] inner_data pointer to a user pre-allocated buffer or nullptr for fully sparse tensors. - * \param[in] inner_num number of inner CSR indices. Should either be 0 for fully sparse tensors or be equal - * to the number of nnz values specified to OrtApi::CreateSparseTensorWithValuesAsOrtValue. - * \param[in,out] outer_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. - * \param[in] outer_num number of CSR outer indices. Should either be 0 for fully sparse tensors or - * equal to rows + 1 of the dense shape. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(UseCsrIndices, _Inout_ OrtValue* ort_value, _Inout_ int64_t* inner_data, size_t inner_num, - _Inout_ int64_t* outer_data, size_t outer_num); - - /** - * The assigns BlockSparse format indices to the SparseTensor that was created by - * OrtApi::CreateSparseTensorWithValuesAsOrtValue above. It also sets OrtSparseFormat to - * ORT_SPARSE_BLOCK_SPARSE. This will not allocate any additional memory for data. The life span of - * indices_data buffer must eclipse the lifespan of this ::OrtValue. - * - * \param[in,out] ort_value OrtValue instance constructed with OrtApi::CreateSparseTensorWithValuesAsOrtValue - * \param[in] indices_shape pointer to indices shape. Use {0} for fully sparse tensors - * \param[in] indices_shape_len length of the indices shape - * \param[in,out] indices_data pointer to user pre-allocated buffer or nullptr for fully sparse tensors. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(UseBlockSparseIndices, _Inout_ OrtValue* ort_value, const int64_t* indices_shape, size_t indices_shape_len, _Inout_ int32_t* indices_data); - - /** \brief Returns sparse tensor format enum iff a given ort value contains an instance of sparse tensor. - * - * \param[in] ort_value ::OrtValue that contains an instance of sparse tensor - * \param[out] out pointer to out parameter - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSparseTensorFormat, _In_ const OrtValue* ort_value, _Out_ enum OrtSparseFormat* out); - - /** \brief Returns data type and shape of sparse tensor values (nnz) iff ::OrtValue contains a SparseTensor. - * - * \param[in] ort_value An ::OrtValue that contains a fully constructed sparse tensor - * \param[out] out Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSparseTensorValuesTypeAndShape, _In_ const OrtValue* ort_value, _Outptr_ OrtTensorTypeAndShapeInfo** out); - - /** \brief Returns numeric data for sparse tensor values (nnz). For string values use GetStringTensor*(). - * - * \param[in] ort_value an instance of ::OrtValue containing sparse tensor - * \param[out] out returns a pointer to values data. Do not attempt to free this ptr. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSparseTensorValues, _In_ const OrtValue* ort_value, _Outptr_ const void** out); - - /** \brief Returns data type, shape for the type of indices specified by indices_format. - * - * \param[in] ort_value ::OrtValue containing sparse tensor. - * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse - * tensor does not contain. - * \param[out] out an instance of ::OrtTensorTypeAndShapeInfo. Must be freed by OrtApi::ReleaseTensorTypeAndShapeInfo - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSparseTensorIndicesTypeShape, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Outptr_ OrtTensorTypeAndShapeInfo** out); - - /** \brief Returns indices data for the type of the indices specified by indices_format - * - * \param[in] ort_value ::OrtValue containing sparse tensor. - * \param[in] indices_format One of the indices formats. It is an error to request a format that the sparse tensor does not contain. - * \param[out] num_indices Pointer to where the number of indices entries is returned - * \param[out] indices Returned pointer to the indices data. Do not free the returned pointer as it refers to internal data owned by the ::OrtValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices); - /// @} - /// \name OrtSessionOptions - /// @{ - - /** - * \brief Sets out to 1 iff an optional type OrtValue has an element, 0 otherwise (OrtValue is None) - * Use this API to find if the optional type OrtValue is None or not. - * If the optional type OrtValue is not None, use the OrtValue just like any other OrtValue. - * For example, if you get an OrtValue that corresponds to Optional(tensor) and - * if HasValue() returns true, use it as tensor and so on. - - * \param[in] value Input OrtValue. - * \param[out] out indicating if the input OrtValue contains data (1) or if it is a None (0) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(HasValue, _In_ const OrtValue* value, _Out_ int* out); - - /// @} - /// \name OrtKernelContext - /// Custom operator APIs. - /// @{ - - /** \brief Used for custom operators, gets the GPU compute stream to use to launch the custom a GPU kernel - * \see ::OrtCustomOp - * \param[in] context OrtKernelContext instance - * \param[out] out Returns pointer to a GPU compute stream that can be used to launch the custom GPU kernel. - * If retrieving the GPU compute stream is not relevant (GPU not enabled in the build, kernel partitioned to - * some other EP), then a nullptr is returned as the output param. - * Do not free or mutate the returned pointer as it refers to internal data owned by the underlying session. - * Only use it for custom kernel launching. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelContext_GetGPUComputeStream, _In_ const OrtKernelContext* context, _Outptr_ void** out); - - /// @} - /// \name GetTensorMemoryInfo - /// @{ - /** \brief Returns a pointer to the ::OrtMemoryInfo of a Tensor - * \param[in] value ::OrtValue containing tensor. - * \param[out] mem_info ::OrtMemoryInfo of the tensor. Do NOT free the returned pointer. It is valid for the lifetime of the ::OrtValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetTensorMemoryInfo, _In_ const OrtValue* value, _Out_ const OrtMemoryInfo** mem_info); - - /// @} - /// \name GetExecutionProviderApi - /// @{ - /** \brief Get a pointer to the requested version of the Execution Provider specific - * API extensions to the OrtApi - * \param[in] provider_name The name of the execution provider name. Currently only the following - * values are supported: "DML". - * \param[in] version Must be ::ORT_API_VERSION. - * \param[out] provider_api A void pointer containing a reference to the execution provider versioned api structure. - * For example, the provider_api pointer can be cast to the OrtDmlApi* when the provider_name is "DML". - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetExecutionProviderApi, _In_ const char* provider_name, _In_ uint32_t version, _Outptr_ const void** provider_api); - - /// @} - - /// \name SessionOptions - /// @{ - /** \brief Set custom thread creation function - * - * \param[in] options Session options - * \param[in] ort_custom_create_thread_fn Custom thread creation function - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsSetCustomCreateThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); - - /** \brief Set creation options for custom thread - * - * \param[in] options Session options - * \param[in] ort_custom_thread_creation_options Custom thread creation options (can be nullptr) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsSetCustomThreadCreationOptions, _Inout_ OrtSessionOptions* options, _In_ void* ort_custom_thread_creation_options); - - /** \brief Set custom thread join function - * - * \param[in] options Session options - * \param[in] ort_custom_join_thread_fn Custom join thread function, must not be nullptr when ort_custom_create_thread_fn is set - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SessionOptionsSetCustomJoinThreadFn, _Inout_ OrtSessionOptions* options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); - /// @} - - /// \name OrtThreadingOptions - /// @{ - /** \brief Set custom thread creation function for global thread pools - * - * \param[inout] tp_options - * \param[in] ort_custom_create_thread_fn Custom thread creation function - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalCustomCreateThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomCreateThreadFn ort_custom_create_thread_fn); - - /** \brief Set custom thread creation options for global thread pools - * - * \param[inout] tp_options - * \param[in] ort_custom_thread_creation_options Custom thread creation options (can be nullptr) - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalCustomThreadCreationOptions, _Inout_ OrtThreadingOptions* tp_options, _In_ void* ort_custom_thread_creation_options); - - /** \brief Set custom thread join function for global thread pools - * - * \param[inout] tp_options - * \param[in] ort_custom_join_thread_fn Custom thread join function, must not be nullptr when global ort_custom_create_thread_fn is set - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SetGlobalCustomJoinThreadFn, _Inout_ OrtThreadingOptions* tp_options, _In_ OrtCustomJoinThreadFn ort_custom_join_thread_fn); - /// @} - - /** \brief Synchronize bound inputs. The call may be necessary for some providers, such as cuda, - * in case the system that allocated bound memory operated on a different stream. However, the - * operation is provider specific and could be a no-op. - * - * \param[inout] binding_ptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SynchronizeBoundInputs, _Inout_ OrtIoBinding* binding_ptr); - - /** \brief Synchronize bound outputs. The call may be necessary for some providers, such as cuda, - * in case the system that allocated bound memory operated on a different stream. However, the - * operation is provider specific and could be a no-op. - * - * \param[inout] binding_ptr - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(SynchronizeBoundOutputs, _Inout_ OrtIoBinding* binding_ptr); - - /// \name OrtSessionOptions - /// @{ - - /** \brief Append CUDA execution provider to the session options - * - * If CUDA is not available (due to a non CUDA enabled build), this function will return failure. - * - * This is slightly different from OrtApi::SessionOptionsAppendExecutionProvider_CUDA, it takes an - * ::OrtCUDAProviderOptions which is publicly defined. This takes an opaque ::OrtCUDAProviderOptionsV2 - * which must be created with OrtApi::CreateCUDAProviderOptions. - * - * For OrtApi::SessionOptionsAppendExecutionProvider_CUDA, the user needs to instantiate ::OrtCUDAProviderOptions - * as well as allocate/release buffers for some members of ::OrtCUDAProviderOptions. - * Here, OrtApi::CreateCUDAProviderOptions and Ortapi::ReleaseCUDAProviderOptions will do the memory management for you. - * - * \param[in] options - * \param[in] cuda_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.11. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CUDA_V2, - _In_ OrtSessionOptions* options, _In_ const OrtCUDAProviderOptionsV2* cuda_options); - - /// @} - /// \name OrtCUDAProviderOptionsV2 - /// @{ - - /** \brief Create an OrtCUDAProviderOptionsV2 - * - * \param[out] out Newly created ::OrtCUDAProviderOptionsV2. Must be released with OrtApi::ReleaseCudaProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.11. - */ - ORT_API2_STATUS(CreateCUDAProviderOptions, _Outptr_ OrtCUDAProviderOptionsV2** out); - - /** \brief Set options in a CUDA Execution Provider. - * - * Please refer to https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#configuration-options - * to know the available keys and values. Key should be in null terminated string format of the member of ::OrtCUDAProviderOptionsV2 - * and value should be its related range. Recreates the options and only sets the supplied values. - * - * For example, key="device_id" and value="0" - * - * \param[in] cuda_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.11. - */ - ORT_API2_STATUS(UpdateCUDAProviderOptions, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** - * Get serialized CUDA provider options string. - * - * For example, "device_id=0;arena_extend_strategy=0;......" - * - * \param cuda_options - OrtCUDAProviderOptionsV2 instance - * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() - * the specified allocator will be used to allocate continuous buffers for output strings and lengths. - * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.11. - */ - ORT_API2_STATUS(GetCUDAProviderOptionsAsString, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - - /** \brief Release an ::OrtCUDAProviderOptionsV2 - * - * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does - * - * \since Version 1.11. - */ - void(ORT_API_CALL* ReleaseCUDAProviderOptions)(_Frees_ptr_opt_ OrtCUDAProviderOptionsV2* input); - - /// @} - - /** \brief Append MIGraphX provider to session options - * - * If MIGraphX is not available (due to a non MIGraphX enabled build, or if MIGraphX is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] migraphx_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.11. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_MIGraphX, - _In_ OrtSessionOptions* options, _In_ const OrtMIGraphXProviderOptions* migraphx_options); - - /** \brief Replace initialized Tensors with external data with the data provided in initializers. - * - * The function will find the initialized TensorProtos with external data in the graph with the provided names and - * replace them with the provided tensors. The API verifies that the TensorProto being replaced - * has an external data reference and has the same name, dimensions and data type as its replacement. The replacement - * will occur before any of the optimizations take place. The data will be copied into the graph - * since TensorProto can't refer to the user provided buffers. - * - * Once the model has been loaded, the OrtValue(s) added to SessionOptions instance will be removed - * from the internal SessionOptions copy to save memory, the user provided buffers can then be deallocated - * and the SessionOptions instance that refers to them can be destroyed. - * - * \param[in] options - * \param[in] initializer_names Array of null terminated UTF-8 encoded strings of the initializers names. - * \param[in] initializers Array of ::OrtValue type - * \param[in] num_initializers Number of elements in the initializer_names and initializers - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.12. - */ - ORT_API2_STATUS(AddExternalInitializers, _In_ OrtSessionOptions* options, - _In_reads_(num_initializers) const char* const* initializer_names, - _In_reads_(num_initializers) const OrtValue* const* initializers, size_t num_initializers); - - /** \brief: Create attribute of onnxruntime operator - * - * \param[in] name Name of the attribute - * \param[in] data Data content of the attribute - * \param[in] len Number of bytes stored in data - * \param[in] type Data type - * \param[out] op_attr Attribute that has been created, which must be released by OrtApi::ReleaseOpAttr - * - * \since Version 1.12. - */ - ORT_API2_STATUS(CreateOpAttr, - _In_ const char* name, - _In_ const void* data, - _In_ int len, - _In_ OrtOpAttrType type, - _Outptr_ OrtOpAttr** op_attr); - - /* \brief: Release op attribute - * - * \param[in] opAttr Attribute created by OrtApi::CreateOpAttr - * - * \since Version 1.12. - */ - ORT_CLASS_RELEASE(OpAttr); - - /** \brief: Create onnxruntime native operator - * - * \param[in] info Kernel info - * \param[in] op_name Operator name - * \param[in] domain Operator domain - * \param[in] version Operator opset version - * \param[in] type_constraint_names Name of the type constraints, such as "T" or "T1" - * \param[in] type_constraint_values Type of each constraints - * \param[in] type_constraint_count Number of constraints - * \param[in] attr_values Attributes used to initialize the operator - * \param[in] attr_count Number of the attributes - * \param[in] input_count Number of inputs - * \param[in] output_count Number of outputs - * \param[out] ort_op Operator that has been created - * - * \since Version 1.12. - */ - ORT_API2_STATUS(CreateOp, - _In_ const OrtKernelInfo* info, - _In_z_ const char* op_name, - _In_z_ const char* domain, - int version, - _In_reads_(type_constraint_count) const char** type_constraint_names, - _In_reads_(type_constraint_count) const ONNXTensorElementDataType* type_constraint_values, - int type_constraint_count, - _In_reads_(attr_count) const OrtOpAttr* const* attr_values, - int attr_count, - int input_count, - int output_count, - _Outptr_ OrtOp** ort_op); - - /** \brief: Invoke the operator created by OrtApi::CreateOp - * The inputs must follow the order as specified in onnx specification - * - * \param[in] context Kernel context - * \param[in] ort_op Operator that has been created - * \param[in] input_values Array of inputs - * \param[in] input_count Number of inputs - * \param[in] output_values Array of outputs - * \param[in] output_count Number of outputs - * - * \since Version 1.12. - */ - ORT_API2_STATUS(InvokeOp, - _In_ const OrtKernelContext* context, - _In_ const OrtOp* ort_op, - _In_ const OrtValue* const* input_values, - _In_ int input_count, - _Inout_ OrtValue* const* output_values, - _In_ int output_count); - - /* \brief: Release an onnxruntime operator - * - * \param[in] Op Operator created by OrtApi::CreateOp - * - * \since Version 1.12. - */ - ORT_CLASS_RELEASE(Op); - - /** \brief: Append execution provider to the session options. - * \param[in] options - * \param[in] provider_name - provider to add. - * \param[in] provider_options_keys - keys to configure the provider options - * \param[in] provider_options_values - values to configure the provider options - * \param[in] num_keys - number of keys passed in - * - * Currently supported provider names: - * QNNExecutionProvider (or QNN) - * OpenVINOExecutionProvider (or OpenVINO) - * XnnpackExecutionProvider (or XNNPACK) - * WebNNExecutionProvider (or WEBNN) - * WebGpuExecutionProvider (or WebGPU) - * AzureExecutionProvider (or AZURE) - * JsExecutionProvider (or JS) - * VitisAIExecutionProvider (or VitisAI) - * CoreMLExecutionProvider (or CoreML) - * - * Note: If an execution provider has a dedicated SessionOptionsAppendExecutionProvider_ function - * that should be used to add it. - * - * QNN supported keys: - * "backend_type": Type of QNN backend. Specifies a backend path that is the associated QNN backend library file - * name. E.g., given backend type "htp", on Windows, the backend path would be "QnnHtp.dll", and on other - * platforms, it would be "libQnnHtp.so". Mutually exclusive with "backend_path". - * Available options: - * -# "cpu" - * -# "gpu" - * -# "htp": Default. - * -# "saver" - * -# "ir" - * "backend_path": File path to QNN backend library. Mutually exclusive with "backend_type". - * "profiling_level": QNN profiling level. - * Available options: - * -# "off": Default. - * -# "basic" - * -# "detailed" - * "profiling_file_path": QNN profiling file path if ETW not enabled. - * "rpc_control_latency": QNN RPC control latency. - * "vtcm_mb": QNN VTCM size in MB. default to 0(not set). - * "htp_performance_mode": QNN performance mode. - * Available options: - * -# "burst" - * -# "balanced" - * -# "default": Default. - * -# "high_performance" - * -# "high_power_saver" - * -# "low_balanced" - * -# "extreme_power_saver" - * -# "low_power_saver" - * -# "power_saver" - * -# "sustained_high_performance" - * "dump_qnn_ir_dlc": Use the QnnIr backend library to write .dlc files for each subgraph dispatched to QNN. When - * enabled, inference results will be incorrect. Use only for debugging. - * -# "0": Default: disabled - * -# "1": enabled - * "dump_qnn_ir_dlc_dir": Set the directory into which QnnIr will be configured to write QNN graphs as .dlc files. - * Default is current working directory. - * "qnn_ir_backend_path": File path to the QnnIr backend library. If "dump_qnn_ir_dlc" is enabled, use this path - * instead of looking for the Ir backend in the standard location. - * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will - * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and - * may alter model/EP partitioning. Use only for debugging. - * "qnn_context_priority": QNN context priority. - * Available options: - * -# "low" - * -# "normal": Default. - * -# "normal_high" - * -# "high" - * "htp_graph_finalization_optimization_mode": Set the optimization mode for graph finalization on the HTP backend. - * Available options: - * -# "0": Default. - * -# "1": Faster preparation time, less optimal graph. - * -# "2": Longer preparation time, more optimal graph. - * -# "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific - * details. - * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. - * Defaults to "0" (unknown). - * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. - * Available options: - * -# "0": Default (none). - * -# "68" - * -# "69" - * -# "73" - * -# "75" - * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). - * "enable_htp_fp16_precision": Used for float32 model for HTP backend. - * Enable the float32 model to be inferenced with fp16 precision. Otherwise, it will be fp32 precision. - * -# "0": With fp32 precision. - * -# "1": Default. With fp16 precision. - * "offload_graph_io_quantization": Offload graph input quantization and graph output dequantization to another - * execution provider (typically CPU EP). - * -# "0": Disabled. QNN EP will handle quantization and dequantization of graph I/O. - * -# "1": Enabled. This is the default value. - * "enable_htp_spill_fill_buffer": Enable HTP spill fill buffer setting. The flag is used while generating context - * binary. - * -# "0": Default. Disabled. - * -# "1": Enabled. - * "enable_htp_shared_memory_allocator": Enable the QNN HTP shared memory allocator. Requires libcdsprpc.so/dll to - * be available. - * -# "0": Default. Disabled. - * -# "1": Enabled. - * "dump_json_qnn_graph": Set to "1" to dump QNN graphs generated by QNN EP as JSON files. Each graph partition - * assigned to QNN EP is dumped to a separate file. - * "json_qnn_graph_dir": Directory in which to dump QNN JSON graphs. If not specified, QNN graphs are dumped in the - * program's current working directory. Ignored if "dump_json_qnn_graph" is not set. - * - * XNNPACK supported keys: - * "intra_op_num_threads": number of thread-pool size to use for XNNPACK execution provider. - * default value is 0, which means to use the session thread-pool size. - * - * \since Version 1.12. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider, _In_ OrtSessionOptions* options, - _In_ const char* provider_name, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /* \brief: Get a copy of kernel info - * - * \param[in] info Kernel info - * \param[out] info_copy Copy of kernel info - * - * \since Version 1.12. - */ - ORT_API2_STATUS(CopyKernelInfo, - _In_ const OrtKernelInfo* info, - _Outptr_ OrtKernelInfo** info_copy); - - /* \brief: Release kernel info - * - * \param[in] KernelInfo A copy of kernel info returned by CopyKernelInfo - * - * \since Version 1.12. - */ - ORT_CLASS_RELEASE(KernelInfo); - - /// \name Ort Training - /// @{ - /** \brief Gets the Training C Api struct - * - * Call this function to access the ::OrtTrainingApi structure that holds pointers to functions that enable - * training with onnxruntime. - * \note A NULL pointer will be returned and no error message will be printed if the training api - * is not supported with this build. A NULL pointer will be returned and an error message will be - * printed if the provided version is unsupported, for example when using a runtime older than the - * version created with this header file. - * - * \param[in] version Must be ::ORT_API_VERSION - * \return The ::OrtTrainingApi struct for the version requested. - * - * \since Version 1.13 - */ - const OrtTrainingApi*(ORT_API_CALL* GetTrainingApi)(uint32_t version)NO_EXCEPTION; - - /// @} - - /** \brief Append CANN provider to session options - * - * If CANN is not available (due to a non CANN enabled build, or if CANN is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] cann_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.13. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_CANN, - _In_ OrtSessionOptions* options, _In_ const OrtCANNProviderOptions* cann_options); - - /** \brief Create an OrtCANNProviderOptions - * - * \param[out] out created ::OrtCANNProviderOptions. Must be released with OrtApi::ReleaseCANNProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.13. - */ - ORT_API2_STATUS(CreateCANNProviderOptions, _Outptr_ OrtCANNProviderOptions** out); - - /** \brief Set options in a CANN Execution Provider. - * - * \param[in] cann_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.13. - */ - ORT_API2_STATUS(UpdateCANNProviderOptions, _Inout_ OrtCANNProviderOptions* cann_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** \brief Get serialized CANN provider options string. - * - * \param[in] cann_options OrtCANNProviderOptions instance - * \param[in] allocator a ptr to an instance of OrtAllocator obtained with CreateAllocator() - * or GetAllocatorWithDefaultOptions(), the specified allocator will be used to allocate - * continuous buffers for output strings and lengths. - * \param[out] ptr is a UTF-8 null terminated string allocated using 'allocator'. - * The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.13. - */ - ORT_API2_STATUS(GetCANNProviderOptionsAsString, _In_ const OrtCANNProviderOptions* cann_options, - _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - - /** \brief Release an OrtCANNProviderOptions - * - * \param[in] input The pointer of OrtCANNProviderOptions which will been deleted - * - * \since Version 1.13. - */ - void(ORT_API_CALL* ReleaseCANNProviderOptions)(_Frees_ptr_opt_ OrtCANNProviderOptions* input); - - /* \brief Get OrtDevice type from MemoryInfo - * - * \since Version 1.14 - */ - void(ORT_API_CALL* MemoryInfoGetDeviceType)(_In_ const OrtMemoryInfo* ptr, _Out_ OrtMemoryInfoDeviceType* out); - - /* \brief Update the OrtEnv instance with custom log severity level - * - * \param[in] ort_env The OrtEnv instance being used - * \param[in] log_severity_level The log severity level. - * - * \since Version 1.14. - */ - ORT_API2_STATUS(UpdateEnvWithCustomLogLevel, _In_ OrtEnv* ort_env, OrtLoggingLevel log_severity_level); - - /* \brief Set affinities for intra op threads - * - * Affinity string follows format: - * logical_processor_id,logical_processor_id;logical_processor_id,logical_processor_id - * Semicolon isolates configurations among threads, while comma split processors where ith thread expected to attach to. - * e.g. 1,2,3;4,5 - * specifies affinities for two threads, with the 1st thread attach to the 1st, 2nd, and 3rd processor, and 2nd thread to the 4th and 5th. - * To ease the configuration, an "interval" is also allowed: - * e.g. 1-8;8-16;17-24 - * orders that the 1st thread runs on first eight processors, 2nd thread runs on next eight processors, and so forth. - * Note: - * 1. Once set, the number of thread affinities must equal to intra_op_num_threads - 1, - * ort does not set affinity on the main thread which is started and managed by the calling app; - * 2. For windows, ort will infer the group id from a logical processor id, for example, assuming there are two groups with each has 64 logical processors, - * an id of 64 will be inferred as the last processor of the 1st group, while 65 will be interpreted as the 1st processor of the second group. - * Hence 64-65 is an invalid configuration, because a windows thread cannot be attached to processors across group boundary. - * - * \since Version 1.14 - */ - ORT_API2_STATUS(SetGlobalIntraOpThreadAffinity, _Inout_ OrtThreadingOptions* tp_options, const char* affinity_string); - - /** \brief Register custom ops from a shared library. - * - * Loads a shared library (.dll on windows, .so on linux, etc) named 'library_name' and looks for this entry point: - * OrtStatus* RegisterCustomOps(OrtSessionOptions * options, const OrtApiBase* api); - * It then passes in the provided session options to this function along with the api base. - * - * The handle to the loaded library is automatically released by ORT when the last OrtSession that references the - * library handle is released. If no OrtSession is created, then the library handle is released when the provided - * OrtSessionOptions is released. - * - * \param[in] options The session options. - * \param[in] library_name The name of the shared library to load and register. Refer to OS-specific dynamic library - * loading utilities (e.g., LoadLibraryEx on Windows or dlopen on Linux/MacOS) for information - * on the format of library names and search paths. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(RegisterCustomOpsLibrary_V2, _Inout_ OrtSessionOptions* options, _In_ const ORTCHAR_T* library_name); - - /** \brief Register custom ops by calling a RegisterCustomOpsFn function. - * - * Searches for registration_func_name and if found calls it. - * - * The library containing the function must either be linked against or previously loaded by the executable. - * - * If you want ONNX Runtime to load the library and manage its lifetime, use RegisterCustomOpsLibrary_V2. - * - * RegisterCustomOpsUsingFunction can be used in scenarios where it may not be possible for ONNX Runtime to load - * the library from a path. e.g. mobile platforms where the library must be linked into the app. - * - * The registration function must have the signature of RegisterCustomOpsFn: - * OrtStatus* (*fn)(OrtSessionOptions* options, const OrtApiBase* api); - * - * See https://onnxruntime.ai/docs/reference/operators/add-custom-op.html for details on how the registration - * function should be implemented. - * - * \param[in] options OrtSessionOptions that is passed through as the first argument in the call to the - * registration function. - * \param[in] registration_func_name Name of registration function to use. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(RegisterCustomOpsUsingFunction, _Inout_ OrtSessionOptions* options, - _In_ const char* registration_func_name); - - /// \name OrtKernelInfo - /// Custom operator APIs. - /// @{ - - /** \brief Get the number of inputs from ::OrtKernelInfo. - * - * Used in the CreateKernel callback of an OrtCustomOp to query the number of inputs - * during kernel/session creation. - * - * \param[in] info Instance of ::OrtKernelInfo. - * \param[out] out Pointer to variable assigned with the result on success. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetInputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); - - /** \brief Get the number of outputs from ::OrtKernelInfo. - * - * Used in the CreateKernel callback of an OrtCustomOp to query the number of outputs - * during kernel/session creation. - * - * \param[in] info Instance of ::OrtKernelInfo. - * \param[out] out Pointer to variable assigned with the result on success. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetOutputCount, _In_ const OrtKernelInfo* info, _Out_ size_t* out); - - /** \brief Get the name of a ::OrtKernelInfo's input. - * - * Used in the CreateKernel callback of an OrtCustomOp to query an input's name - * during kernel/session creation. - * - * If `out` is nullptr, the value of `size` is set to the size of the name - * string (including null-terminator), and a success status is returned. - * - * If the `size` parameter is greater than or equal to the name string's size, - * the value of `size` is set to the true size of the string (including null-terminator), - * the provided memory is filled with the string's contents, and a success status is returned. - * - * If the `size` parameter is less than the actual string's size and `out` - * is not nullptr, the value of `size` is set to the true size of the string - * and a failure status is returned. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[in] index The index of the input name to get. Returns a failure status if out-of-bounds. - * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the input's name. - * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetInputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, - _Inout_ size_t* size); - - /** \brief Get the name of a ::OrtKernelInfo's output. - * - * Used in the CreateKernel callback of an OrtCustomOp to query an output's name - * during kernel/session creation. - * - * If `out` is nullptr, the value of `size` is set to the size of the name - * string (including null-terminator), and a success status is returned. - * - * If the `size` parameter is greater than or equal to the name string's size, - * the value of `size` is set to the true size of the string (including null-terminator), - * the provided memory is filled with the string's contents, and a success status is returned. - * - * If the `size` parameter is less than the actual string's size and `out` - * is not nullptr, the value of `size` is set to the true size of the string - * and a failure status is returned. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[in] index The index of the output name to get. Returns a failure status if out-of-bounds. - * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the output's - * name. - * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetOutputName, _In_ const OrtKernelInfo* info, size_t index, _Out_ char* out, - _Inout_ size_t* size); - - /** \brief Get the type information for a ::OrtKernelInfo's input. - * - * Used in the CreateKernel callback of an OrtCustomOp to query the shape and type information - * of an input during kernel/session creation. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[in] index Which input to get the type information for - * \param[out] type_info Pointer set to the resulting ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetInputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, - _Outptr_ OrtTypeInfo** type_info); - - /** \brief Get the type information for a ::OrtKernelInfo's output. - * - * Used in the CreateKernel callback of an OrtCustomOp to query the shape and type information - * of an output during kernel/session creation. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[in] index Which input to get the type information for - * \param[out] type_info Pointer set to the resulting ::OrtTypeInfo. Must be freed with OrtApi::ReleaseTypeInfo. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(KernelInfo_GetOutputTypeInfo, _In_ const OrtKernelInfo* info, size_t index, - _Outptr_ OrtTypeInfo** type_info); - - /** \brief Get a ::OrtValue tensor stored as an attribute in the graph node. - * - * Used in the CreateKernel callback of an OrtCustomOp to get a tensor attribute. - * - * \param[in] info ::OrtKernelInfo instance. - * \param[in] name UTF-8 null-terminated string representing the attribute's name. - * \param[in] allocator Allocator used to allocate the internal tensor state. - * \param[out] out Returns newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue, - * which will also free internal tensor state allocated with the provided allocator. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(KernelInfoGetAttribute_tensor, _In_ const OrtKernelInfo* info, _In_z_ const char* name, - _Inout_ OrtAllocator* allocator, _Outptr_ OrtValue** out); - - /// @} - /// \name OrtSessionOptions - /// Custom operator APIs - /// @{ - - /** \brief Checks if the given session configuration entry exists. - * - * The config_key formats are defined in onnxruntime_session_options_config_keys.h - * - * Can be used in a custom operator library to check for session configuration entries - * that target one or more custom operators in the library. Example: The config entry - * custom_op.myop.some_key targets a custom op named "myop". - * - * \param[in] options The ::OrtSessionOptions instance. - * \param[in] config_key A null-terminated UTF-8 string representation of the configuration key. - * \param[out] out Pointer set to 1 if the entry exists and 0 otherwise. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(HasSessionConfigEntry, _In_ const OrtSessionOptions* options, - _In_z_ const char* config_key, _Out_ int* out); - - /** \brief Get a session configuration value. - * - * Returns a failure status if the configuration key does not exist. - * The config_key and the format of config_value are defined in onnxruntime_session_options_config_keys.h - * - * If `config_value` is nullptr, the value of `size` is set to the true size of the string - * value (including null-terminator), and a success status is returned. - * - * If the `size` parameter is greater than or equal to the actual string value's size, - * the value of `size` is set to the true size of the string value, the provided memory - * is filled with the value's contents, and a success status is returned. - * - * If the `size` parameter is less than the actual string value's size and `config_value` - * is not nullptr, the value of `size` is set to the true size of the string value - * and a failure status is returned. - * - * Can be used in a custom operator library to get session configuration entries - * that target one or more custom operators in the library. Example: The config entry - * custom_op.myop.some_key targets a custom op named "myop". - * - * \param[in] options The session options. - * \param[in] config_key A null-terminated UTF-8 string representation of the config key. - * \param[in] config_value Pointer to memory where the null-terminated UTF-8 string value will be stored. - * \param[in,out] size Pointer to the size of the `config_value` buffer. See above comments for details. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.14 - */ - ORT_API2_STATUS(GetSessionConfigEntry, _In_ const OrtSessionOptions* options, - _In_z_ const char* config_key, _Out_ char* config_value, _Inout_ size_t* size); - - /// @} - - /** \brief Append dnnl provider to session options - * - * If oneDNN is not available, this function will return failure. - * - * \param[in] options - * \param[in] dnnl_options - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_Dnnl, - _In_ OrtSessionOptions* options, _In_ const OrtDnnlProviderOptions* dnnl_options); - - /** \brief Create an OrtDnnlProviderOptions - * - * \param[out] out Newly created ::OrtDnnlProviderOptions. Must be released with OrtApi::ReleaseDnnlProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(CreateDnnlProviderOptions, _Outptr_ OrtDnnlProviderOptions** out); - - /** \brief Set options in a oneDNN Execution Provider. - * - * Key should be in null terminated string format of the member of ::OrtDnnlProviderOptions - * and value should be its related range. - * - * For example, key="use_arena" and value="1" - * - * \param[in] dnnl_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(UpdateDnnlProviderOptions, _Inout_ OrtDnnlProviderOptions* dnnl_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** - * Get serialized oneDNN provider options string. - * - * For example, "use_arena=1;......" - * - * \param dnnl_options - OrtDnnlProviderOptions instance - * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() - * the specified allocator will be used to allocate continuous buffers for output strings and lengths. - * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(GetDnnlProviderOptionsAsString, _In_ const OrtDnnlProviderOptions* dnnl_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - - /** \brief Release an ::OrtDnnlProviderOptions - * - * \since Version 1.15. - */ - void(ORT_API_CALL* ReleaseDnnlProviderOptions)(_Frees_ptr_opt_ OrtDnnlProviderOptions* input); - - /// \name OrtKernelInfo - /// Custom operator APIs. - /// @{ - - /** \brief Get the graph node name from ::OrtKernelInfo. - * - * If `out` is nullptr, the value of `size` is set to the size of the name - * string (including null-terminator), and a success status is returned. - * - * If the `size` parameter is greater than or equal to the name string's size, - * the value of `size` is set to the true size of the string (including null-terminator), - * the provided memory is filled with the string's contents, and a success status is returned. - * - * If the `size` parameter is less than the actual string's size and `out` - * is not nullptr, the value of `size` is set to the true size of the string - * and a failure status is returned. - * - * Can be used in a custom operator's CreateKernel callback to get the name of the operator's node name in the graph. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[out] out Memory location into which to write the UTF-8 null-terminated string representing the name. - * \param[in,out] size Pointer to the size of the `out` buffer. See above comments for details. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.15 - */ - ORT_API2_STATUS(KernelInfo_GetNodeName, _In_ const OrtKernelInfo* info, _Out_ char* out, _Inout_ size_t* size); - - /** \brief Get the session logger from ::OrtKernelInfo. - * - * Used in the CreateKernel callback of an OrtCustomOp to get a logger that can be used to log - * messages. - * - * \param[in] info An instance of ::OrtKernelInfo. - * \param[out] logger Pointer set to the session's ::OrtLogger. Owned by ONNX Runtime, so do not free. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.15 - */ - ORT_API2_STATUS(KernelInfo_GetLogger, _In_ const OrtKernelInfo* info, _Outptr_ const OrtLogger** logger); - - /// @} - /// \name OrtKernelContext - /// Custom operator APIs. - /// @{ - - /** \brief Get the runtime logger from ::OrtKernelContext. - * - * Used in the KernelCompute callback of an OrtCustomOp to get a logger that can be used to log - * messages during inference. - * - * \param[in] context An instance of ::OrtKernelContext. - * \param[out] logger Pointer set to the kernel context's ::OrtLogger. Owned by ONNX Runtime, so do not free. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.15 - */ - ORT_API2_STATUS(KernelContext_GetLogger, _In_ const OrtKernelContext* context, _Outptr_ const OrtLogger** logger); - - /// @} - /// \name OrtLogger - /// Custom operator APIs. - /// @{ - - /** \brief Logs a message at the given severity level using the provided ::OrtLogger. - * - * Only messages with a severity level equal or greater than the ::OrtLogger's logging severity level - * are logged. Use OrtApi::Logger_GetLoggingSeverityLevel to get the ::OrtLogger's logging severity - * level. - * - * Can be used in custom operators to log messages with the logger retrieved via OrtApi::KernelInfo_GetLogger. - * - * \param[in] logger The ::OrtLogger instance. - * \param[in] log_severity_level The message's severity level. - * \param[in] message The message to log. - * \param[in] file_path The filepath of the file in which the message is logged. Usually the value of ORT_FILE. - * \param[in] line_number The file line number in which the message is logged. Usually the value of __LINE__. - * \param[in] func_name The name of the function in which the message is logged. Usually the value of __FUNCTION__. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.15 - */ - ORT_API2_STATUS(Logger_LogMessage, _In_ const OrtLogger* logger, OrtLoggingLevel log_severity_level, - _In_z_ const char* message, _In_z_ const ORTCHAR_T* file_path, int line_number, - _In_z_ const char* func_name); - - /** \brief Get the logging severity level of the ::OrtLogger. - * - * Can be used in a custom operator to get the logging severity level of the ::OrtLogger associated with - * the ::OrtKernelInfo. - * - * \param[in] logger The ::OrtLogger instance. - * \param[out] out Pointer to variable assigned with the logging severity level on success. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.15 - */ - ORT_API2_STATUS(Logger_GetLoggingSeverityLevel, _In_ const OrtLogger* logger, _Out_ OrtLoggingLevel* out); - - /// @} - - /** \brief Get a ::OrtValue tensor stored as a constant initializer in the graph node. - * - * Used in the CreateKernel callback of an OrtCustomOp to get a tensor value. - * - * \param[in] info ::OrtKernelInfo instance. - * \param[in] index The node index. - * \param[out] is_constant Is it a constant node input or not. - * \param[out] out The OrtValue tensor value. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(KernelInfoGetConstantInput_tensor, _In_ const OrtKernelInfo* info, size_t index, _Out_ int* is_constant, _Outptr_ const OrtValue** out); - - /** \brief Get Optional Type information from an ::OrtTypeInfo - * - * This augments ::OrtTypeInfo to return an ::OrtOptionalTypeInfo when the type is optional. - * The OrtOptionalTypeInfo also has a nested ::OrtTypeInfo that describes the type of the optional value. - * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. - * The actual OrtValues that are supplied in place of optional type inputs should contain - * specific type that is described by ::OrtOptionalTypeInfo. - * - * So the picture: ::OrtTypeInfo -> ::OrtOptionalTypeInfo -> ::OrtTypeInfo (describes the type that can be supplied - * in place of the optional type when creating the actual ::OrtValue). - * - * \param[in] type_info - * \param[out] out A pointer to the ::OrtOptionalTypeInfo. Do not free this value, - * it is owned by OrtTypeInfo instance. When the type_info does not represent - * optional type, nullptr is returned in out. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(CastTypeInfoToOptionalTypeInfo, _In_ const OrtTypeInfo* type_info, - _Outptr_result_maybenull_ const OrtOptionalTypeInfo** out); - - /** \brief Get OrtTypeInfo for the allowed contained type from an ::OrtOptionalTypeInfo. - * - * This augments ::OrtOptionalTypeInfo to return an ::OrtTypeInfo for the contained type. - * The OrtOptionalTypeInfo has a nested ::OrtTypeInfo that describes the type of the optional value. - * ::OrtOptionalTypeInfo type can only appear within model metadata to describe inputs/outputs. - * The actual OrtValues that are supplied in place of optional type inputs should contain - * specific type that is described by the returned ::OrtTypeInfo. - * - * \param[in] optional_type_info - * \param[out] out A copy of ::OrtTypeInfo for what the optional value could be. - * The user must free this value with ReleaseTypeInfo. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(GetOptionalContainedTypeInfo, _In_ const OrtOptionalTypeInfo* optional_type_info, - _Outptr_ OrtTypeInfo** out); - - /** \brief Set a single string in a string tensor - * Do not zero terminate the string data. - * - * \param[in] value A string tensor - * \param[in] index - flat index of the element - * \param[in] length_in_bytes length of the buffer in utf-8 bytes (without the null terminator) - * \param[inout] buffer - address of return value - * - * \snippet{doc} snippets.dox OrtStatus Return Value - */ - ORT_API2_STATUS(GetResizedStringTensorElementBuffer, _Inout_ OrtValue* value, _In_ size_t index, _In_ size_t length_in_bytes, _Inout_ char** buffer); - - /** \brief Get Allocator from KernelContext for a specific memoryInfo. Please use C API ReleaseAllocator to release out object - * - * \param[in] context OrtKernelContext instance - * \param[in] mem_info OrtMemoryInfo instance - * \param[out] out A pointer to OrtAllocator. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.15. - */ - ORT_API2_STATUS(KernelContext_GetAllocator, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _Outptr_ OrtAllocator** out); - - /** \brief Returns a null terminated string of the build info including git info and cxx flags - * - * \return UTF-8 encoded version string. Do not deallocate the returned buffer. - * - * \since Version 1.15. - */ - const char*(ORT_API_CALL* GetBuildInfoString)(void); - - /// \name OrtROCMProviderOptions - /// @{ - - /** \brief Create an OrtROCMProviderOptions - * - * \param[out] out Newly created ::OrtROCMProviderOptions. Must be released with OrtApi::ReleaseROCMProviderOptions - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.16. - */ - ORT_API2_STATUS(CreateROCMProviderOptions, _Outptr_ OrtROCMProviderOptions** out); - - /** \brief Set options in a ROCm Execution Provider. - * - * Please refer to https://onnxruntime.ai/docs/execution-providers/ROCm-ExecutionProvider.html - * to know the available keys and values. Key should be in null terminated string format of the member of - * ::OrtROCMProviderOptions and value should be its related range. - * - * For example, key="device_id" and value="0" - * - * \param[in] rocm_options - * \param[in] provider_options_keys Array of UTF-8 null-terminated string for provider options keys - * \param[in] provider_options_values Array of UTF-8 null-terminated string for provider options values - * \param[in] num_keys Number of elements in the `provider_option_keys` and `provider_options_values` arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.16. - */ - ORT_API2_STATUS(UpdateROCMProviderOptions, _Inout_ OrtROCMProviderOptions* rocm_options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** - * Get serialized ROCm provider options string. - * - * For example, "device_id=0;arena_extend_strategy=0;......" - * - * \param rocm_options - OrtROCMProviderOptions instance - * \param allocator - a ptr to an instance of OrtAllocator obtained with CreateAllocator() or GetAllocatorWithDefaultOptions() - * the specified allocator will be used to allocate continuous buffers for output strings and lengths. - * \param ptr - is a UTF-8 null terminated string allocated using 'allocator'. The caller is responsible for using the same allocator to free it. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.16. - */ - ORT_API2_STATUS(GetROCMProviderOptionsAsString, _In_ const OrtROCMProviderOptions* rocm_options, _Inout_ OrtAllocator* allocator, _Outptr_ char** ptr); - - /** \brief Release an ::OrtROCMProviderOptions - * - * \note This is an exception in the naming convention of other Release* functions, as the name of the method does not have the V2 suffix, but the type does - * - * \since Version 1.16. - */ - void(ORT_API_CALL* ReleaseROCMProviderOptions)(_Frees_ptr_opt_ OrtROCMProviderOptions* input); - - /** \brief Create an allocator with specific type and register it with the ::OrtEnv - * This API enhance CreateAndRegisterAllocator that it can create an allocator with specific type, not just CPU allocator - * Enables sharing the allocator between multiple sessions that use the same env instance. - * Lifetime of the created allocator will be valid for the duration of the environment. - * Returns an error if an allocator with the same ::OrtMemoryInfo is already registered. - * \param[in] env OrtEnv instance - * \param[in] provider_type ExecutionProvider type - * \param[in] mem_info OrtMemoryInfo instance - * \param[in] arena_cfg Arena configuration - * \param[in] provider_options_keys key of the provider options map - * \param[in] provider_options_values value of the provider options map - * \param[in] num_keys Length of the provider options map - */ - ORT_API2_STATUS(CreateAndRegisterAllocatorV2, _Inout_ OrtEnv* env, _In_ const char* provider_type, _In_ const OrtMemoryInfo* mem_info, _In_ const OrtArenaCfg* arena_cfg, - _In_reads_(num_keys) const char* const* provider_options_keys, _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); - - /** \brief Run the model asynchronously in a thread owned by intra op thread pool - * - * \param[in] session - * \param[in] run_options If nullptr, will use a default ::OrtRunOptions - * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names - * \param[in] input Array of ::OrtValue%s of the input values - * \param[in] input_len Number of elements in the input_names and inputs arrays - * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names - * \param[in] output_names_len Number of elements in the output_names and outputs array - * \param[out] output OrtValue* array of size output_names_len. - * On calling RunAsync, output[i] could either be a null or a pointer to a preallocated OrtValue. - * Later, the output array will be passed to run_async_callback with all null(s) filled with valid - * OrtValue pointer(s) allocated by onnxruntime. - * NOTE: it is customer's duty to finally release the output array and each of its member, - * regardless of whether the member (OrtValue*) is allocated by onnxruntime or preallocated by the customer. - * \param[in] run_async_callback Callback function on model run completion - * \param[in] user_data User data that pass back to run_async_callback - */ - ORT_API2_STATUS(RunAsync, _Inout_ OrtSession* session, _In_opt_ const OrtRunOptions* run_options, - _In_reads_(input_len) const char* const* input_names, - _In_reads_(input_len) const OrtValue* const* input, size_t input_len, - _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, - _Inout_updates_all_(output_names_len) OrtValue** output, - _In_ RunAsyncCallbackFn run_async_callback, _In_opt_ void* user_data); - - /** - * Update TensorRT EP provider option where its data type is pointer, for example 'user_compute_stream'. - * If the data type of the provider option can be represented by string please use UpdateTensorRTProviderOptions. - * - * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. - * - * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance - * \param key - Name of the provider option - * \param value - A pointer to the instance that will be assigned to this provider option - * - * \since Version 1.16. - */ - ORT_API2_STATUS(UpdateTensorRTProviderOptionsWithValue, _Inout_ OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _In_ void* value); - - /** - * Get TensorRT EP provider option where its data type is pointer. - * If the data type of the provider option can be represented by string please use GetTensorRTProviderOptionsAsString. - * - * \param tensorrt_options - OrtTensorRTProviderOptionsV2 instance - * \param key - Name of the provider option - * \param ptr - A pointer to the instance that is kept by the provider option - * - * \since Version 1.16. - */ - ORT_API2_STATUS(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProviderOptionsV2* tensorrt_options, _In_ const char* key, _Outptr_ void** ptr); - - /** - * Update CUDA EP provider option where its data type is pointer, for example 'user_compute_stream'. - * If the data type of the provider option can be represented by string please use UpdateCUDAProviderOptions. - * - * Note: It's caller's responsibility to properly manage the lifetime of the instance pointed by this pointer. - * - * \param cuda_options - OrtCUDAProviderOptionsV2 instance - * \param key - Name of the provider option - * \param value - A pointer to the instance that will be assigned to this provider option - * - * \since Version 1.16. - */ - ORT_API2_STATUS(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value); - - /** - * Get CUDA EP provider option where its data type is pointer. - * If the data type of the provider option can be represented by string please use GetCUDAProviderOptionsAsString. - * - * \param cuda_options - OrtCUDAProviderOptionsV2 instance - * \param key - Name of the provider option - * \param ptr - A pointer to the instance that is kept by the provider option - * - * \since Version 1.16. - */ - ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr); - - /** - * Get a EP resource. - * E.g. a cuda stream or a cublas handle - * - * \param context - Kernel context - * \param resource_version - Version of the resource - * \param resource_id - Type of resource - * \param resource - A pointer to returned resource - * - * \since Version 1.16. - */ - ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resource_version, - _In_ int resource_id, _Outptr_ void** resource); - - /** \brief Set user logging function - * - * By default the logger created by the CreateEnv* functions is used to create the session logger as well. - * This function allows a user to override this default session logger with a logger of their own choosing. This way - * the user doesn't have to create a separate environment with a custom logger. This addresses the problem when - * the user already created an env but now wants to use a different logger for a specific session (for debugging or - * other reasons). - * - * \param[in] options - * \param[in] user_logging_function A pointer to a logging function. - * \param[in] user_logging_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `user_logging_function`. This parameter is optional. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.17. - */ - ORT_API2_STATUS(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, - _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); - - /** - * Get number of input from OrtShapeInferContext - * - * \param[in] context - * \param[out] out The number of inputs - * - * \since Version 1.17. - */ - ORT_API2_STATUS(ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out); - - /** - * Get type and shape info of an input - * - * \param[in] context - * \param[in] index The index of the input - * \param[out] info Type shape info of the input - * - * \since Version 1.17. - */ - ORT_API2_STATUS(ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info); - - /** - * Get attribute from OrtShapeInferContext. Note that OrtShapeInferContext is a per-node context, one could only read attribute from current node. - * - * \param[in] context - * \param[in] attr_name Name of the attribute - * \param[out] attr Handle of the attribute fetched - * - * \since Version 1.17. - */ - ORT_API2_STATUS(ShapeInferContext_GetAttribute, _In_ const OrtShapeInferContext* context, _In_ const char* attr_name, _Outptr_ const OrtOpAttr** attr); - - /** - * Set type and shape info of an output - * - * \param[in] context - * \param[in] index The index of the output - * \param[out] info Type shape info of the output - * - * \since Version 1.17. - */ - ORT_API2_STATUS(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info); - - /** - * Set symbolic shape to type shape info - * - * \param[in] info Type shape info - * \param[in] dim_params Symbolic strings - * \param[in] dim_params_length Number of strings - * - * \since Version 1.17. - */ - ORT_API2_STATUS(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, _In_ const char* dim_params[], _In_ size_t dim_params_length); - - /** - * Read contents of an attribute to data - * - * \param[in] op_attr - * \param[in] type Attribute type - * \param[out] data Memory address to save raw content of the attribute - * \param[in] len Number of bytes allowed to store in data - * \param[out] out Number of bytes required to save the data when the call failed, or the real number of bytes saved to data on success - * - * \since Version 1.17. - */ - ORT_API2_STATUS(ReadOpAttr, _In_ const OrtOpAttr* op_attr, _In_ OrtOpAttrType type, _Inout_ void* data, _In_ size_t len, _Out_ size_t* out); - - /** \brief Set whether to use deterministic compute. - * - * Default is false. If set to true, this will enable deterministic compute for GPU kernels where possible. - * Note that this most likely will have a performance cost. - * - * \param[in] options - * \param[in] value - * - * \since Version 1.17. - */ - ORT_API2_STATUS(SetDeterministicCompute, _Inout_ OrtSessionOptions* options, bool value); - - /** - * Run fn in parallel - * - * \param[in] context - * \param[in] fn Function accepting usr_data and an integer as iterator - * \param[in] total The number of times fn is to be invoked - * \param[in] num_batch Number of batches by which the "total" is to be divided in maximum. When zero, there is no limit - * \param[in] usr_data User data to be passed back to fn - * - * \since Version 1.17. - */ - ORT_API2_STATUS(KernelContext_ParallelFor, _In_ const OrtKernelContext* context, _In_ void (*fn)(void*, size_t), _In_ size_t total, _In_ size_t num_batch, _In_ void* usr_data); - - /** \brief Append OpenVINO execution provider to the session options - * - * If OpenVINO is not available (due to a non OpenVINO enabled build, or if OpenVINO is not installed on the system), this function will fail. - * - * \param[in] options - * \param[in] provider_options_keys - * \param[in] provider_options_values - * \param[in] num_keys - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.17. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_OpenVINO_V2, - _In_ OrtSessionOptions* options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** \brief Append VitisAI provider to session options - * - * If VitisAI is not available (due to a non VitisAI enabled build, or if VitisAI is not installed on the system), this function will return failure. - * - * \param[in] options - * \param[in] provider_options_keys - * \param[in] provider_options_values - * \param[in] num_keys - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.18. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_VitisAI, - _In_ OrtSessionOptions* options, - _In_reads_(num_keys) const char* const* provider_options_keys, - _In_reads_(num_keys) const char* const* provider_options_values, - _In_ size_t num_keys); - - /** \brief Get scratch buffer from the corresponding allocator under the specific OrtMemoryInfo object. - * NOTE: callers are responsible to release this scratch buffer from the corresponding allocator - * \param[in] context OrtKernelContext instance - * \param[in] mem_info OrtMemoryInfo instance - * \param[in] count_or_bytes How many bytes is this scratch buffer - * \param[out] out A pointer to the scratch buffer - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.18. - */ - ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); - - /** \brief Get allocator from KernelInfo for a specific memory type. Please use C API ReleaseAllocator to release out object - * - * \param[in] info OrtKernelInfo instance - * \param[in] mem_type OrtMemType object - * \param[out] out A pointer to OrtAllocator - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.18. - */ - ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); - - /** \brief Replace initialized Tensors with external data with the provided files in memory - * - * The function will find the initialized TensorProtos with external data in the graph with the provided - * external file names and the file content in memory. The API gets the external file name, offset, data length - * from TensorProto, and locate the tensor data from the file in memory buffer. - * It creates a Tensor to replace the existing Tensor in graph. The replacement - * will occur before any of the optimizations take place. The data will be copied into the graph - * since TensorProto can't refer to the user provided buffers. - * - * \param[in] options - * \param[in] external_initializer_file_names Array of null terminated UTF-8 encoded strings of the file names - * which holds the external initializers. - * \param[in] external_initializer_file_buffer_array Array of pointers to the buffer of the file content. - * The buffer can be freed after session creation. - * \param[in] external_initializer_file_lengths Array of size_t to indicate the length of file content - * \param[in] num_external_initializer_files Number of external files - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.18. - */ - ORT_API2_STATUS(AddExternalInitializersFromFilesInMemory, _In_ OrtSessionOptions* options, - _In_reads_(num_external_initializer_files) const ORTCHAR_T* const* external_initializer_file_names, - _In_reads_(num_external_initializer_files) char* const* external_initializer_file_buffer_array, - _In_reads_(num_external_initializer_files) const size_t* external_initializer_file_lengths, - size_t num_external_initializer_files); - - /** \brief Create an OrtLoraAdapter - * - * The function attempts to locate file specified by adapter_file_path, read it and create an OrtLoraAdapter - * instance. The adapter_file_path should be a valid path to a file that contains a valid Lora Adapter - * format. The function attempts to validate the format at load time. The file will always be memory mapped, unless - * the platform does not support memory mapping, in which case the file will be read into memory. - * - * \param[in] adapter_file_path adapter file path. - * \param[in] allocator optional pointer to a device allocator. If specified - * data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU. - * The data would still be copied to device if required by the model at inference time. - * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with - * OrtApi::ReleaseLoraAdapter. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.20. - */ - ORT_API2_STATUS(CreateLoraAdapter, const ORTCHAR_T* adapter_file_path, _In_ OrtAllocator* allocator, - _Outptr_ OrtLoraAdapter** out); - - /** \brief Create an OrtLoraAdapter - * - * The function copies the bytes from the array and creates an OrtLoraAdapter instance. - * - * - * \param[in] bytes pointer to a valid Lora Adapter format buffer. - * \param[in] num_bytes length of bytes buffer. - * \param[in] allocator optional pointer to a device allocator. If specified - * data is copied to the device at some point before Run() is invoked. If nullptr, data stays on CPU. - * The data would still be copied to device if required by the model at inference time. - * \param[out] out A pointer to a newly created OrtLoraAdapter instance. Must be released with - * OrtApi::ReleaseLoraAdapter. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.20. - */ - ORT_API2_STATUS(CreateLoraAdapterFromArray, _In_ const void* bytes, size_t num_bytes, _In_ OrtAllocator* allocator, - _Outptr_ OrtLoraAdapter** out); - - /** \brief Release an ::OrtLoraAdapter obtained from OrtApi::CreateLoraAdapter - */ - ORT_CLASS_RELEASE(LoraAdapter); - - /** \brief Add the Lora Adapter to the list of active adapters. - * - * The function adds the Lora Adapter to the list of active adapters. The Lora Adapter must be created with - * OrtApi::CreateLoraAdapter or FromArray. The Lora Adapter will be used by the session to run the model. - * The instance of the OrtRunOptions can then be used to customize the Run() calls. - * More than one OrtLoraAdapter can be active at the same time. Lora Parameters that belong to different - * Lora adapters that will be active at the same time must not overlap. - * This setting does not affect RunWithBinding. - * - * \param[in] options OrtRunOptions instance - * \param[in] adapter OrtLoraAdapter instance - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.20. - */ - ORT_API2_STATUS(RunOptionsAddActiveLoraAdapter, _Inout_ OrtRunOptions* options, _In_ const OrtLoraAdapter* adapter); - - /// @} - /// \name OrtEpDynamicOptions - /// @{ - - /** \brief Set DynamicOptions for EPs (Execution Providers) - * - * Valid options can be found in `include\onnxruntime\core\session\onnxruntime_session_options_config_keys.h` - * Look for `kOrtEpDynamicOptions` - * - * \param[in] sess OrtSession - * \param[in] keys Array of null terminated UTF8 encoded strings of EP dynamic option keys - * \param[in] values Array of null terminated UTF8 encoded string of EP dynamic option values - * \param[in] kv_len Number of elements in the keys and values arrays - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.20. - */ - ORT_API2_STATUS(SetEpDynamicOptions, _Inout_ OrtSession* sess, _In_reads_(kv_len) const char* const* keys, - _In_reads_(kv_len) const char* const* values, _In_ size_t kv_len); - - /** \brief Release an OrtValueInfo instance if it was not added to an OrtGraph. - * \since Version 1.22. - */ - ORT_CLASS_RELEASE(ValueInfo); - - /** \brief Release an OrtNode if it was not added to an OrtGraph. - * \since Version 1.22. - */ - ORT_CLASS_RELEASE(Node); - - /** \brief Release an OrtGraph. - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.22. - */ - ORT_CLASS_RELEASE(Graph); - - /** \brief Release an OrtModel. - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.22. - */ - ORT_CLASS_RELEASE(Model); - - /** \brief Get the value name from an OrtValueInfo instance. - * \param[in] value_info The OrtValueInfo instance. - * \param[out] name The name of the OrtValueInfo - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.22. - */ - ORT_API2_STATUS(GetValueInfoName, _In_ const OrtValueInfo* value_info, _Out_ const char** name); - - /** \brief Get the type information from an OrtValueInfo instance. - * \param[in] value_info The OrtValueInfo instance. - * \param[out] type_info The type info of the OrtValueInfo - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.22. - */ - ORT_API2_STATUS(GetValueInfoTypeInfo, _In_ const OrtValueInfo* value_info, _Outptr_ const OrtTypeInfo** type_info); - - /** \brief Get the Model Editor API instance - * - * Get the Model Editor API instance to create a new model or augment an existing model. - * - * \return Model Editor API struct - * - * \since Version 1.22. - */ - const OrtModelEditorApi*(ORT_API_CALL* GetModelEditorApi)(); - - /** \brief Create an OrtValue for a Tensor that uses pre-existing memory. - * - * ORT will take ownership of the memory and free it using the provided deleter when no longer in use. - * - * \param[in] deleter OrtAllocator instance that will be used to free the memory. - * Only the OrtAllocator:Info and OrtAllocator::Release functions are required. - * The OrtMemoryInfo returned by OrtAllocator::Info must match the location of p_data. - * \param[in] p_data Pointer to the memory that will be used by the Tensor. ORT will take ownership of the memory. - * \param[in] p_data_len Length of the memory in bytes. - * \param[in] shape Dimensions of the Tensor. All values should be > 0. - * \param[in] shape_len Number of dimensions in the shape array. - * \param[in] type Data type of the Tensor. - * \param[out] out Newly created ::OrtValue. Must be freed with OrtApi::ReleaseValue - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateTensorWithDataAndDeleterAsOrtValue, _In_ OrtAllocator* deleter, - _In_ void* p_data, size_t p_data_len, - _In_ const int64_t* shape, size_t shape_len, - ONNXTensorElementDataType type, - _Outptr_ OrtValue** out); - - /** \brief sets load cancellation flag to abort session loading process. - * - * \param[in] options instance that was passed to the session at creation time. - * \param[in] cancel setting this to true after model loading process was initiated will - * attempt to cancel the loading process. If cancellation is successful, CreateSession() - * CreateSessionFromArray() or any other session creation API that take session options as an - * argument will return an OrtStatus indicating that session loading was canceled at user request, - * error code ORT_MODEL_LOAD_CANCELED. - * The APIs above would not return any valid Session instance. This is the best case effort and the result - * is not guaranteed. The session may have already been created and initialized - * before the cancellation request was issued. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(SessionOptionsSetLoadCancellationFlag, _Inout_ OrtSessionOptions* options, - _In_ bool cancel); - - /** \brief Get the Compile API instance. - * - * Get the Compile API instance to compile ONNX models. Execution providers that support compilation fuse a subgraph - * into an EPContext node that wraps a provider-specific binary representation of the subgraph. - * For more details about the EPContext design, refer to: - * \htmlonly - * EPContext design document. - * \endhtmlonly - * - * \return Compile API struct instance. - * - * \since Version 1.22. - */ - const OrtCompileApi*(ORT_API_CALL* GetCompileApi)(); - - // - // OrtKeyValuePairs - // - - /** \brief Create an OrtKeyValuePairs instance. - * - * \param[out] out A pointer to a newly created OrtKeyValuePairs instance. - * - * \note Must be released by calling ReleaseKeyValuePairs. - * - * \since Version 1.22. - */ - void(ORT_API_CALL* CreateKeyValuePairs)(_Outptr_ OrtKeyValuePairs** out); - - /** \brief Add a key-value pair to the OrtKeyValuePairs instance. - * - * \param[in] kvps OrtKeyValuePairs instance. - * \param[in] key Key to be added. - * \param[in] value Value to be added. - * - * \note The `key` and `value` are copied internally. - * - * \since Version 1.22. - */ - - void(ORT_API_CALL* AddKeyValuePair)(_In_ OrtKeyValuePairs* kvps, _In_ const char* key, _In_ const char* value); - - /** \brief Get the value associated with a key in the OrtKeyValuePairs instance. - * - * \param[in] kvps OrtKeyValuePairs instance. - * \param[in] key Key to be searched. - * - * \return The value associated with the key, or nullptr if the key does not exist. - * - * \since Version 1.22. - */ - const char*(ORT_API_CALL* GetKeyValue)(_In_ const OrtKeyValuePairs* kvps, _In_ const char* key); - - /** \brief Get all the key-value pairs from the OrtKeyValuePairs instance. - * - * \param[in] kvps OrtKeyValuePairs instance. - * \param[out] keys Array of keys from `kvps`. - * \param[out] values Array of values from `kvps`. - * \param[out] num_entries Number of entries in `keys` and `values`. - * - * \since Version 1.22. - */ - void(ORT_API_CALL* GetKeyValuePairs)(_In_ const OrtKeyValuePairs* kvps, - _Outptr_ const char* const** keys, _Outptr_ const char* const** values, - _Out_ size_t* num_entries); - - /** \brief Remove a key-value pair from the OrtKeyValuePairs instance. - * - * \param[in] kvps OrtKeyValuePairs instance. - * \param[in] key Key to be removed. No error if not found. - * - * \since Version 1.22. - */ - void(ORT_API_CALL* RemoveKeyValuePair)(_In_ OrtKeyValuePairs* kvps, _In_ const char* key); - - /** \brief Release an OrtKeyValuePairs instance. - * - * \param[in] input OrtKeyValuePairs instance to be released. - * - * \since Version 1.22. - */ - ORT_CLASS_RELEASE(KeyValuePairs); - - /** \brief Register an execution provider library with ORT. - * - * The library must export 'CreateEpFactories' and 'ReleaseEpFactory' functions. - * See OrtEpApi for more details. - * - * \param[in] env The OrtEnv instance to register the library in. - * \param[in] registration_name The name to register the execution provider library under. - * \param[in] path The path to the execution provider library. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(RegisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name, - _In_ const ORTCHAR_T* path); - - /** \brief Unregister an execution provider library with ORT. - * - * ORT will call ReleaseEpFactory for all factories created by the library, and unload the library. - * - * You MUST ensure there are no Session instances using execution providers created by the library - * before calling this function. - * - * \param[in] env The OrtEnv instance to unregister the library from. - * \param[in] registration_name The name the execution provider library was registered under. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(UnregisterExecutionProviderLibrary, _In_ OrtEnv* env, _In_ const char* registration_name); - - /** \brief Get the list of available OrtEpDevice instances. - * - * Each OrtEpDevice instance contains details of the execution provider and the device it will use. - * - * \param[in] env The OrtEnv instance to query. - * \param[out] ep_devices The OrtEpDevice instances that the execution provider will use. - * \param[out] num_ep_devices The number of OrtEpDevice instances returned. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(GetEpDevices, _In_ const OrtEnv* env, - _Outptr_ const OrtEpDevice* const** ep_devices, _Out_ size_t* num_ep_devices); - - /** \brief Append the execution provider that is responsible for the selected OrtEpDevice instances - * to the session options. - * - * \param[in] session_options Session options to add execution provider to. - * \param[in] env Environment that execution providers were registered with. - * \param[in] ep_devices One or more OrtEpDevice instances to create an execution provider for. - * Obtain from GetEpDevices. All OrtEpDevice instances must be from the same execution - * provider. It is only necessary to provide multiple OrtEpDevices if you want to use the - * same execution provider for multiple devices. - * e.g. the EP is capable of running on GPU and NPU. - * \param[in] num_ep_devices Number of OrtEpDevice instances. - * \param[in] ep_option_keys Optional keys to configure the execution provider. - * \param[in] ep_option_vals Optional values to configure the execution provider. - * \param[in] num_ep_options Number of execution provide options to add. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_V2, _In_ OrtSessionOptions* session_options, - _In_ OrtEnv* env, - _In_reads_(num_ep_devices) const OrtEpDevice* const* ep_devices, _In_ size_t num_ep_devices, - _In_reads_(num_op_options) const char* const* ep_option_keys, - _In_reads_(num_op_options) const char* const* ep_option_vals, - size_t num_ep_options); - - /** \brief Set the execution provider selection policy for the session. - * - * Allows users to specify a device selection policy for automatic execution provider (EP) selection. - * If custom selection is required please use SessionOptionsSetEpSelectionPolicyDelegate instead. - * - * \param[in] session_options The OrtSessionOptions instance. - * \param[in] policy The device selection policy to use (see OrtExecutionProviderDevicePolicy). - * - * \since Version 1.22 - */ - ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicy, _In_ OrtSessionOptions* session_options, - _In_ OrtExecutionProviderDevicePolicy policy); - - /** \brief Set the execution provider selection policy delegate for the session. - * - * Allows users to provide a custom device selection policy for automatic execution provider (EP) selection. - * - * \param[in] session_options The OrtSessionOptions instance. - * \param[in] delegate Delegate callback for custom selection. - * \param[in] delegate_state Optional state that will be passed to the delegate callback. nullptr if not required. - * - * \since Version 1.22 - */ - ORT_API2_STATUS(SessionOptionsSetEpSelectionPolicyDelegate, _In_ OrtSessionOptions* session_options, - _In_ EpSelectionDelegate delegate, - _In_opt_ void* delegate_state); - - /** \brief Get the hardware device type. - * - * \param[in] device The OrtHardwareDevice instance to query. - * \return The hardware device type. - * - * \since Version 1.22. - */ - OrtHardwareDeviceType(ORT_API_CALL* HardwareDevice_Type)(_In_ const OrtHardwareDevice* device); - - /** \brief Get the hardware device's vendor identifier. - * - * \param[in] device The OrtHardwareDevice instance to query. - * \return The hardware device vendor identifier. - * - * \since Version 1.22. - */ - uint32_t(ORT_API_CALL* HardwareDevice_VendorId)(_In_ const OrtHardwareDevice* device); - - /** \brief Get the hardware device's vendor name. - * - * \param[in] device The OrtHardwareDevice instance to query. - * \return The hardware device's vendor name. - * - * \since Version 1.22. - */ - const char*(ORT_API_CALL* HardwareDevice_Vendor)(_In_ const OrtHardwareDevice* device); - - /** \brief Get the hardware device's unique identifier. - * - * \param[in] device The OrtHardwareDevice instance to query. - * \return The device id. - * - * \note This is not a unique identifier. It identifies the hardware type when combined with vendor id. - * \since Version 1.22. - */ - uint32_t(ORT_API_CALL* HardwareDevice_DeviceId)(_In_ const OrtHardwareDevice* device); - - /** \brief Get hardware device metadata. - * - * \param[in] device The OrtHardwareDevice instance to query. - * \return An OrtKeyValuePairs instance containing the metadata for the device. - * Note: ORT owns the instance so the user must not call ReleaseKeyValuePairs with it. - * - * \since Version 1.22. - */ - const OrtKeyValuePairs*(ORT_API_CALL* HardwareDevice_Metadata)(_In_ const OrtHardwareDevice* device); - - /** \brief Get the execution provider name. - * - * \param[in] ep_device The OrtEpDevice instance to query. - * \return The execution provider name. - * - * \since Version 1.22. - */ - const char*(ORT_API_CALL* EpDevice_EpName)(_In_ const OrtEpDevice* ep_device); - - /** \brief Get the execution provider's vendor name. - * - * \param[in] ep_device The OrtEpDevice instance to query. - * \return The execution provider's vendor name. - * - * \since Version 1.22. - */ - const char*(ORT_API_CALL* EpDevice_EpVendor)(_In_ const OrtEpDevice* ep_device); - - /** \brief Get the metadata for the OrtEpDevice. - * - * \param[in] ep_device The OrtEpDevice instance to query. - * \return An OrtKeyValuePairs instance containing the metadata for the device. - * - * \since Version 1.22. - */ - const OrtKeyValuePairs*(ORT_API_CALL* EpDevice_EpMetadata)(_In_ const OrtEpDevice* ep_device); - - /** \brief Get the execution provider options for the OrtEpDevice. - * - * \param[in] ep_device The OrtEpDevice instance to query. - * \return An OrtKeyValuePairs instance containing the execution provider options for the device. - * - * \since Version 1.22. - */ - const OrtKeyValuePairs*(ORT_API_CALL* EpDevice_EpOptions)(_In_ const OrtEpDevice* ep_device); - - /** \brief Get the OrtHardwareDevice instance for the OrtEpDevice. - * - * \param[in] ep_device The OrtEpDevice instance to query. - * \return The OrtHardwareDevice instance for the device. - * - * \since Version 1.22. - */ - const OrtHardwareDevice*(ORT_API_CALL* EpDevice_Device)(_In_ const OrtEpDevice* ep_device); - - /** \brief Get the OrtEpApi instance for implementing an execution provider. - * - * \since Version 1.22. - */ - const OrtEpApi*(ORT_API_CALL* GetEpApi)(); - - /** \brief Compute total size in bytes of the tensor data contained in an OrtValue. - * - * Returns the total number of bytes used to store the tensor data. For numeric tensors, - * this is sizeof(element_type) * total_element_count. OrtValues that are not tensors or - * that are tensors that contain strings will cause an error to be returned. - * - * \param[in] ort_value OrtValue instance containing a tensor - * \param[out] size The total size of the tensor data in bytes - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23 - */ - ORT_API2_STATUS(GetTensorSizeInBytes, _In_ const OrtValue* ort_value, _Out_ size_t* size); -}; - -/* - * Steps to use a custom op: - * 1 Create an OrtCustomOpDomain with the domain name used by the custom ops - * 2 Create an OrtCustomOp structure for each op and add them to the domain - * 3 Call OrtAddCustomOpDomain to add the custom domain of ops to the session options - */ - -// Specifies some characteristics of inputs/outputs of custom ops: -// Specify if the inputs/outputs are one of: -// 1) Non-optional (input/output must be present in the node) -// 2) Optional (input/output may be absent in the node) -// 3) Variadic: A variadic input or output specifies N (i.e., the minimum arity) or more operands. -// Only the last input or output of a custom op may be marked as variadic. -// The homogeneity of the variadic input or output determines whether all operands must be of the same -// tensor element type. -typedef enum OrtCustomOpInputOutputCharacteristic { - INPUT_OUTPUT_REQUIRED = 0, - INPUT_OUTPUT_OPTIONAL, - INPUT_OUTPUT_VARIADIC, -} OrtCustomOpInputOutputCharacteristic; - -/* - * The OrtCustomOp structure defines a custom op's schema and its kernel callbacks. The callbacks are filled in by - * the implementor of the custom op. - */ -struct OrtCustomOp { - uint32_t version; // Must be initialized to ORT_API_VERSION - - // This callback creates the kernel, which is a user defined - // parameter that is passed to the Kernel* callbacks below. It is - // recommended to use CreateKernelV2 which allows for a safe error - // propagation by returning an OrtStatusPtr. - void*(ORT_API_CALL* CreateKernel)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, - _In_ const OrtKernelInfo* info); - - // Returns the name of the op - const char*(ORT_API_CALL* GetName)(_In_ const struct OrtCustomOp* op); - - // Returns the type of the execution provider, return nullptr to use CPU execution provider - const char*(ORT_API_CALL* GetExecutionProviderType)(_In_ const struct OrtCustomOp* op); - - // Returns the count and types of the input & output tensors - ONNXTensorElementDataType(ORT_API_CALL* GetInputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - size_t(ORT_API_CALL* GetInputTypeCount)(_In_ const struct OrtCustomOp* op); - ONNXTensorElementDataType(ORT_API_CALL* GetOutputType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - size_t(ORT_API_CALL* GetOutputTypeCount)(_In_ const struct OrtCustomOp* op); - - // Perform a computation step. It is recommended to use - // KernelComputeV2 which allows for a safe error propagation by - // returning an OrtStatusPtr. - void(ORT_API_CALL* KernelCompute)(_In_ void* op_kernel, _In_ OrtKernelContext* context); - void(ORT_API_CALL* KernelDestroy)(_In_ void* op_kernel); - - // Returns the characteristics of the input & output tensors - OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetInputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - OrtCustomOpInputOutputCharacteristic(ORT_API_CALL* GetOutputCharacteristic)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - - // Returns the memory type of the input tensors. This API allows the custom op - // to place the inputs on specific devices. By default, it returns - // OrtMemTypeDefault, which means the input is placed on the default device for - // the execution provider. If the inputs need to be with different memory types, - // this function can be overridden to return the specific memory types. - OrtMemType(ORT_API_CALL* GetInputMemoryType)(_In_ const struct OrtCustomOp* op, _In_ size_t index); - - // Returns the minimum number of input arguments expected for the variadic input. - // Applicable only for custom ops that have a variadic input. - int(ORT_API_CALL* GetVariadicInputMinArity)(_In_ const struct OrtCustomOp* op); - - // Returns true (non-zero) if all arguments of a variadic input have to be of the same type (homogeneous), - // and false (zero) otherwise. - // Applicable only for custom ops that have a variadic input. - int(ORT_API_CALL* GetVariadicInputHomogeneity)(_In_ const struct OrtCustomOp* op); - - // Returns the minimum number of output values expected for the variadic output. - // Applicable only for custom ops that have a variadic output. - int(ORT_API_CALL* GetVariadicOutputMinArity)(_In_ const struct OrtCustomOp* op); - - // Returns true (non-zero) if all outputs values of a variadic output have to be of the same type (homogeneous), - // and false (zero) otherwise. - // Applicable only for custom ops that have a variadic output. - int(ORT_API_CALL* GetVariadicOutputHomogeneity)(_In_ const struct OrtCustomOp* op); - - // Create the kernel state which is passed to each compute call. - OrtStatusPtr(ORT_API_CALL* CreateKernelV2)(_In_ const struct OrtCustomOp* op, _In_ const OrtApi* api, - _In_ const OrtKernelInfo* info, - _Out_ void** kernel); - - // Perform the computation step. - OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context); - - OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*); - - // Get start range - int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); - int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); - - // Get the inplace_map that defines which output can reuse which input - // Callers will provide 2 raw int* and pass in their address, this function will fill these 2 arrays - // when return, output (*output_index)[i] may reuse the input (*input_index[i]). - // The return value is the size of these 2 arrays. - // Callers are responsible to delete these 2 arrays after use by calling OrtCustomOp::ReleaseMayInplace(). - size_t(ORT_API_CALL* GetMayInplace)(_Out_ int** input_index, _Out_ int** output_index); - - // Release the pointer input_index and output_index allocated from GetMayInplace() function. - // If GetMayInplace() is defined, this function MUST be defined as well. - void(ORT_API_CALL* ReleaseMayInplace)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); - - // Same as GetMayInplace() and ReleaseMayInplace() - size_t(ORT_API_CALL* GetAliasMap)(_Out_ int** input_index, _Out_ int** output_index); - void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); -}; - -/** - * ORT Model Editor API - */ - -/** - * \brief The OrtModelEditorApi struct provides functions to create or edit an ONNX model. - * - * See onnxruntime/test/shared_lib/test_model_editor_api.cc for example usage. - * - * \since Version 1.22. - */ -struct OrtModelEditorApi { - // Model building/editing requires a full build. We return nullptr from GetModelEditorApi if this is a minimal - // build, so it doesn't matter if there are no function pointers in this struct as a user will never get an - // OrtModelEditorApi instance. We do however need a dummy field to avoid empty struct warning. -#if defined(ORT_MINIMAL_BUILD) - const bool not_defined_in_this_build; -#else - /** \brief Create an OrtTypeInfo instance for a Tensor. - * - * Create an OrtTypeInfo instance for a Tensor to use as graph inputs/outputs with the Model Editor API. - * - * User can release `tensor_info` after creating the OrtTypeInfo. - * - * \param[in] tensor_info Tensor type and shape information. - * \param[out] type_info TypeInfo instance for the tensor. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, - _Outptr_ OrtTypeInfo** type_info); - - /** \brief Create an OrtTypeInfo instance for a SparseTensor. - * - * Create an OrtTypeInfo instance for a SparseTensor to use as graph inputs/outputs with the Model Editor API. - * - * User can release `tensor_info` after creating the OrtTypeInfo. - * - * \param[in] tensor_info SparseTensor type and shape information. - * \param[out] type_info TypeInfo instance for the tensor. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateSparseTensorTypeInfo, _In_ const OrtTensorTypeAndShapeInfo* tensor_info, - _Outptr_ OrtTypeInfo** type_info); - - /** \brief Create an OrtTypeInfo instance for a Map. - * - * Create an OrtTypeInfo instance for a Map to use as graph inputs/outputs with the Model Editor API. - * - * User can release `map_value_type` after creating the OrtTypeInfo. - * - * \param[in] map_key_type Key type for the map. - * \param[in] map_value_type Value type for the map. - * \param[out] type_info TypeInfo instance for the map. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateMapTypeInfo, ONNXTensorElementDataType map_key_type, _In_ const OrtTypeInfo* map_value_type, - _Outptr_ OrtTypeInfo** type_info); - - /** \brief Create an OrtTypeInfo instance for a Sequence. - * - * Create an OrtTypeInfo instance for a Sequence to use as graph inputs/outputs with the Model Editor API. - * - * User can release `sequence_type` after creating the OrtTypeInfo. - * - * \param[in] sequence_type Sequence type and shape information. - * \param[out] type_info TypeInfo instance for the sequence. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateSequenceTypeInfo, _In_ const OrtTypeInfo* sequence_type, _Outptr_ OrtTypeInfo** type_info); - - /** \brief Create an OrtTypeInfo instance for an Optional. - * - * Create an OrtTypeInfo instance for an Optional to use as graph inputs/outputs with the Model Editor API. - * - * User can release `contained_type` after creating the OrtTypeInfo. - * - * \param[in] contained_type Tensor type and shape information. - * \param[out] type_info TypeInfo instance for the tensor. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateOptionalTypeInfo, _In_ const OrtTypeInfo* contained_type, _Outptr_ OrtTypeInfo** type_info); - - /** \brief Create an OrtValueInfo for use as an OrtGraph input or output. - * - * \param[in] name The name of the input or output. - * \param[in] type_info The type information for the input or output. The provided value is copied. - * \param[out] value_info The OrtValueInfo instance. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateValueInfo, _In_ const char* name, _In_ const OrtTypeInfo* type_info, - _Outptr_ OrtValueInfo** value_info); - - /** \brief Create an OrtNode to add to an OrtGraph. - * - * Create an OrtNode. - * - * Create attributes with CreateOpAttr. OrtOpAttr instances are copied. - * - * \param[in] operator_name The name of the operator. - * \param[in] domain_name The domain of the operator. Use an empty string for ONNX operators. - * \param[in] node_name The name of the node. - * \param[in] input_names The names of the inputs. - * \param[in] input_names_len The number of input names. - * \param[in] output_names The names of the outputs. - * \param[in] output_names_len The number of output names. - * \param[in] attributes The optional attributes of the node. - * \param[in] attribs_len The number of attributes. May be zero. - * \param[out] node The OrtNode instance. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateNode, _In_ const char* operator_name, _In_ const char* domain_name, _In_ const char* node_name, - _In_reads_(input_names_len) const char* const* input_names, size_t input_names_len, - _In_reads_(output_names_len) const char* const* output_names, size_t output_names_len, - _In_reads_(attribs_len) _In_opt_ OrtOpAttr** attributes, _In_ size_t attribs_len, - _Outptr_ OrtNode** node); - - /** \brief Create an OrtGraph - * \snippet{doc} snippets.dox OrtStatus Return Value - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateGraph, _Outptr_ OrtGraph** graph); - - /** \brief Set the inputs for the OrtGraph. - * - * Set the graph inputs. This will replace any existing inputs with the new values. - * The OrtGraph takes ownership of the OrtValueInfo instances and you should NOT call ReleaseOrtValueInfo. - * - * \param[in] graph The OrtGraph instance to update. - * \param[in] inputs The input OrtValueInfo instances. - * \param[in] inputs_len The number of input OrtValueInfo instances. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(SetGraphInputs, _Inout_ OrtGraph* graph, - _In_reads_(inputs_len) _In_ OrtValueInfo** inputs, _In_ size_t inputs_len); - - /** \brief Set the outputs for the OrtGraph. - * - * Set the graph outputs. This will replace any existing outputs with the new values. - * The OrtGraph takes ownership of the OrtValueInfo instances provided and you should NOT call ReleaseOrtValueInfo. - * - * \param[in] graph The OrtGraph instance to update. - * \param[in] outputs The output OrtValueInfo instances. - * \param[in] outputs_len The number of output OrtValueInfo instances. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(SetGraphOutputs, _Inout_ OrtGraph* graph, - _In_reads_(outputs_len) _In_ OrtValueInfo** outputs, _In_ size_t outputs_len); - - /** \brief Add an initializer to the OrtGraph - * - * ORT will take ownership of the OrtValue and you should NOT call ReleaseOrtValue. - * - * Two options: - * - * Allocated memory: - * Use CreateTensorAsOrtValue (allocates memory) and populate the tensor with the data. - * Set `data_is_external` to false. - * - * Pre-existing memory: - * Use CreateTensorWithDataAsOrtValue or CreateTensorWithDataAndDeleterAsOrtValue to create an OrtValue - * with a tensor that contains a pointer to the existing data. - * Set `data_is_external` to true. - * - * The pointer must remain valid for the duration of the inference session. - * If using CreateTensorWithDataAsOrtValue you are responsible for freeing the memory after the inference session - * is released. - * If using CreateTensorWithDataAndDeleterAsOrtValue, ORT will free the memory using the provided deleter as - * soon as the OrtValue is no longer in use. - * - * NOTE: A tensor containing pre-existing memory MUST have 128 bytes of data or more. - * For smaller tensors use CreateTensorAsOrtValue. - * - * ONNX shape inferencing does not support external data. An initializer involved in shape inferencing is - * typically small (a single value or limited by the rank of a tensor) and uses less than 128 bytes of - * memory, so this limit acts as a simple catch-all rule to avoid issues. - * e.g. Reshape's `shape`, Clip's `min` and `max`, various ops `axes`. - * - * \param[in] graph The OrtGraph instance to update. - * \param[in] name The value name for the initializer. - * \param[in] tensor The OrtValue instance containing the tensor data. - * \param[in] data_is_external Set to true if the data is external and should not be copied. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(AddInitializerToGraph, _Inout_ OrtGraph* graph, _In_ const char* name, _In_ OrtValue* tensor, - bool data_is_external); - - /** \brief Add an OrtNode to an OrtGraph - * - * Add the node to the graph. The OrtGraph will take ownership of OrtNode and you should NOT call ReleaseOrtNode. - * - * \param[in] graph The OrtGraph instance to update. - * \param[in] node The OrtNode instance to add to the graph. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(AddNodeToGraph, _Inout_ OrtGraph* graph, _In_ OrtNode* node); - - /** \brief Create an OrtModel. - * - * Create an OrtModel. - * - * This can be used to build a new model, or to augment an existing model. - * - * \param[in] domain_names The domain names for the model. - * If augmenting an existing model add additional domains if needed. - * \param[in] opset_versions The opset versions for the model. - * If augmenting an existing model add additional opset versions if needed. - * \param[in] opset_entries_len The number of domain_names and opset_versions entries. - * Domain and opset entries should be 1:1 - * \param[out] model The OrtModel instance. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateModel, - _In_reads_(opset_entries_len) const char* const* domain_names, - _In_reads_(opset_entries_len) const int* opset_versions, - size_t opset_entries_len, - _Outptr_ OrtModel** model); - - /** \brief Add an OrtGraph to an OrtModel. - * - * Add the graph to a model. This should be called once when creating a new model. - * - * The OrtModel takes ownership of the OrtGraph and you should NOT call ReleaseOrtGraph. - * - * \param[in] model The OrtModel instance to update. - * \param[in] graph The OrtGraph instance to add to the model. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(AddGraphToModel, _Inout_ OrtModel* model, _In_ OrtGraph* graph); - - /** \brief Create an OrtSession using the OrtModel. - * - * Create an inference session using the OrtModel instance. - * The OrtModel should have been populated with an OrtGraph containing nodes and initializers, and SetGraphInputs - * and SetGraphOutputs must have been called. - * This will validate the model, run optimizers, and prepare the session for inferencing. - * - * ReleaseOrtModel must be called to free the OrtModel after session creation. - * - * \param[in] env The OrtEnv instance. - * \param[in] model The OrtModel instance. - * \param[in] options The OrtSessionOptions instance. - * \param[out] out The OrtSession instance. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateSessionFromModel, _In_ const OrtEnv* env, _In_ const OrtModel* model, - _In_ const OrtSessionOptions* options, _Outptr_ OrtSession** out); - - /** \brief Create an OrtSession to augment an existing model. - * - * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. - * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the - * model is finalized. - * - * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. - * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. - * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made - * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. - * - * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the - * session for inferencing by calling FinalizeModelEditorSession. - * - * \param{in} env The OrtEnv instance. - * \param{in} model_path The path to the existing ONNX model to augment. - * \param{in} options The OrtSessionOptions instance. - * \param{out} out The created OrtSession instance. - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateModelEditorSession, _In_ const OrtEnv* env, _In_ const ORTCHAR_T* model_path, - _In_ const OrtSessionOptions* options, - _Outptr_ OrtSession** out); - - /** \brief Create an OrtSession to augment an existing model. - * - * Create an OrtSession with an existing model that will be augmented with additional nodes and initializers. - * Nodes can be added before or after the existing nodes in the model. ONNX Runtime will connect the nodes when the - * model is finalized. - * - * To add nodes and initializers to the existing model, first create an OrtModel using CreateModel. - * Add nodes and initializers to the OrtModel using AddNodeToGraph and AddInitializerToGraph. - * Graph inputs/outputs should be updated with SetGraphInputs and SetGraphOutputs as needed to reflect changes made - * by the new nodes. The list of graph inputs/outputs should be for the overall model and not just the new nodes. - * - * Add the new information from the OrtModel to the original model using ApplyModelToSession, and prepare the - * session for inferencing by calling FinalizeModelEditorSession. - * - * \param{in} env The OrtEnv instance. - * \param{in} model_data The model data for the existing model to augment. - * \param{in} model_data_length The length of the model data. - * \param{in} options The OrtSessionOptions instance. - * \param{out} out The created OrtSession instance. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateModelEditorSessionFromArray, _In_ const OrtEnv* env, - _In_ const void* model_data, size_t model_data_length, - _In_ const OrtSessionOptions* options, - _Outptr_ OrtSession** out); - - /** \brief Query the session for the opset version of a domain. - * - * When using the Model Editor API to augment a model, any new nodes must conform to the opset version of the - * original model. To do that the user must be able to discover that opset version. - * Returns an error if the domain is not used in the model. - * - * \param[in] session OrtSession to query - * \param[in] domain Domain to query. The ONNX domain is an empty string. - * \param[out] opset The opset version of the domain. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(SessionGetOpsetForDomain, _In_ const OrtSession* session, _In_ const char* domain, _Out_ int* opset); - - /** \brief Apply changes to augment the ONNX model in a session created using CreateModelEditorSession[FromArray] - * - * Adds new nodes and updates graph inputs/outputs using `model` to augment the original ONNX model in the session. - * All changes will be validated. - * Call FinalizeModelEditorSession to prepare the session for inferencing. - * - * Existing input/outputs will only be updated if the OrtGraph inputs/outputs are set in the OrtModel. - * i.e. you don't need to call SetGraphInputs/SetGraphOutputs if they are unchanged. - * - * ReleaseOrtModel must be called to free the OrtModel after it is applied to the session. - * - * \param[in] session OrtSession to update. Session must have been created using CreateModelEditorSession[FromArray]. - * \param[in] model OrtModel containing new nodes, new initializers, and updated graph input and/or output info. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(ApplyModelToModelEditorSession, _Inout_ OrtSession* session, _In_ OrtModel* model); - - /** \brief Finalize the Model Editor session that was created using CreateModelEditorSession[FromArray]. - * - * Finalize the Model Editor session that augmented an ONNX model by adding new nodes. - * This will run optimizers and prepare the session for inferencing. - * - * \param[in] session OrtSession to finalize. Session must have been created using CreateModelEditorSession[FromArray]. - * \param[in] options OrtSessionOptions to use for the session. - * \param[in] prepacked_weights_container Optional OrtPrepackedWeightsContainer to use for the session. - Set to nullptr if not used. - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(FinalizeModelEditorSession, _Inout_ OrtSession* session, _In_ const OrtSessionOptions* options, - _In_opt_ OrtPrepackedWeightsContainer* prepacked_weights_container); -#endif // !defined(ORT_MINIMAL_BUILD) -}; - -/** - * ORT Compile API - */ - -/** \brief Flags representing options to enable when compiling a model. - */ -typedef enum OrtCompileApiFlags { - // Default. Do not enable any additional compilation options. - OrtCompileApiFlags_NONE = 0, - - // Force compilation to return an error (ORT_FAIL) if no nodes were compiled. - // Otherwise, a model with basic optimizations (ORT_ENABLE_BASIC) is still generated by default. - OrtCompileApiFlags_ERROR_IF_NO_NODES_COMPILED = 1 << 0, - - // Force compilation to return an error (ORT_FAIL) if a file with the same filename as the output model exists. - // Otherwise, compilation will automatically overwrite the output file if it exists. - OrtCompileApiFlags_ERROR_IF_OUTPUT_FILE_EXISTS = 1 << 1, -} OrtCompileApiFlags; - -/** - * \brief The OrtCompileApi struct provides functions to compile ONNX models. - * - * Execution providers that support compilation fuse a subgraph into an EPContext node that wraps a provider-specific - * binary representation of the subgraph. - * For more details about the EPContext design, refer to: - * \htmlonly - * EPContext design document. - * \endhtmlonly - * - * Example (error handling not shown): - * OrtStatus* status = NULL; - * OrtCompileApi* compile_api = ort_api->GetCompileApi(); - * OrtModelCompilationOptions* compile_options = NULL; - * - * status = compile_api->CreateModelCompilationOptionsFromSessionOptions(env, session_options, &compile_options); - * status = compile_api->ModelCompilationOptions_SetInputModelPath(compile_options, ORT_TSTR("model.onnx")); - * status = compile_api->ModelCompilationOptions_SetOutputModelPath(compile_options, ORT_TSTR("model.compiled.onnx")); - * status = compile_api->CompileModel(env, compile_options); - * compile_api->ReleaseModelCompilationOptions(compile_options); - * - * \since Version 1.22. - */ -struct OrtCompileApi { - /// @} - /// \name OrtModelCompilationOptions - /// @{ - ORT_CLASS_RELEASE(ModelCompilationOptions); - - /** \brief Creates an OrtModelCompilationOptions object from an existing OrtSessionOptions object. - * - * An OrtModelCompilationOptions object contains the settings used to generate a compiled ONNX model. - * The OrtSessionOptions object has the execution providers with which the model will be compiled. - * - * ReleaseOrtModelCompilationsOptions must be called to free the OrtModelCompilationOptions after calling - * CompileModel. - * - * \param[in] env OrtEnv object. - * \param[in] session_options The OrtSessionOptions instance from which to create the OrtModelCompilationOptions. - * \param[out] out The created OrtModelCompilationOptions instance. - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateModelCompilationOptionsFromSessionOptions, _In_ const OrtEnv* env, - _In_ const OrtSessionOptions* session_options, _Outptr_ OrtModelCompilationOptions** out); - - /** \brief Sets the file path to the input ONNX model to compile. - * - * The input model's location (e.g., file path or memory buffer) must be set with either - * ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer. - * - * \param[in] model_compile_options The OrtModelCompilationOptions instance. - * \param[in] input_model_path Null terminated string of the path (wchar on Windows, char otherwise). - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(ModelCompilationOptions_SetInputModelPath, _In_ OrtModelCompilationOptions* model_compile_options, - _In_ const ORTCHAR_T* input_model_path); - - /** \brief Sets the buffer that stores the bytes of the loaded ONNX model to compile. - * - * The input model's location (e.g., file path or memory buffer) must be set with either - * ModelCompilationOptions_SetInputModelPath or ModelCompilationOptions_SetInputModelFromBuffer. - * - * \param[in] model_compile_options The OrtModelCompilationOptions instance. - * \param[in] input_model_data Buffer containing the loaded ONNX model bytes. - * \param[in] input_model_data_size The number of bytes in the `input_model_data` buffer. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(ModelCompilationOptions_SetInputModelFromBuffer, - _In_ OrtModelCompilationOptions* model_compile_options, - _In_ const void* input_model_data, - size_t input_model_data_size); - - /** \brief Sets the file path for the output ONNX model generated by CompileModel. - * - * The output model's location (e.g., file path or memory buffer) can be set with either - * ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer. - * - * If the output model's location is not set, ONNX Runtime will generate an output file with a path based on - * the input model's file path. Examples: - * /Path/my_model.onnx -> /Path/my_model_ctx.onnx - * /Path/my_model -> /Path/my_model_ctx.onnx - * - * \param[in] model_compile_options The OrtModelCompilationOptions instance. - * \param[in] output_model_path Null terminated string of the path (wchar on Windows, char otherwise). - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelPath, _In_ OrtModelCompilationOptions* model_compile_options, - _In_ const ORTCHAR_T* output_model_path); - - /** \brief Optionally sets the file that should store external initializers for the compiled ONNX model. - * If not set, initializers are stored within the model. - * - * Only initializers for nodes that were not compiled are stored in the external initializers file. - * Compiled nodes contain their initializer data within the `ep_cache_context` attribute of EPContext nodes. - * Refer to ModelCompilationOptions_SetEpContextEmbedMode. - * - * \param[in] model_compile_options The OrtModelCompilationOptions instance. - * \param[in] external_initializers_file_path Null terminated string of the path to the file. - * \param[in] external_initializers_size_threshold Initializers larger than this threshold are stored in the file. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelExternalInitializersFile, - _In_ OrtModelCompilationOptions* model_compile_options, - _In_ const ORTCHAR_T* external_initializers_file_path, - size_t external_initializers_size_threshold); - - /** \brief Configures model compilation to store the output compiled ONNX model in a buffer. - * - * The caller passes an OrtAllocator that ONNX Runtime uses to allocate memory for the buffer. - * - * The output model's location (e.g., file path or memory buffer) can be set with either - * ModelCompilationOptions_SetOutputModelPath or ModelCompilationOptions_SetOutputModelBuffer. - * - * If the output model's location is not set, ONNX Runtime will generate an output file with a path based on - * the input model's file path. Examples: - * /Path/my_model.onnx -> /Path/my_model_ctx.onnx - * /Path/my_model -> /Path/my_model_ctx.onnx - * - * \param[in] model_compile_options The OrtModelCompilationOptions instance. - * \param[in] allocator The allocator used to allocate the buffer for the compiled model. - * \param[out] output_model_buffer_ptr Pointer to the buffer that stores the compiled model. - * \param[out] output_model_buffer_size_ptr Pointer set to the size of output model in bytes. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(ModelCompilationOptions_SetOutputModelBuffer, - _In_ OrtModelCompilationOptions* model_compile_options, - _Inout_ OrtAllocator* allocator, - _Outptr_ void** output_model_buffer_ptr, - _Out_ size_t* output_model_buffer_size_ptr); - - /** \brief Enables or disables the embedding of EPContext binary data into the `ep_cache_context` attribute - * of EPContext nodes. Defaults to false. - * - * If enabled, the `ep_cache_context` attribute of EPContext nodes will store the context binary data, which may - * include weights for compiled subgraphs. - * - * If disabled, the `ep_cache_context` attribute of EPContext nodes will contain the path to the file containing the - * context binary data. The path is set by the execution provider creating the EPContext node. - * - * More details relate to EPContext design refers to: - * \htmlonly - * EPContext design document. - * \endhtmlonly - * - * \param[in] model_compile_options The OrtModelCompilationOptions instance. - * \param[in] embed_ep_context_in_model True to embed EPContext binary data into the EPContext node - * `ep_cache_context` attributes. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(ModelCompilationOptions_SetEpContextEmbedMode, _In_ OrtModelCompilationOptions* model_compile_options, - bool embed_ep_context_in_model); - - /** \brief Compiles an input ONNX model with the given compilation options. - * - * \param[in] env OrtEnv object. - * \param[in] model_options The compilation options that defines compilation options for a model. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CompileModel, _In_ const OrtEnv* env, _In_ const OrtModelCompilationOptions* model_options); - - /** \brief Sets flags from OrtCompileApiFlags that represent one or more boolean options to enable. - * - * \param[in] model_compile_options The OrtModelCompilationOptions instance. - * \param[in] flags bitwise OR of flags in OrtCompileApiFlags to enable. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.23. - */ - ORT_API2_STATUS(ModelCompilationOptions_SetFlags, _In_ OrtModelCompilationOptions* model_compile_options, - size_t flags); -}; - -ORT_RUNTIME_CLASS(Ep); -ORT_RUNTIME_CLASS(EpFactory); - -struct OrtEpApi { - /** \brief Create an OrtEpDevice for the EP and an OrtHardwareDevice. - * \param[in] ep_factory Execution provider factory that is creating the instance. - * \param[in] hardware_device Hardware device that the EP can utilize. - * \param[in] ep_metadata Optional OrtKeyValuePairs instance for execution provider metadata that may be used - * during execution provider selection and passed to CreateEp. - * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. - * \param[in] ep_options Optional OrtKeyValuePairs instance for execution provider options that will be added - * to the Session configuration options if the execution provider is selected. - * ep_device will copy this instance and the user should call ReleaseKeyValuePairs. - * \param ep_device OrtExecutionDevice that is created. - * - * \since Version 1.22. - */ - ORT_API2_STATUS(CreateEpDevice, _In_ OrtEpFactory* ep_factory, - _In_ const OrtHardwareDevice* hardware_device, - _In_opt_ const OrtKeyValuePairs* ep_metadata, - _In_opt_ const OrtKeyValuePairs* ep_options, - _Out_ OrtEpDevice** ep_device); - - ORT_CLASS_RELEASE(EpDevice); -}; - -/** - * \brief The OrtEp struct provides functions to implement for an execution provider. - * \since Version 1.22. - */ -struct OrtEp { - /** \brief The ONNX Runtime version the execution provider was compiled with. - * - * Implementation should set to ORT_API_VERSION. - * ORT will use this to ensure it does not call functions that were not available when the library was compiled. - * - * \since Version 1.22. - */ - uint32_t ort_version_supported; - - /** \brief Get the execution provider name. - * - * \param[in] this_ptr The OrtEp instance. - * \return The execution provider name. - * - * \note Returned string is owned by ORT and valid until UnregisterExecutionProviderLibrary is called. - * - * \since Version 1.22. - */ - const char*(ORT_API_CALL* GetName)(const OrtEp* this_ptr); - - // OrtStatus* GetCapability(OrtEp* ep, const OrtGraph* graph, - // size_t* num_supported_subgraphs, - // OrtIndexedSubgraph** supported_subgraphs, OrtAllocator* allocator); - - // OrtStatus* Compile(OrtEp* ep, const OrtGraph** graphs, OrtNode** fused_graph_nodes, - // size_t count, OrtNodeComputeInfo* node_compute_infos); - - // TODO: Implement OrtEpApi and the complete OrtEp interface as the next step. -}; - -/** \brief The function signature that ORT will call to create OrtEpFactory instances. - * - * This must be available in a function called 'CreateEpFactories' in the execution provider library. - * - * \param[in] registered_name The name the execution library is registered with by RegisterExecutionProviderLibrary - * \param[in] ort_api_base The OrtApiBase instance that is used by the factory to get the OrtApi instance for the - * version of ORT that the library was compiled against. - * \param[in,out] factories The implementation should create and add OrtEpFactory instances to this - * pre-allocated array. - * i.e. usage is `factories[0] = new MyEpFactory();` - * \param[in] max_factories The maximum number of OrtEpFactory instances that can be added to `factories`. - * Current default is to allow 4 factories. This can be increased in the future if needed. - * \param[out] num_factories The number of OrtEpFactory instances created by the factory and added to `factories`. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ -typedef OrtStatus* (*CreateEpApiFactoriesFn)(_In_ const char* registered_name, _In_ const OrtApiBase* ort_api_base, - _Inout_ OrtEpFactory** factories, _In_ size_t max_factories, - _Out_ size_t* num_factories); - -/** \brief The function signature that ORT will call to release an OrtEpFactory instance. - * - * This must be available in a function called 'ReleaseEpFactory' in the execution provider library. - * - * \param[in] factory The OrtEpFactory instance to release. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version 1.22. - */ -typedef OrtStatus* (*ReleaseEpApiFactoryFn)(_In_ OrtEpFactory* factory); - -/** - * \brief The OrtEpFactory provides functions to create and manage execution providers. - * \since Version 1.22. - */ -struct OrtEpFactory { - /** \brief The ONNX Runtime version the execution provider was compiled with. - * - * Implementation should set to ORT_API_VERSION. - * ORT will use this to ensure it does not call functions that were not available when the library was compiled. - * - * \since Version 1.22. - */ - uint32_t ort_version_supported; - - /** \brief Get the name the of the execution provider that the factory creates. - * - * \param[in] this_ptr The OrtEpFactory instance. - * \return The name of the execution provider the factory creates. - * - * \since Version 1.22. - */ - const char*(ORT_API_CALL* GetName)(const OrtEpFactory* this_ptr); - - /** \brief Get the name of vendor who owns the execution provider that the factory creates. - * - * \param[in] this_ptr The OrtEpFactory instance. - * \return vendor The vendor name of the execution provider the factory creates. - * - * \since Version 1.22. - */ - const char*(ORT_API_CALL* GetVendor)(const OrtEpFactory* this_ptr); // return EP vendor - - /** \brief Get information from the execution provider if it supports the OrtHardwareDevice. - * - * \param[in] this_ptr The OrtEpFactory instance. - * Non-const as the factory is passed through to the CreateEp call via the OrtEpDevice. - * \param[in] devices The OrtHardwareDevice instances that are available. - * \param[in] num_devices The number of OrtHardwareDevice instances. - * \param[out] ep_devices OrtEpDevice instances for each OrtHardwareDevice that the EP can use. - * The implementation should call OrtEpApi::CreateEpDevice to create, and add the OrtEpDevice - * instances to this pre-allocated array. ORT will take ownership of the values returned. - * i.e. usage is `ep_devices[0] = ;` - * \param[in] max_ep_devices The maximum number of OrtEpDevices that can be added to ep_devices. - * Current default is 8. This can be increased if needed. - * \param[out] num_ep_devices The number of EP devices added to ep_devices. - * \return true if the factory can create an execution provider that uses `device`. - * - * \note ORT will take ownership or ep_metadata and/or ep_options if they are not null. - * - * \since Version 1.22. - */ - OrtStatus*(ORT_API_CALL* GetSupportedDevices)(_In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_ size_t num_devices, - _Inout_ OrtEpDevice** ep_devices, - _In_ size_t max_ep_devices, - _Out_ size_t* num_ep_devices); - - /** \brief Function to create an OrtEp instance for use in a Session. - * - * ORT will call ReleaseEp to release the instance when it is no longer needed. - * - * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] devices The OrtHardwareDevice instances that the execution provider was selected to use. - * \param[in] ep_metadata_pairs Execution provider metadata that was provided to OrtEpApi::CreateEpDevice, for each - * device. - * \param[in] num_devices The number of devices the execution provider was selected for. - * \param[in] session_options The OrtSessionOptions instance that contains the configuration options for the - * session. This will include ep_options from GetSupportedDevices as well as any - * user provided overrides. - * Execution provider options will have been added with a prefix of 'ep.[ep name].'. - * The OrtSessionOptions instance will NOT be valid after this call and should not be - * stored for later use. - * \param[in] logger The OrtLogger instance for the session that the execution provider should use for logging. - * \param[out] ep The OrtEp instance created by the factory. - * - * \snippet{doc} snippets.dox OrtStatus Return Value - * - * \since Version [coming soon]. This is a placeholder. - */ - OrtStatus*(ORT_API_CALL* CreateEp)(_In_ OrtEpFactory* this_ptr, - _In_reads_(num_devices) const OrtHardwareDevice* const* devices, - _In_reads_(num_devices) const OrtKeyValuePairs* const* ep_metadata_pairs, - _In_ size_t num_devices, - _In_ const OrtSessionOptions* session_options, - _In_ const OrtLogger* logger, _Outptr_ OrtEp** ep); - - /** \brief Release the OrtEp instance. - * - * \param[in] this_ptr The OrtEpFactory instance. - * \param[in] ep The OrtEp instance to release. - * - * \since Version [coming soon]. This is a placeholder. - */ - void(ORT_API_CALL* ReleaseEp)(OrtEpFactory* this_ptr, struct OrtEp* ep); -}; - -/* - * This is the old way to add the CUDA provider to the session, please use SessionOptionsAppendExecutionProvider_CUDA above to access the latest functionality - * This function always exists, but will only succeed if Onnxruntime was built with CUDA support and the CUDA provider shared library exists - * - * \param device_id CUDA device id, starts from zero. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id); - -/* - * This is the old way to add the ROCm provider to the session, please use - * SessionOptionsAppendExecutionProvider_ROCM above to access the latest functionality - * This function always exists, but will only succeed if Onnxruntime was built with - * HIP support and the ROCm provider shared library exists - * - * \param device_id HIP device id, starts from zero. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_ROCM, _In_ OrtSessionOptions* options, int device_id); - -/* - * This is the old way to add the MIGraphX provider to the session, please use - * SessionOptionsAppendExecutionProvider_MIGraphX above to access the latest functionality - * This function always exists, but will only succeed if Onnxruntime was built with - * HIP support and the MIGraphX provider shared library exists - * - * \param device_id HIP device id, starts from zero. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id); - -/* - * This is the old way to add the oneDNN provider to the session, please use - * SessionOptionsAppendExecutionProvider_oneDNN above to access the latest functionality - * This function always exists, but will only succeed if Onnxruntime was built with - * oneDNN support and the oneDNN provider shared library exists - * - * \param use_arena zero: false. non-zero: true. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena); - -/* - * This is the old way to add the TensorRT provider to the session, please use SessionOptionsAppendExecutionProvider_TensorRT_V2 above to access the latest functionality - * This function always exists, but will only succeed if Onnxruntime was built with TensorRT support and the TensorRT provider shared library exists - * - * \param device_id CUDA device id, starts from zero. - */ -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id); - -#ifdef __cplusplus -} -#endif /// @} From 79f2cf9781441db45823eded46efcfa549c835f8 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 11:17:31 -0700 Subject: [PATCH 21/33] Add .clang-format --- .clang-format | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 .clang-format diff --git a/.clang-format b/.clang-format new file mode 100644 index 0000000..efda3f8 --- /dev/null +++ b/.clang-format @@ -0,0 +1,24 @@ +--- +# Defaults for all languages. +BasedOnStyle: Google + +# Setting ColumnLimit to 0 so developer choices about where to break lines are maintained. +# Developers are responsible for adhering to the 120 character maximum. +ColumnLimit: 0 +SortIncludes: false +DerivePointerAlignment: false +# Avoid adding spaces between tokens in GSL_SUPPRESS arguments. +# E.g., don't change "GSL_SUPPRESS(r.11)" to "GSL_SUPPRESS(r .11)". +WhitespaceSensitiveMacros: ["GSL_SUPPRESS"] + +# if you want to customize when working locally see https://clang.llvm.org/docs/ClangFormatStyleOptions.html for options. +# See ReformatSource.ps1 for a script to update all source according to the current options in this file. +# e.g. customizations to use Allman bracing and more indenting. +# AccessModifierOffset: -2 +# BreakBeforeBraces: Allman +# CompactNamespaces: false +# IndentCaseLabels: true +# IndentWidth: 4 +# NamespaceIndentation: All + +... From b48e4ea8eac0bd49c622bdff31721b171f891752 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 12:25:50 -0700 Subject: [PATCH 22/33] update --- .github/workflows/linux_ci.yml | 128 +- .github/workflows/macos_ci.yml | 20 +- .github/workflows/reusable_windows_build.yml | 83 + .github/workflows/win_ci.yml | 146 +- CMakePresets.json | 4817 ++++++++++-------- 5 files changed, 3017 insertions(+), 2177 deletions(-) create mode 100644 .github/workflows/reusable_windows_build.yml diff --git a/.github/workflows/linux_ci.yml b/.github/workflows/linux_ci.yml index 9289336..78c6b56 100644 --- a/.github/workflows/linux_ci.yml +++ b/.github/workflows/linux_ci.yml @@ -7,27 +7,27 @@ on: pull_request: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true - + jobs: Linux_arm64_gcc_release: runs-on: ubuntu-24.04-arm steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_release_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_release_workflow Linux_x64_gcc_ubuntu24_release_no_ort: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_release_no_ort_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_release_no_ort_workflow Linux_x64_gcc_ubuntu24_release: runs-on: ubuntu-24.04 @@ -43,16 +43,16 @@ jobs: config-file: ./.github/codeql/codeql-config.yml languages: 'cpp' - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_release_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_release_workflow - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v3 with: category: "/language:cpp" output: sarif-results upload: failure-only - + - name: filter-sarif uses: advanced-security/filter-sarif@v1 with: @@ -68,92 +68,92 @@ jobs: uses: github/codeql-action/upload-sarif@v3 with: sarif_file: sarif-results/cpp.sarif - + Linux_x64_gcc_ubuntu22_release: runs-on: ubuntu-22.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_release_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_release_workflow Linux_x64_gcc_ubuntu24_debug: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_debug_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_debug_workflow Linux_x64_clang_ubuntu24_debug: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_clang_debug_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_clang_debug_workflow Linux_x64_gcc_ubuntu24_debug_asan: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset linux_gcc_debug_asan_workflow + set -e -x + rm -rf build + cmake --workflow --preset linux_gcc_debug_asan_workflow Linux_wasm_debug_asan: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - mkdir -p build - cd build - git clone https://github.com/emscripten-core/emsdk.git - cd emsdk - ./emsdk install latest - ./emsdk activate latest - source emsdk_env.sh - cd .. - CFLAGS="-O0 -g -fsanitize=address" CXXFLAGS="-O0 -g -fsanitize=address" emcmake cmake .. -DCMAKE_BUILD_TYPE=Debug -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON - make -j $(nproc) all - ctest --output-on-failure + set -e -x + rm -rf build + mkdir -p build + cd build + git clone https://github.com/emscripten-core/emsdk.git + cd emsdk + ./emsdk install latest + ./emsdk activate latest + source emsdk_env.sh + cd .. + CFLAGS="-O0 -g -fsanitize=address" CXXFLAGS="-O0 -g -fsanitize=address" emcmake cmake .. -DCMAKE_BUILD_TYPE=Debug -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON + make -j $(nproc) all + ctest --output-on-failure Linux_wasm_release: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - mkdir -p build - cd build - git clone https://github.com/emscripten-core/emsdk.git - cd emsdk - ./emsdk install latest - ./emsdk activate latest - source emsdk_env.sh - cd .. - CFLAGS="-O2 -DNDEBUG -g" CXXFLAGS="-O2 -DNDEBUG -g" emcmake cmake .. -DCMAKE_BUILD_TYPE=Release -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON - make -j $(nproc) all - ctest --output-on-failure + set -e -x + rm -rf build + mkdir -p build + cd build + git clone https://github.com/emscripten-core/emsdk.git + cd emsdk + ./emsdk install latest + ./emsdk activate latest + source emsdk_env.sh + cd .. + CFLAGS="-O2 -DNDEBUG -g" CXXFLAGS="-O2 -DNDEBUG -g" emcmake cmake .. -DCMAKE_BUILD_TYPE=Release -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON + make -j $(nproc) all + ctest --output-on-failure Linux_wasm_release_no_exception: runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - mkdir -p build - cd build - git clone https://github.com/emscripten-core/emsdk.git - cd emsdk - ./emsdk install latest - ./emsdk activate latest - source emsdk_env.sh - cd .. - CFLAGS="-O2 -DNDEBUG -g" CXXFLAGS="-O2 -DNDEBUG -g" emcmake cmake .. -DCMAKE_BUILD_TYPE=Release -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON -DMLAS_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING=ON - make -j $(nproc) all - ctest --output-on-failure + set -e -x + rm -rf build + mkdir -p build + cd build + git clone https://github.com/emscripten-core/emsdk.git + cd emsdk + ./emsdk install latest + ./emsdk activate latest + source emsdk_env.sh + cd .. + CFLAGS="-O2 -DNDEBUG -g" CXXFLAGS="-O2 -DNDEBUG -g" emcmake cmake .. -DCMAKE_BUILD_TYPE=Release -DMLAS_ENABLE_WEBASSEMBLY_THREADS=ON -DMLAS_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING=ON + make -j $(nproc) all + ctest --output-on-failure diff --git a/.github/workflows/macos_ci.yml b/.github/workflows/macos_ci.yml index de92f12..d7cbed2 100644 --- a/.github/workflows/macos_ci.yml +++ b/.github/workflows/macos_ci.yml @@ -7,9 +7,9 @@ on: pull_request: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true - + jobs: # The following one doesn't work on macos-12. It has some compiling errors related to std::date MacOS14_arm64_release: @@ -17,15 +17,15 @@ jobs: steps: - uses: actions/checkout@v4 - run: | - set -e -x - rm -rf build - cmake --workflow --preset macos_arm64_release_workflow - + set -e -x + rm -rf build + cmake --workflow --preset macos_arm64_release_workflow + MacOS14_universal2_release: runs-on: macos-14 steps: - uses: actions/checkout@v4 - - run: | - set -e -x - rm -rf build - cmake --workflow --preset macos_universal2_release_workflow \ No newline at end of file + - run: |- + set -e -x + rm -rf build + cmake --workflow --preset macos_universal2_release_workflow diff --git a/.github/workflows/reusable_windows_build.yml b/.github/workflows/reusable_windows_build.yml new file mode 100644 index 0000000..b3275ae --- /dev/null +++ b/.github/workflows/reusable_windows_build.yml @@ -0,0 +1,83 @@ +name: Reusable Windows Build + +on: + workflow_call: + inputs: + job-name: # For display purposes in the reusable workflow logs + required: true + type: string + cmake-workflow-preset: + required: true + type: string + enable-codeql: + required: false + type: boolean + default: false + codeql-config-file: + required: false + type: string + default: ./.github/codeql/codeql-config.yml + codeql-sarif-output-dir: + required: false + type: string + default: sarif-results + +permissions: + actions: read + contents: read + # security-events: write is needed only if CodeQL is enabled and uploads SARIF + # We'll set it at the job level within this reusable workflow for clarity. + +jobs: + build_and_optional_analyze: + name: ${{ inputs.job-name }} + runs-on: windows-2022 + # Define permissions here based on whether CodeQL might run. + # If enable-codeql can be true, then security-events: write is needed. + permissions: + actions: read + contents: read + security-events: write # Needed if CodeQL analysis runs & uploads + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Initialize CodeQL (if enabled) + if: ${{ inputs.enable-codeql }} + uses: github/codeql-action/init@v3 + with: + config-file: ${{ inputs.codeql-config-file }} + languages: 'cpp' + + - name: Run CMake Workflow + run: | + cmake --workflow --preset ${{ inputs.cmake-workflow-preset }} + shell: cmd # Ensuring shell is explicit for windows runners if needed + + - name: Perform CodeQL Analysis (if enabled) + if: ${{ inputs.enable-codeql }} + uses: github/codeql-action/analyze@v3 + with: + category: "/language:cpp" # Category for the analysis + output: ${{ inputs.codeql-sarif-output-dir }} # Directory for SARIF files + upload: failure-only # Upload SARIF results only on failure + + - name: Filter SARIF (if CodeQL enabled) + if: ${{ inputs.enable-codeql }} + uses: advanced-security/filter-sarif@v1 + with: + patterns: | + +**/*.cc + +**/*.h + -tests/**/*.* + -build/**/*.* + input: ${{ inputs.codeql-sarif-output-dir }}/cpp.sarif + output: ${{ inputs.codeql-sarif-output-dir }}/cpp.sarif.filtered # Output to a new file + + - name: Upload filtered SARIF (if CodeQL enabled) + if: ${{ inputs.enable-codeql }} + uses: github/codeql-action/upload-sarif@v3 + with: + sarif_file: ${{ inputs.codeql-sarif-output-dir }}/cpp.sarif.filtered + category: cpp-${{ inputs.job-name }} # Make category unique if needed \ No newline at end of file diff --git a/.github/workflows/win_ci.yml b/.github/workflows/win_ci.yml index 86815af..e048e2e 100644 --- a/.github/workflows/win_ci.yml +++ b/.github/workflows/win_ci.yml @@ -1,4 +1,5 @@ name: Windows_CI + on: push: branches: @@ -7,103 +8,66 @@ on: pull_request: concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} cancel-in-progress: true - -jobs: - Win32_debug_no_ort: - runs-on: windows-2022 - permissions: - actions: read - contents: read - security-events: write - steps: - - uses: actions/checkout@v4 - - name: Initialize CodeQL - uses: github/codeql-action/init@v3 - with: - config-file: ./.github/codeql/codeql-config.yml - languages: 'cpp' - - run: | - cmake --workflow --preset windows_win32_debug_no_ort_workflow - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 - with: - category: "/language:cpp" - output: sarif-results - upload: failure-only - - - name: filter-sarif - uses: advanced-security/filter-sarif@v1 - with: - patterns: | - +**/*.cc - +**/*.h - -tests/**/*.* - -build/**/*.* - input: sarif-results/cpp.sarif - output: sarif-results/cpp.sarif - - name: Upload SARIF - uses: github/codeql-action/upload-sarif@v3 - with: - sarif_file: sarif-results/cpp.sarif +jobs: + # Win32 Jobs + Win32_debug_no_ort: + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: Win32_Debug_NoOrt_CodeQL + cmake-workflow-preset: windows_win32_debug_no_ort_workflow + enable-codeql: true + # codeql-config-file can be omitted if default is fine Win32_release_no_ort: - runs-on: windows-2022 - steps: - - uses: actions/checkout@v4 - - run: | - cmake --workflow --preset windows_win32_release_no_ort_workflow + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: Win32_Release_NoOrt + cmake-workflow-preset: windows_win32_release_no_ort_workflow + enable-codeql: false + # WinX64 Jobs WinX64_debug_no_ort: - runs-on: windows-2022 - permissions: - actions: read - contents: read - security-events: write - steps: - - uses: actions/checkout@v4 - - name: Initialize CodeQL - uses: github/codeql-action/init@v3 - with: - config-file: ./.github/codeql/codeql-config.yml - languages: 'cpp' - - run: | - cmake --workflow --preset windows_x64_debug_no_ort_workflow - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@v3 - with: - category: "/language:cpp" - output: sarif-results - upload: failure-only - - - name: filter-sarif - uses: advanced-security/filter-sarif@v1 - with: - patterns: | - +**/*.cc - +**/*.h - -tests/**/*.* - -build/**/*.* - input: sarif-results/cpp.sarif - output: sarif-results/cpp.sarif - - - name: Upload SARIF - uses: github/codeql-action/upload-sarif@v3 - with: - sarif_file: sarif-results/cpp.sarif + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: Winx64_Debug_NoOrt_CodeQL + cmake-workflow-preset: windows_x64_debug_no_ort_workflow + enable-codeql: true WinX64_release_no_ort: - runs-on: windows-2022 - steps: - - uses: actions/checkout@v4 - - run: | - cmake --workflow --preset windows_x64_release_no_ort_workflow - + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: Winx64_Release_NoOrt + cmake-workflow-preset: windows_x64_release_no_ort_workflow + enable-codeql: false + WinX64_release: - runs-on: windows-2022 - steps: - - uses: actions/checkout@v4 - - run: | - cmake --workflow --preset windows_x64_release_workflow \ No newline at end of file + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: Winx64_Release + cmake-workflow-preset: windows_x64_release_workflow + enable-codeql: false + + # Windows ARM64 Jobs (New) + WinARM64_debug_no_ort: + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: WinARM64_Debug_NoOrt_CodeQL + cmake-workflow-preset: windows_arm64_debug_no_ort_workflow # Ensure this preset exists + enable-codeql: true + + WinARM64_release_no_ort: + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: WinARM64_Release_NoOrt + cmake-workflow-preset: windows_arm64_release_no_ort_workflow # Ensure this preset exists + enable-codeql: false + + WinARM64_release: + uses: ./.github/workflows/reusable_windows_build.yml + with: + job-name: WinARM64_Release + cmake-workflow-preset: windows_arm64_release_workflow # Ensure this preset exists + enable-codeql: false diff --git a/CMakePresets.json b/CMakePresets.json index 261050b..d3aff74 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -1,5 +1,477 @@ { - "version": 8, + "$schema": "https://cmake.org/cmake/help/latest/_downloads/3e2d73bff478d88a7de0de736ba5e361/schema.json", + "buildPresets": [ + { + "configuration": "Debug", + "configurePreset": "linux_clang_debug", + "name": "linux_clang_debug" + }, + { + "configuration": "Debug", + "configurePreset": "linux_clang_debug_asan", + "name": "linux_clang_debug_asan" + }, + { + "configuration": "Debug", + "configurePreset": "linux_clang_debug_asan_no_ort", + "name": "linux_clang_debug_asan_no_ort" + }, + { + "configuration": "Debug", + "configurePreset": "linux_clang_debug_cov", + "name": "linux_clang_debug_cov" + }, + { + "configuration": "Debug", + "configurePreset": "linux_clang_debug_cov_no_ort", + "name": "linux_clang_debug_cov_no_ort" + }, + { + "configuration": "Debug", + "configurePreset": "linux_clang_debug_no_ort", + "name": "linux_clang_debug_no_ort" + }, + { + "configuration": "Debug", + "configurePreset": "linux_gcc_debug", + "name": "linux_gcc_debug" + }, + { + "configuration": "Debug", + "configurePreset": "linux_gcc_debug_asan", + "name": "linux_gcc_debug_asan" + }, + { + "configuration": "Debug", + "configurePreset": "linux_gcc_debug_asan_no_ort", + "name": "linux_gcc_debug_asan_no_ort" + }, + { + "configuration": "Debug", + "configurePreset": "linux_gcc_debug_no_ort", + "name": "linux_gcc_debug_no_ort" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "linux_gcc_minsizerel", + "name": "linux_gcc_minsizerel" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "linux_gcc_minsizerel_asan", + "name": "linux_gcc_minsizerel_asan" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "linux_gcc_minsizerel_asan_no_ort", + "name": "linux_gcc_minsizerel_asan_no_ort" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "linux_gcc_minsizerel_no_ort", + "name": "linux_gcc_minsizerel_no_ort" + }, + { + "configuration": "Release", + "configurePreset": "linux_gcc_release", + "name": "linux_gcc_release" + }, + { + "configuration": "Release", + "configurePreset": "linux_gcc_release_asan", + "name": "linux_gcc_release_asan" + }, + { + "configuration": "Release", + "configurePreset": "linux_gcc_release_asan_no_ort", + "name": "linux_gcc_release_asan_no_ort" + }, + { + "configuration": "Release", + "configurePreset": "linux_gcc_release_no_ort", + "name": "linux_gcc_release_no_ort" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "linux_gcc_relwithdebinfo", + "name": "linux_gcc_relwithdebinfo" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "linux_gcc_relwithdebinfo_asan", + "name": "linux_gcc_relwithdebinfo_asan" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "linux_gcc_relwithdebinfo_asan_no_ort", + "name": "linux_gcc_relwithdebinfo_asan_no_ort" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "linux_gcc_relwithdebinfo_no_ort", + "name": "linux_gcc_relwithdebinfo_no_ort" + }, + { + "configuration": "Debug", + "configurePreset": "macos_arm64_debug", + "name": "macos_arm64_debug" + }, + { + "configuration": "Debug", + "configurePreset": "macos_arm64_debug_asan", + "name": "macos_arm64_debug_asan" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "macos_arm64_minsizerel", + "name": "macos_arm64_minsizerel" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "macos_arm64_minsizerel_asan", + "name": "macos_arm64_minsizerel_asan" + }, + { + "configuration": "Release", + "configurePreset": "macos_arm64_release", + "name": "macos_arm64_release" + }, + { + "configuration": "Release", + "configurePreset": "macos_arm64_release_asan", + "name": "macos_arm64_release_asan" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "macos_arm64_relwithdebinfo", + "name": "macos_arm64_relwithdebinfo" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "macos_arm64_relwithdebinfo_asan", + "name": "macos_arm64_relwithdebinfo_asan" + }, + { + "configuration": "Debug", + "configurePreset": "macos_universal2_debug", + "name": "macos_universal2_debug" + }, + { + "configuration": "Debug", + "configurePreset": "macos_universal2_debug_asan", + "name": "macos_universal2_debug_asan" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "macos_universal2_minsizerel", + "name": "macos_universal2_minsizerel" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "macos_universal2_minsizerel_asan", + "name": "macos_universal2_minsizerel_asan" + }, + { + "configuration": "Release", + "configurePreset": "macos_universal2_release", + "name": "macos_universal2_release" + }, + { + "configuration": "Release", + "configurePreset": "macos_universal2_release_asan", + "name": "macos_universal2_release_asan" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "macos_universal2_relwithdebinfo", + "name": "macos_universal2_relwithdebinfo" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "macos_universal2_relwithdebinfo_asan", + "name": "macos_universal2_relwithdebinfo_asan" + }, + { + "configuration": "Debug", + "configurePreset": "macos_x86_64_debug", + "name": "macos_x86_64_debug" + }, + { + "configuration": "Debug", + "configurePreset": "macos_x86_64_debug_asan", + "name": "macos_x86_64_debug_asan" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "macos_x86_64_minsizerel", + "name": "macos_x86_64_minsizerel" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "macos_x86_64_minsizerel_asan", + "name": "macos_x86_64_minsizerel_asan" + }, + { + "configuration": "Release", + "configurePreset": "macos_x86_64_release", + "name": "macos_x86_64_release" + }, + { + "configuration": "Release", + "configurePreset": "macos_x86_64_release_asan", + "name": "macos_x86_64_release_asan" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "macos_x86_64_relwithdebinfo", + "name": "macos_x86_64_relwithdebinfo" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "macos_x86_64_relwithdebinfo_asan", + "name": "macos_x86_64_relwithdebinfo_asan" + }, + { + "configuration": "Debug", + "configurePreset": "windows_arm64_debug", + "name": "windows_arm64_debug" + }, + { + "configuration": "Debug", + "configurePreset": "windows_arm64_debug_asan", + "name": "windows_arm64_debug_asan" + }, + { + "configuration": "Debug", + "configurePreset": "windows_arm64_debug_asan_no_ort", + "name": "windows_arm64_debug_asan_no_ort" + }, + { + "configuration": "Debug", + "configurePreset": "windows_arm64_debug_no_ort", + "name": "windows_arm64_debug_no_ort" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_arm64_minsizerel", + "name": "windows_arm64_minsizerel" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_arm64_minsizerel_asan", + "name": "windows_arm64_minsizerel_asan" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_arm64_minsizerel_asan_no_ort", + "name": "windows_arm64_minsizerel_asan_no_ort" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_arm64_minsizerel_no_ort", + "name": "windows_arm64_minsizerel_no_ort" + }, + { + "configuration": "Release", + "configurePreset": "windows_arm64_release", + "name": "windows_arm64_release" + }, + { + "configuration": "Release", + "configurePreset": "windows_arm64_release_asan", + "name": "windows_arm64_release_asan" + }, + { + "configuration": "Release", + "configurePreset": "windows_arm64_release_asan_no_ort", + "name": "windows_arm64_release_asan_no_ort" + }, + { + "configuration": "Release", + "configurePreset": "windows_arm64_release_no_ort", + "name": "windows_arm64_release_no_ort" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_arm64_relwithdebinfo", + "name": "windows_arm64_relwithdebinfo" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_arm64_relwithdebinfo_asan", + "name": "windows_arm64_relwithdebinfo_asan" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_arm64_relwithdebinfo_asan_no_ort", + "name": "windows_arm64_relwithdebinfo_asan_no_ort" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_arm64_relwithdebinfo_no_ort", + "name": "windows_arm64_relwithdebinfo_no_ort" + }, + { + "configuration": "Debug", + "configurePreset": "windows_win32_debug", + "name": "windows_win32_debug" + }, + { + "configuration": "Debug", + "configurePreset": "windows_win32_debug_asan", + "name": "windows_win32_debug_asan" + }, + { + "configuration": "Debug", + "configurePreset": "windows_win32_debug_asan_no_ort", + "name": "windows_win32_debug_asan_no_ort" + }, + { + "configuration": "Debug", + "configurePreset": "windows_win32_debug_no_ort", + "name": "windows_win32_debug_no_ort" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_win32_minsizerel", + "name": "windows_win32_minsizerel" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_win32_minsizerel_asan", + "name": "windows_win32_minsizerel_asan" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_win32_minsizerel_asan_no_ort", + "name": "windows_win32_minsizerel_asan_no_ort" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_win32_minsizerel_no_ort", + "name": "windows_win32_minsizerel_no_ort" + }, + { + "configuration": "Release", + "configurePreset": "windows_win32_release", + "name": "windows_win32_release" + }, + { + "configuration": "Release", + "configurePreset": "windows_win32_release_asan", + "name": "windows_win32_release_asan" + }, + { + "configuration": "Release", + "configurePreset": "windows_win32_release_asan_no_ort", + "name": "windows_win32_release_asan_no_ort" + }, + { + "configuration": "Release", + "configurePreset": "windows_win32_release_no_ort", + "name": "windows_win32_release_no_ort" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_win32_relwithdebinfo", + "name": "windows_win32_relwithdebinfo" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_win32_relwithdebinfo_asan", + "name": "windows_win32_relwithdebinfo_asan" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_win32_relwithdebinfo_asan_no_ort", + "name": "windows_win32_relwithdebinfo_asan_no_ort" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_win32_relwithdebinfo_no_ort", + "name": "windows_win32_relwithdebinfo_no_ort" + }, + { + "configuration": "Debug", + "configurePreset": "windows_x64_debug", + "name": "windows_x64_debug" + }, + { + "configuration": "Debug", + "configurePreset": "windows_x64_debug_asan", + "name": "windows_x64_debug_asan" + }, + { + "configuration": "Debug", + "configurePreset": "windows_x64_debug_asan_no_ort", + "name": "windows_x64_debug_asan_no_ort" + }, + { + "configuration": "Debug", + "configurePreset": "windows_x64_debug_no_ort", + "name": "windows_x64_debug_no_ort" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_x64_minsizerel", + "name": "windows_x64_minsizerel" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_x64_minsizerel_asan", + "name": "windows_x64_minsizerel_asan" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_x64_minsizerel_asan_no_ort", + "name": "windows_x64_minsizerel_asan_no_ort" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_x64_minsizerel_no_ort", + "name": "windows_x64_minsizerel_no_ort" + }, + { + "configuration": "Release", + "configurePreset": "windows_x64_release", + "name": "windows_x64_release" + }, + { + "configuration": "Release", + "configurePreset": "windows_x64_release_asan", + "name": "windows_x64_release_asan" + }, + { + "configuration": "Release", + "configurePreset": "windows_x64_release_asan_no_ort", + "name": "windows_x64_release_asan_no_ort" + }, + { + "configuration": "Release", + "configurePreset": "windows_x64_release_no_ort", + "name": "windows_x64_release_no_ort" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_x64_relwithdebinfo", + "name": "windows_x64_relwithdebinfo" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_x64_relwithdebinfo_asan", + "name": "windows_x64_relwithdebinfo_asan" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_x64_relwithdebinfo_asan_no_ort", + "name": "windows_x64_relwithdebinfo_asan_no_ort" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "windows_x64_relwithdebinfo_no_ort", + "name": "windows_x64_relwithdebinfo_no_ort" + } + ], "cmakeMinimumRequired": { "major": 3, "minor": 28, @@ -7,2532 +479,2581 @@ }, "configurePresets": [ { - "name": "linux_clang_debug", - "displayName": "linux clang debug", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_C_FLAGS": "-ggdb3 -O0", "CMAKE_CXX_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux clang debug", "environment": { "CC": "clang", "CXX": "clang++" - } + }, + "generator": "Unix Makefiles", + "name": "linux_clang_debug" }, { - "name": "linux_clang_debug_asan", - "displayName": "linux clang debug asan", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/asan/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_CXX_STANDARD": "20" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address" }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux clang debug asan", "environment": { "CC": "clang", "CXX": "clang++" - } + }, + "generator": "Unix Makefiles", + "name": "linux_clang_debug_asan" }, { - "name": "linux_clang_debug_asan_no_ort", - "displayName": "linux clang debug asan no ort", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/asan/no_ort", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", "MLAS_NO_ONNXRUNTIME": "ON" }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux clang debug asan no_ort", "environment": { "CC": "clang", "CXX": "clang++" - } + }, + "generator": "Unix Makefiles", + "name": "linux_clang_debug_asan_no_ort" }, { - "name": "linux_clang_debug_cov", - "displayName": "linux clang debug cov", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/cov/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" }, + "displayName": "linux clang debug cov", "environment": { "CC": "clang", "CXX": "clang++" - } + }, + "generator": "Unix Makefiles", + "name": "linux_clang_debug_cov" }, { - "name": "linux_clang_debug_cov_no_ort", - "displayName": "linux clang debug cov no ort", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/cov/no_ort", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20", "MLAS_NO_ONNXRUNTIME": "ON" }, - "environment": { - "CC": "clang", - "CXX": "clang++" - } - }, - { - "name": "linux_clang_debug_no_ort", - "displayName": "linux clang debug no ort", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default/no_ort", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Linux" - }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20", - "MLAS_NO_ONNXRUNTIME": "ON" + "rhs": "Linux", + "type": "equals" }, + "displayName": "linux clang debug cov no_ort", "environment": { "CC": "clang", "CXX": "clang++" - } - }, - { - "name": "linux_gcc_debug", - "displayName": "linux gcc debug", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20" - }, - "environment": { - "CC": "gcc", - "CXX": "g++" - } - }, - { - "name": "linux_gcc_debug_asan", - "displayName": "linux gcc debug asan", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/asan/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address -D_GLIBCXX_DEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_CXX_STANDARD": "20" }, - "environment": { - "CC": "gcc", - "CXX": "g++" - } - }, - { - "name": "linux_gcc_debug_asan_no_ort", - "displayName": "linux gcc debug asan no ort", "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/asan/no_ort", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address -D_GLIBCXX_DEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "MLAS_NO_ONNXRUNTIME": "ON" - }, - "environment": { - "CC": "gcc", - "CXX": "g++" - } + "name": "linux_clang_debug_cov_no_ort" }, { - "name": "linux_gcc_debug_no_ort", - "displayName": "linux gcc debug no ort", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default/no_ort", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_STANDARD": "20", "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20", "MLAS_NO_ONNXRUNTIME": "ON" }, - "environment": { - "CC": "gcc", - "CXX": "g++" - } - }, - { - "name": "linux_gcc_minsizerel", - "displayName": "linux gcc minsizerel", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default/default", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Linux" - }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20" + "rhs": "Linux", + "type": "equals" }, + "displayName": "linux clang debug no_ort", "environment": { - "CC": "gcc", - "CXX": "g++" - } - }, - { - "name": "linux_gcc_minsizerel_asan", - "displayName": "linux gcc minsizerel asan", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/asan/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_CXX_STANDARD": "20" + "CC": "clang", + "CXX": "clang++" }, - "environment": { - "CC": "gcc", - "CXX": "g++" - } + "generator": "Unix Makefiles", + "name": "linux_clang_debug_no_ort" }, { - "name": "linux_gcc_minsizerel_asan_no_ort", - "displayName": "linux gcc minsizerel asan no ort", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/asan/no_ort", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, + "binaryDir": "${sourceDir}/build/default/default", "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG", "CMAKE_CXX_STANDARD": "20", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_C_FLAGS": "-ggdb3 -O0", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" }, - "environment": { - "CC": "gcc", - "CXX": "g++" - } - }, - { - "name": "linux_gcc_minsizerel_no_ort", - "displayName": "linux gcc minsizerel no ort", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default/no_ort", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Linux" - }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20", - "MLAS_NO_ONNXRUNTIME": "ON" + "rhs": "Linux", + "type": "equals" }, + "displayName": "linux gcc debug", "environment": { "CC": "gcc", "CXX": "g++" - } - }, - { - "name": "linux_gcc_release", - "displayName": "linux gcc release", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" }, + "generator": "Unix Makefiles", + "name": "linux_gcc_debug" + }, + { + "binaryDir": "${sourceDir}/build/default/no_ort", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20" + "MLAS_NO_ONNXRUNTIME": "ON" }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux gcc debug no_ort", "environment": { "CC": "gcc", "CXX": "g++" - } + }, + "generator": "Unix Makefiles", + "name": "linux_gcc_debug_no_ort" }, { - "name": "linux_gcc_release_asan", - "displayName": "linux gcc release asan", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/asan/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_CXX_STANDARD": "20" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" }, + "displayName": "linux gcc debug asan", "environment": { "CC": "gcc", "CXX": "g++" - } + }, + "generator": "Unix Makefiles", + "name": "linux_gcc_debug_asan" }, { - "name": "linux_gcc_release_asan_no_ort", - "displayName": "linux gcc release asan no ort", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/asan/no_ort", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", "MLAS_NO_ONNXRUNTIME": "ON" }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux gcc debug asan no_ort", "environment": { "CC": "gcc", "CXX": "g++" - } + }, + "generator": "Unix Makefiles", + "name": "linux_gcc_debug_asan_no_ort" }, { - "name": "linux_gcc_release_no_ort", - "displayName": "linux gcc release no ort", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default/no_ort", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, + "binaryDir": "${sourceDir}/build/default/default", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" }, + "displayName": "linux gcc minsizerel", "environment": { "CC": "gcc", "CXX": "g++" - } + }, + "generator": "Unix Makefiles", + "name": "linux_gcc_minsizerel" }, { - "name": "linux_gcc_relwithdebinfo", - "displayName": "linux gcc relwithdebinfo", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, + "binaryDir": "${sourceDir}/build/default/no_ort", "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20" + "MLAS_NO_ONNXRUNTIME": "ON" }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux gcc minsizerel no_ort", "environment": { "CC": "gcc", "CXX": "g++" - } + }, + "generator": "Unix Makefiles", + "name": "linux_gcc_minsizerel_no_ort" }, { - "name": "linux_gcc_relwithdebinfo_asan", - "displayName": "linux gcc relwithdebinfo asan", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/asan/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_CXX_STANDARD": "20" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" }, + "displayName": "linux gcc minsizerel asan", "environment": { "CC": "gcc", "CXX": "g++" - } + }, + "generator": "Unix Makefiles", + "name": "linux_gcc_minsizerel_asan" }, { - "name": "linux_gcc_relwithdebinfo_asan_no_ort", - "displayName": "linux gcc relwithdebinfo asan no ort", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/asan/no_ort", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", "MLAS_NO_ONNXRUNTIME": "ON" }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux gcc minsizerel asan no_ort", "environment": { "CC": "gcc", "CXX": "g++" - } + }, + "generator": "Unix Makefiles", + "name": "linux_gcc_minsizerel_asan_no_ort" }, { - "name": "linux_gcc_relwithdebinfo_no_ort", - "displayName": "linux gcc relwithdebinfo no ort", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default/no_ort", + "binaryDir": "${sourceDir}/build/default/default", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Linux" + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux gcc release", + "environment": { + "CC": "gcc", + "CXX": "g++" }, + "generator": "Unix Makefiles", + "name": "linux_gcc_release" + }, + { + "binaryDir": "${sourceDir}/build/default/no_ort", "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_CXX_STANDARD": "20", "MLAS_NO_ONNXRUNTIME": "ON" }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux gcc release no_ort", "environment": { "CC": "gcc", "CXX": "g++" - } - }, - { - "name": "macos_arm64_debug", - "displayName": "macos arm64 debug", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Darwin" }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "macos_arm64_debug_asan", - "displayName": "macos arm64 debug asan", "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Darwin" - }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } + "name": "linux_gcc_release_no_ort" }, { - "name": "macos_arm64_minsizerel", - "displayName": "macos arm64 minsizerel", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Darwin" - }, + "binaryDir": "${sourceDir}/build/asan/default", "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "macos_arm64_minsizerel_asan", - "displayName": "macos arm64 minsizerel asan", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Linux", + "type": "equals" }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } + "displayName": "linux gcc release asan", + "environment": { + "CC": "gcc", + "CXX": "g++" + }, + "generator": "Unix Makefiles", + "name": "linux_gcc_release_asan" }, { - "name": "macos_arm64_release", - "displayName": "macos arm64 release", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Darwin" - }, + "binaryDir": "${sourceDir}/build/asan/no_ort", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "macos_arm64_release_asan", - "displayName": "macos arm64 release asan", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Linux", + "type": "equals" }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } + "displayName": "linux gcc release asan no_ort", + "environment": { + "CC": "gcc", + "CXX": "g++" + }, + "generator": "Unix Makefiles", + "name": "linux_gcc_release_asan_no_ort" }, { - "name": "macos_arm64_relwithdebinfo", - "displayName": "macos arm64 relwithdebinfo", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "binaryDir": "${sourceDir}/build/default/default", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux gcc relwithdebinfo", + "environment": { + "CC": "gcc", + "CXX": "g++" }, + "generator": "Unix Makefiles", + "name": "linux_gcc_relwithdebinfo" + }, + { + "binaryDir": "${sourceDir}/build/default/no_ort", "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "macos_arm64_relwithdebinfo_asan", - "displayName": "macos arm64 relwithdebinfo asan", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "MLAS_NO_ONNXRUNTIME": "ON" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux gcc relwithdebinfo no_ort", + "environment": { + "CC": "gcc", + "CXX": "g++" }, + "generator": "Unix Makefiles", + "name": "linux_gcc_relwithdebinfo_no_ort" + }, + { + "binaryDir": "${sourceDir}/build/asan/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Linux", + "type": "equals" + }, + "displayName": "linux gcc relwithdebinfo asan", + "environment": { + "CC": "gcc", + "CXX": "g++" + }, + "generator": "Unix Makefiles", + "name": "linux_gcc_relwithdebinfo_asan" }, { - "name": "macos_universal2_debug", - "displayName": "macos universal2 debug", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "binaryDir": "${sourceDir}/build/asan/no_ort", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Linux", + "type": "equals" }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_STANDARD": "20" - } + "displayName": "linux gcc relwithdebinfo asan no_ort", + "environment": { + "CC": "gcc", + "CXX": "g++" + }, + "generator": "Unix Makefiles", + "name": "linux_gcc_relwithdebinfo_asan_no_ort" }, { - "name": "macos_universal2_debug_asan", - "displayName": "macos universal2 debug asan", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0", + "CMAKE_OSX_ARCHITECTURES": "arm64" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, + "displayName": "macos arm64 debug", + "generator": "Unix Makefiles", + "name": "macos_arm64_debug" + }, + { + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "macos_universal2_minsizerel", - "displayName": "macos universal2 minsizerel", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "CMAKE_OSX_ARCHITECTURES": "arm64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20" - } + "displayName": "macos arm64 debug asan", + "generator": "Unix Makefiles", + "name": "macos_arm64_debug_asan" }, { - "name": "macos_universal2_minsizerel_asan", - "displayName": "macos universal2 minsizerel asan", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_OSX_ARCHITECTURES": "arm64" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, + "displayName": "macos arm64 minsizerel", + "generator": "Unix Makefiles", + "name": "macos_arm64_minsizerel" + }, + { + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "macos_universal2_release", - "displayName": "macos universal2 release", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "CMAKE_OSX_ARCHITECTURES": "arm64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_STANDARD": "20" - } + "displayName": "macos arm64 minsizerel asan", + "generator": "Unix Makefiles", + "name": "macos_arm64_minsizerel_asan" }, { - "name": "macos_universal2_release_asan", - "displayName": "macos universal2 release asan", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_OSX_ARCHITECTURES": "arm64" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, + "displayName": "macos arm64 release", + "generator": "Unix Makefiles", + "name": "macos_arm64_release" + }, + { + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "macos_universal2_relwithdebinfo", - "displayName": "macos universal2 relwithdebinfo", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "CMAKE_OSX_ARCHITECTURES": "arm64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20" - } + "displayName": "macos arm64 release asan", + "generator": "Unix Makefiles", + "name": "macos_arm64_release_asan" }, { - "name": "macos_universal2_relwithdebinfo_asan", - "displayName": "macos universal2 relwithdebinfo asan", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_OSX_ARCHITECTURES": "arm64" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, + "displayName": "macos arm64 relwithdebinfo", + "generator": "Unix Makefiles", + "name": "macos_arm64_relwithdebinfo" + }, + { + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "macos_x86_64_debug", - "displayName": "macos x86 64 debug", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "CMAKE_OSX_ARCHITECTURES": "arm64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_STANDARD": "20" - } + "displayName": "macos arm64 relwithdebinfo asan", + "generator": "Unix Makefiles", + "name": "macos_arm64_relwithdebinfo_asan" }, { - "name": "macos_x86_64_debug_asan", - "displayName": "macos x86 64 debug asan", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, + "displayName": "macos universal2 debug", + "generator": "Unix Makefiles", + "name": "macos_universal2_debug" + }, + { + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "macos_x86_64_minsizerel", - "displayName": "macos x86 64 minsizerel", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20" - } + "displayName": "macos universal2 debug asan", + "generator": "Unix Makefiles", + "name": "macos_universal2_debug_asan" }, { - "name": "macos_x86_64_minsizerel_asan", - "displayName": "macos x86 64 minsizerel asan", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, + "displayName": "macos universal2 minsizerel", + "generator": "Unix Makefiles", + "name": "macos_universal2_minsizerel" + }, + { + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "macos_x86_64_release", - "displayName": "macos x86 64 release", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_STANDARD": "20" - } + "displayName": "macos universal2 minsizerel asan", + "generator": "Unix Makefiles", + "name": "macos_universal2_minsizerel_asan" }, { - "name": "macos_x86_64_release_asan", - "displayName": "macos x86 64 release asan", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, + "displayName": "macos universal2 release", + "generator": "Unix Makefiles", + "name": "macos_universal2_release" + }, + { + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "macos_x86_64_relwithdebinfo", - "displayName": "macos x86 64 relwithdebinfo", - "generator": "Unix Makefiles", - "binaryDir": "${sourceDir}/build/default", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, - "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20" - } + "displayName": "macos universal2 release asan", + "generator": "Unix Makefiles", + "name": "macos_universal2_release_asan" }, { - "name": "macos_x86_64_relwithdebinfo_asan", - "displayName": "macos x86 64 relwithdebinfo asan", - "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64" + }, "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin" + "rhs": "Darwin", + "type": "equals" }, + "displayName": "macos universal2 relwithdebinfo", + "generator": "Unix Makefiles", + "name": "macos_universal2_relwithdebinfo" + }, + { + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_CXX_STANDARD": "20" - } - }, - { - "name": "windows_win32_debug", - "displayName": "windows win32 debug", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/debug/default", - "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } - }, - { - "name": "windows_win32_debug_asan", - "displayName": "windows win32 debug asan", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/debug/asan", - "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "rhs": "Darwin", + "type": "equals" }, - "architecture": "Win32", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "displayName": "macos universal2 relwithdebinfo asan", + "generator": "Unix Makefiles", + "name": "macos_universal2_relwithdebinfo_asan" }, { - "name": "windows_win32_debug_asan_no_ort", - "displayName": "windows win32 debug asan no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/debug/asan", + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0", + "CMAKE_OSX_ARCHITECTURES": "x86_64" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } - }, - { - "name": "windows_win32_debug_no_ort", - "displayName": "windows win32 debug no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/debug/default", - "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "rhs": "Darwin", + "type": "equals" }, - "architecture": "Win32", - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "displayName": "macos x86_64 debug", + "generator": "Unix Makefiles", + "name": "macos_x86_64_debug" }, { - "name": "windows_win32_minsizerel", - "displayName": "windows win32 minsizerel", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/minsizerel/default", + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Darwin", + "type": "equals" + }, + "displayName": "macos x86_64 debug asan", + "generator": "Unix Makefiles", + "name": "macos_x86_64_debug_asan" }, { - "name": "windows_win32_minsizerel_asan", - "displayName": "windows win32 minsizerel asan", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/minsizerel/asan", + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_OSX_ARCHITECTURES": "x86_64" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Darwin", + "type": "equals" + }, + "displayName": "macos x86_64 minsizerel", + "generator": "Unix Makefiles", + "name": "macos_x86_64_minsizerel" }, { - "name": "windows_win32_minsizerel_asan_no_ort", - "displayName": "windows win32 minsizerel asan no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/minsizerel/asan", + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Darwin", + "type": "equals" + }, + "displayName": "macos x86_64 minsizerel asan", + "generator": "Unix Makefiles", + "name": "macos_x86_64_minsizerel_asan" }, { - "name": "windows_win32_minsizerel_no_ort", - "displayName": "windows win32 minsizerel no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/minsizerel/default", + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_OSX_ARCHITECTURES": "x86_64" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Darwin", + "type": "equals" + }, + "displayName": "macos x86_64 release", + "generator": "Unix Makefiles", + "name": "macos_x86_64_release" }, { - "name": "windows_win32_release", - "displayName": "windows win32 release", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/release/default", + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Darwin", + "type": "equals" + }, + "displayName": "macos x86_64 release asan", + "generator": "Unix Makefiles", + "name": "macos_x86_64_release_asan" }, - { - "name": "windows_win32_release_asan", - "displayName": "windows win32 release asan", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/release/asan", + { + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_OSX_ARCHITECTURES": "x86_64" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Darwin", + "type": "equals" + }, + "displayName": "macos x86_64 relwithdebinfo", + "generator": "Unix Makefiles", + "name": "macos_x86_64_relwithdebinfo" }, { - "name": "windows_win32_release_asan_no_ort", - "displayName": "windows win32 release asan no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/release/asan", + "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Darwin", + "type": "equals" + }, + "displayName": "macos x86_64 relwithdebinfo asan", + "generator": "Unix Makefiles", + "name": "macos_x86_64_relwithdebinfo_asan" }, { - "name": "windows_win32_release_no_ort", - "displayName": "windows win32 release no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/release/default", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/debug/default", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 debug", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_debug" }, { - "name": "windows_win32_relwithdebinfo", - "displayName": "windows win32 relwithdebinfo", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/relwithdebinfo/default", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/debug/default", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 debug no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_debug_no_ort" }, { - "name": "windows_win32_relwithdebinfo_asan", - "displayName": "windows win32 relwithdebinfo asan", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/debug/asan", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 debug asan", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_debug_asan" }, { - "name": "windows_win32_relwithdebinfo_asan_no_ort", - "displayName": "windows win32 relwithdebinfo asan no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/debug/asan", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 debug asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_debug_asan_no_ort" }, { - "name": "windows_win32_relwithdebinfo_no_ort", - "displayName": "windows win32 relwithdebinfo no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/relwithdebinfo/default", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/minsizerel/default", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, - "architecture": "Win32", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 minsizerel", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_minsizerel" }, { - "name": "windows_x64_debug", - "displayName": "windows x64 debug", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/debug/default", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/minsizerel/default", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 minsizerel no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_minsizerel_no_ort" }, { - "name": "windows_x64_debug_asan", - "displayName": "windows x64 debug asan", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/debug/asan", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/minsizerel/asan", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 minsizerel asan", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_minsizerel_asan" }, { - "name": "windows_x64_debug_asan_no_ort", - "displayName": "windows x64 debug asan no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/debug/asan", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/minsizerel/asan", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 minsizerel asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_minsizerel_asan_no_ort" }, { - "name": "windows_x64_debug_no_ort", - "displayName": "windows x64 debug no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/debug/default", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/release/default", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 release", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_release" }, { - "name": "windows_x64_minsizerel", - "displayName": "windows x64 minsizerel", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/minsizerel/default", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/release/default", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 release no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_release_no_ort" }, { - "name": "windows_x64_minsizerel_asan", - "displayName": "windows x64 minsizerel asan", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/minsizerel/asan", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/release/asan", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 release asan", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_release_asan" }, { - "name": "windows_x64_minsizerel_asan_no_ort", - "displayName": "windows x64 minsizerel asan no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/minsizerel/asan", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/release/asan", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 release asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_release_asan_no_ort" }, { - "name": "windows_x64_minsizerel_no_ort", - "displayName": "windows x64 minsizerel no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/minsizerel/default", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/relwithdebinfo/default", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 relwithdebinfo", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_relwithdebinfo" }, { - "name": "windows_x64_release", - "displayName": "windows x64 release", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/release/default", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/relwithdebinfo/default", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 relwithdebinfo no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_relwithdebinfo_no_ort" }, { - "name": "windows_x64_release_asan", - "displayName": "windows x64 release asan", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/release/asan", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 relwithdebinfo asan", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_relwithdebinfo_asan" }, { - "name": "windows_x64_release_asan_no_ort", - "displayName": "windows x64 release asan no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/release/asan", + "architecture": "Win32", + "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows win32 relwithdebinfo asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_win32_relwithdebinfo_asan_no_ort" }, { - "name": "windows_x64_release_no_ort", - "displayName": "windows x64 release no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/release/default", + "architecture": "x64", + "binaryDir": "${sourceDir}/build/debug/default", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 debug", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_debug" }, { - "name": "windows_x64_relwithdebinfo", - "displayName": "windows x64 relwithdebinfo", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/relwithdebinfo/default", + "architecture": "x64", + "binaryDir": "${sourceDir}/build/debug/default", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 debug no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_debug_no_ort" }, { - "name": "windows_x64_relwithdebinfo_asan", - "displayName": "windows x64 relwithdebinfo asan", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "architecture": "x64", + "binaryDir": "${sourceDir}/build/debug/asan", "cacheVariables": { - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 debug asan", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_debug_asan" }, { - "name": "windows_x64_relwithdebinfo_asan_no_ort", - "displayName": "windows x64 relwithdebinfo asan no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "architecture": "x64", + "binaryDir": "${sourceDir}/build/debug/asan", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 debug asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_debug_asan_no_ort" }, { - "name": "windows_x64_relwithdebinfo_no_ort", - "displayName": "windows x64 relwithdebinfo no ort", - "generator": "Visual Studio 17 2022", - "binaryDir": "${sourceDir}/build/relwithdebinfo/default", + "architecture": "x64", + "binaryDir": "${sourceDir}/build/minsizerel/default", "cacheVariables": { - "MLAS_NO_ONNXRUNTIME": "ON", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, - "architecture": "x64", "condition": { - "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows" - } - } - ], - "buildPresets": [ - { - "name": "linux_clang_debug", - "configurePreset": "linux_clang_debug", - "configuration": "Debug" - }, - { - "name": "linux_clang_debug_asan", - "configurePreset": "linux_clang_debug_asan", - "configuration": "Debug" - }, - { - "name": "linux_clang_debug_asan_no_ort", - "configurePreset": "linux_clang_debug_asan_no_ort", - "configuration": "Debug" - }, - { - "name": "linux_clang_debug_cov", - "configurePreset": "linux_clang_debug_cov", - "configuration": "Debug" - }, - { - "name": "linux_clang_debug_cov_no_ort", - "configurePreset": "linux_clang_debug_cov_no_ort", - "configuration": "Debug" - }, - { - "name": "linux_clang_debug_no_ort", - "configurePreset": "linux_clang_debug_no_ort", - "configuration": "Debug" - }, - { - "name": "linux_gcc_debug", - "configurePreset": "linux_gcc_debug", - "configuration": "Debug" - }, - { - "name": "linux_gcc_debug_asan", - "configurePreset": "linux_gcc_debug_asan", - "configuration": "Debug" - }, - { - "name": "linux_gcc_debug_asan_no_ort", - "configurePreset": "linux_gcc_debug_asan_no_ort", - "configuration": "Debug" - }, - { - "name": "linux_gcc_debug_no_ort", - "configurePreset": "linux_gcc_debug_no_ort", - "configuration": "Debug" - }, - { - "name": "linux_gcc_minsizerel", - "configurePreset": "linux_gcc_minsizerel", - "configuration": "MinSizeRel" - }, - { - "name": "linux_gcc_minsizerel_asan", - "configurePreset": "linux_gcc_minsizerel_asan", - "configuration": "MinSizeRel" - }, - { - "name": "linux_gcc_minsizerel_asan_no_ort", - "configurePreset": "linux_gcc_minsizerel_asan_no_ort", - "configuration": "MinSizeRel" - }, - { - "name": "linux_gcc_minsizerel_no_ort", - "configurePreset": "linux_gcc_minsizerel_no_ort", - "configuration": "MinSizeRel" - }, - { - "name": "linux_gcc_release", - "configurePreset": "linux_gcc_release", - "configuration": "Release" - }, - { - "name": "linux_gcc_release_asan", - "configurePreset": "linux_gcc_release_asan", - "configuration": "Release" - }, - { - "name": "linux_gcc_release_asan_no_ort", - "configurePreset": "linux_gcc_release_asan_no_ort", - "configuration": "Release" - }, - { - "name": "linux_gcc_release_no_ort", - "configurePreset": "linux_gcc_release_no_ort", - "configuration": "Release" - }, - { - "name": "linux_gcc_relwithdebinfo", - "configurePreset": "linux_gcc_relwithdebinfo", - "configuration": "RelWithDebInfo" - }, - { - "name": "linux_gcc_relwithdebinfo_asan", - "configurePreset": "linux_gcc_relwithdebinfo_asan", - "configuration": "RelWithDebInfo" - }, - { - "name": "linux_gcc_relwithdebinfo_asan_no_ort", - "configurePreset": "linux_gcc_relwithdebinfo_asan_no_ort", - "configuration": "RelWithDebInfo" - }, - { - "name": "linux_gcc_relwithdebinfo_no_ort", - "configurePreset": "linux_gcc_relwithdebinfo_no_ort", - "configuration": "RelWithDebInfo" - }, - { - "name": "macos_arm64_debug", - "configurePreset": "macos_arm64_debug", - "configuration": "Debug" - }, - { - "name": "macos_arm64_debug_asan", - "configurePreset": "macos_arm64_debug_asan", - "configuration": "Debug" - }, - { - "name": "macos_arm64_minsizerel", - "configurePreset": "macos_arm64_minsizerel", - "configuration": "MinSizeRel" - }, - { - "name": "macos_arm64_minsizerel_asan", - "configurePreset": "macos_arm64_minsizerel_asan", - "configuration": "MinSizeRel" - }, - { - "name": "macos_arm64_release", - "configurePreset": "macos_arm64_release", - "configuration": "Release" - }, - { - "name": "macos_arm64_release_asan", - "configurePreset": "macos_arm64_release_asan", - "configuration": "Release" - }, - { - "name": "macos_arm64_relwithdebinfo", - "configurePreset": "macos_arm64_relwithdebinfo", - "configuration": "RelWithDebInfo" - }, - { - "name": "macos_arm64_relwithdebinfo_asan", - "configurePreset": "macos_arm64_relwithdebinfo_asan", - "configuration": "RelWithDebInfo" - }, - { - "name": "macos_universal2_debug", - "configurePreset": "macos_universal2_debug", - "configuration": "Debug" - }, - { - "name": "macos_universal2_debug_asan", - "configurePreset": "macos_universal2_debug_asan", - "configuration": "Debug" - }, - { - "name": "macos_universal2_minsizerel", - "configurePreset": "macos_universal2_minsizerel", - "configuration": "MinSizeRel" - }, - { - "name": "macos_universal2_minsizerel_asan", - "configurePreset": "macos_universal2_minsizerel_asan", - "configuration": "MinSizeRel" - }, - { - "name": "macos_universal2_release", - "configurePreset": "macos_universal2_release", - "configuration": "Release" - }, - { - "name": "macos_universal2_release_asan", - "configurePreset": "macos_universal2_release_asan", - "configuration": "Release" - }, - { - "name": "macos_universal2_relwithdebinfo", - "configurePreset": "macos_universal2_relwithdebinfo", - "configuration": "RelWithDebInfo" - }, - { - "name": "macos_universal2_relwithdebinfo_asan", - "configurePreset": "macos_universal2_relwithdebinfo_asan", - "configuration": "RelWithDebInfo" - }, - { - "name": "macos_x86_64_debug", - "configurePreset": "macos_x86_64_debug", - "configuration": "Debug" - }, - { - "name": "macos_x86_64_debug_asan", - "configurePreset": "macos_x86_64_debug_asan", - "configuration": "Debug" - }, - { - "name": "macos_x86_64_minsizerel", - "configurePreset": "macos_x86_64_minsizerel", - "configuration": "MinSizeRel" - }, - { - "name": "macos_x86_64_minsizerel_asan", - "configurePreset": "macos_x86_64_minsizerel_asan", - "configuration": "MinSizeRel" - }, - { - "name": "macos_x86_64_release", - "configurePreset": "macos_x86_64_release", - "configuration": "Release" - }, - { - "name": "macos_x86_64_release_asan", - "configurePreset": "macos_x86_64_release_asan", - "configuration": "Release" - }, - { - "name": "macos_x86_64_relwithdebinfo", - "configurePreset": "macos_x86_64_relwithdebinfo", - "configuration": "RelWithDebInfo" - }, - { - "name": "macos_x86_64_relwithdebinfo_asan", - "configurePreset": "macos_x86_64_relwithdebinfo_asan", - "configuration": "RelWithDebInfo" - }, - { - "name": "windows_win32_debug", - "configurePreset": "windows_win32_debug", - "configuration": "Debug" - }, - { - "name": "windows_win32_debug_asan", - "configurePreset": "windows_win32_debug_asan", - "configuration": "Debug" - }, - { - "name": "windows_win32_debug_asan_no_ort", - "configurePreset": "windows_win32_debug_asan_no_ort", - "configuration": "Debug" - }, - { - "name": "windows_win32_debug_no_ort", - "configurePreset": "windows_win32_debug_no_ort", - "configuration": "Debug" - }, - { - "name": "windows_win32_minsizerel", - "configurePreset": "windows_win32_minsizerel", - "configuration": "MinSizeRel" + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 minsizerel", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_minsizerel" }, { - "name": "windows_win32_minsizerel_asan", - "configurePreset": "windows_win32_minsizerel_asan", - "configuration": "MinSizeRel" + "architecture": "x64", + "binaryDir": "${sourceDir}/build/minsizerel/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 minsizerel no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_minsizerel_no_ort" }, { - "name": "windows_win32_minsizerel_asan_no_ort", - "configurePreset": "windows_win32_minsizerel_asan_no_ort", - "configuration": "MinSizeRel" + "architecture": "x64", + "binaryDir": "${sourceDir}/build/minsizerel/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 minsizerel asan", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_minsizerel_asan" }, { - "name": "windows_win32_minsizerel_no_ort", - "configurePreset": "windows_win32_minsizerel_no_ort", - "configuration": "MinSizeRel" + "architecture": "x64", + "binaryDir": "${sourceDir}/build/minsizerel/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 minsizerel asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_minsizerel_asan_no_ort" }, { - "name": "windows_win32_release", - "configurePreset": "windows_win32_release", - "configuration": "Release" + "architecture": "x64", + "binaryDir": "${sourceDir}/build/release/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 release", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_release" }, { - "name": "windows_win32_release_asan", - "configurePreset": "windows_win32_release_asan", - "configuration": "Release" + "architecture": "x64", + "binaryDir": "${sourceDir}/build/release/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 release no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_release_no_ort" }, { - "name": "windows_win32_release_asan_no_ort", - "configurePreset": "windows_win32_release_asan_no_ort", - "configuration": "Release" + "architecture": "x64", + "binaryDir": "${sourceDir}/build/release/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 release asan", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_release_asan" }, { - "name": "windows_win32_release_no_ort", - "configurePreset": "windows_win32_release_no_ort", - "configuration": "Release" + "architecture": "x64", + "binaryDir": "${sourceDir}/build/release/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 release asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_release_asan_no_ort" }, { - "name": "windows_win32_relwithdebinfo", - "configurePreset": "windows_win32_relwithdebinfo", - "configuration": "RelWithDebInfo" + "architecture": "x64", + "binaryDir": "${sourceDir}/build/relwithdebinfo/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 relwithdebinfo", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_relwithdebinfo" }, { - "name": "windows_win32_relwithdebinfo_asan", - "configurePreset": "windows_win32_relwithdebinfo_asan", - "configuration": "RelWithDebInfo" + "architecture": "x64", + "binaryDir": "${sourceDir}/build/relwithdebinfo/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 relwithdebinfo no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_relwithdebinfo_no_ort" }, { - "name": "windows_win32_relwithdebinfo_asan_no_ort", - "configurePreset": "windows_win32_relwithdebinfo_asan_no_ort", - "configuration": "RelWithDebInfo" + "architecture": "x64", + "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 relwithdebinfo asan", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_relwithdebinfo_asan" }, { - "name": "windows_win32_relwithdebinfo_no_ort", - "configurePreset": "windows_win32_relwithdebinfo_no_ort", - "configuration": "RelWithDebInfo" + "architecture": "x64", + "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows x64 relwithdebinfo asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_x64_relwithdebinfo_asan_no_ort" }, { - "name": "windows_x64_debug", - "configurePreset": "windows_x64_debug", - "configuration": "Debug" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/debug/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 debug", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_debug" }, { - "name": "windows_x64_debug_asan", - "configurePreset": "windows_x64_debug_asan", - "configuration": "Debug" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/debug/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 debug no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_debug_no_ort" }, { - "name": "windows_x64_debug_asan_no_ort", - "configurePreset": "windows_x64_debug_asan_no_ort", - "configuration": "Debug" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/debug/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 debug asan", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_debug_asan" }, { - "name": "windows_x64_debug_no_ort", - "configurePreset": "windows_x64_debug_no_ort", - "configuration": "Debug" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/debug/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 debug asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_debug_asan_no_ort" }, { - "name": "windows_x64_minsizerel", - "configurePreset": "windows_x64_minsizerel", - "configuration": "MinSizeRel" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/minsizerel/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 minsizerel", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_minsizerel" }, { - "name": "windows_x64_minsizerel_asan", - "configurePreset": "windows_x64_minsizerel_asan", - "configuration": "MinSizeRel" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/minsizerel/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 minsizerel no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_minsizerel_no_ort" }, { - "name": "windows_x64_minsizerel_asan_no_ort", - "configurePreset": "windows_x64_minsizerel_asan_no_ort", - "configuration": "MinSizeRel" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/minsizerel/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 minsizerel asan", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_minsizerel_asan" }, { - "name": "windows_x64_minsizerel_no_ort", - "configurePreset": "windows_x64_minsizerel_no_ort", - "configuration": "MinSizeRel" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/minsizerel/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 minsizerel asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_minsizerel_asan_no_ort" }, { - "name": "windows_x64_release", - "configurePreset": "windows_x64_release", - "configuration": "Release" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/release/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 release", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_release" }, { - "name": "windows_x64_release_asan", - "configurePreset": "windows_x64_release_asan", - "configuration": "Release" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/release/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 release no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_release_no_ort" }, { - "name": "windows_x64_release_asan_no_ort", - "configurePreset": "windows_x64_release_asan_no_ort", - "configuration": "Release" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/release/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 release asan", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_release_asan" }, { - "name": "windows_x64_release_no_ort", - "configurePreset": "windows_x64_release_no_ort", - "configuration": "Release" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/release/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 release asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_release_asan_no_ort" }, { - "name": "windows_x64_relwithdebinfo", - "configurePreset": "windows_x64_relwithdebinfo", - "configuration": "RelWithDebInfo" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/relwithdebinfo/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 relwithdebinfo", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_relwithdebinfo" }, { - "name": "windows_x64_relwithdebinfo_asan", - "configurePreset": "windows_x64_relwithdebinfo_asan", - "configuration": "RelWithDebInfo" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/relwithdebinfo/default", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 relwithdebinfo no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_relwithdebinfo_no_ort" }, { - "name": "windows_x64_relwithdebinfo_asan_no_ort", - "configurePreset": "windows_x64_relwithdebinfo_asan_no_ort", - "configuration": "RelWithDebInfo" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 relwithdebinfo asan", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_relwithdebinfo_asan" }, { - "name": "windows_x64_relwithdebinfo_no_ort", - "configurePreset": "windows_x64_relwithdebinfo_no_ort", - "configuration": "RelWithDebInfo" + "architecture": "ARM64", + "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "cacheVariables": { + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "MLAS_NO_ONNXRUNTIME": "ON" + }, + "condition": { + "lhs": "${hostSystemName}", + "rhs": "Windows", + "type": "equals" + }, + "displayName": "windows arm64 relwithdebinfo asan no_ort", + "generator": "Visual Studio 17 2022", + "name": "windows_arm64_relwithdebinfo_asan_no_ort" } ], "testPresets": [ { - "name": "linux_clang_debug", "configuration": "Debug", - "configurePreset": "linux_clang_debug" + "configurePreset": "linux_clang_debug", + "name": "linux_clang_debug" }, { - "name": "linux_clang_debug_asan", "configuration": "Debug", - "configurePreset": "linux_clang_debug_asan" + "configurePreset": "linux_clang_debug_asan", + "name": "linux_clang_debug_asan" }, { - "name": "linux_clang_debug_asan_no_ort", "configuration": "Debug", - "configurePreset": "linux_clang_debug_asan_no_ort" + "configurePreset": "linux_clang_debug_asan_no_ort", + "name": "linux_clang_debug_asan_no_ort" }, { - "name": "linux_clang_debug_cov", "configuration": "Debug", - "configurePreset": "linux_clang_debug_cov" + "configurePreset": "linux_clang_debug_cov", + "name": "linux_clang_debug_cov" }, { - "name": "linux_clang_debug_cov_no_ort", "configuration": "Debug", - "configurePreset": "linux_clang_debug_cov_no_ort" + "configurePreset": "linux_clang_debug_cov_no_ort", + "name": "linux_clang_debug_cov_no_ort" }, { - "name": "linux_clang_debug_no_ort", "configuration": "Debug", - "configurePreset": "linux_clang_debug_no_ort" + "configurePreset": "linux_clang_debug_no_ort", + "name": "linux_clang_debug_no_ort" }, { - "name": "linux_gcc_debug", "configuration": "Debug", - "configurePreset": "linux_gcc_debug" + "configurePreset": "linux_gcc_debug", + "name": "linux_gcc_debug" }, { - "name": "linux_gcc_debug_asan", "configuration": "Debug", - "configurePreset": "linux_gcc_debug_asan" + "configurePreset": "linux_gcc_debug_asan", + "name": "linux_gcc_debug_asan" }, { - "name": "linux_gcc_debug_asan_no_ort", "configuration": "Debug", - "configurePreset": "linux_gcc_debug_asan_no_ort" + "configurePreset": "linux_gcc_debug_asan_no_ort", + "name": "linux_gcc_debug_asan_no_ort" }, { - "name": "linux_gcc_debug_no_ort", "configuration": "Debug", - "configurePreset": "linux_gcc_debug_no_ort" + "configurePreset": "linux_gcc_debug_no_ort", + "name": "linux_gcc_debug_no_ort" }, { - "name": "linux_gcc_minsizerel", "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel" + "configurePreset": "linux_gcc_minsizerel", + "name": "linux_gcc_minsizerel" }, { - "name": "linux_gcc_minsizerel_asan", "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel_asan" + "configurePreset": "linux_gcc_minsizerel_asan", + "name": "linux_gcc_minsizerel_asan" }, { - "name": "linux_gcc_minsizerel_asan_no_ort", "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel_asan_no_ort" + "configurePreset": "linux_gcc_minsizerel_asan_no_ort", + "name": "linux_gcc_minsizerel_asan_no_ort" }, { - "name": "linux_gcc_minsizerel_no_ort", "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel_no_ort" + "configurePreset": "linux_gcc_minsizerel_no_ort", + "name": "linux_gcc_minsizerel_no_ort" }, { - "name": "linux_gcc_release", "configuration": "Release", - "configurePreset": "linux_gcc_release" + "configurePreset": "linux_gcc_release", + "name": "linux_gcc_release" }, { - "name": "linux_gcc_release_asan", "configuration": "Release", - "configurePreset": "linux_gcc_release_asan" + "configurePreset": "linux_gcc_release_asan", + "name": "linux_gcc_release_asan" }, { - "name": "linux_gcc_release_asan_no_ort", "configuration": "Release", - "configurePreset": "linux_gcc_release_asan_no_ort" + "configurePreset": "linux_gcc_release_asan_no_ort", + "name": "linux_gcc_release_asan_no_ort" }, { - "name": "linux_gcc_release_no_ort", "configuration": "Release", - "configurePreset": "linux_gcc_release_no_ort" + "configurePreset": "linux_gcc_release_no_ort", + "name": "linux_gcc_release_no_ort" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "linux_gcc_relwithdebinfo", + "name": "linux_gcc_relwithdebinfo" + }, + { + "configuration": "RelWithDebInfo", + "configurePreset": "linux_gcc_relwithdebinfo_asan", + "name": "linux_gcc_relwithdebinfo_asan" }, { - "name": "linux_gcc_relwithdebinfo", "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo" + "configurePreset": "linux_gcc_relwithdebinfo_asan_no_ort", + "name": "linux_gcc_relwithdebinfo_asan_no_ort" }, { - "name": "linux_gcc_relwithdebinfo_asan", "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo_asan" + "configurePreset": "linux_gcc_relwithdebinfo_no_ort", + "name": "linux_gcc_relwithdebinfo_no_ort" + }, + { + "configuration": "Debug", + "configurePreset": "macos_arm64_debug", + "name": "macos_arm64_debug" + }, + { + "configuration": "Debug", + "configurePreset": "macos_arm64_debug_asan", + "name": "macos_arm64_debug_asan" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "macos_arm64_minsizerel", + "name": "macos_arm64_minsizerel" + }, + { + "configuration": "MinSizeRel", + "configurePreset": "macos_arm64_minsizerel_asan", + "name": "macos_arm64_minsizerel_asan" + }, + { + "configuration": "Release", + "configurePreset": "macos_arm64_release", + "name": "macos_arm64_release" + }, + { + "configuration": "Release", + "configurePreset": "macos_arm64_release_asan", + "name": "macos_arm64_release_asan" }, { - "name": "linux_gcc_relwithdebinfo_asan_no_ort", "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo_asan_no_ort" + "configurePreset": "macos_arm64_relwithdebinfo", + "name": "macos_arm64_relwithdebinfo" }, { - "name": "linux_gcc_relwithdebinfo_no_ort", "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo_no_ort" + "configurePreset": "macos_arm64_relwithdebinfo_asan", + "name": "macos_arm64_relwithdebinfo_asan" }, { - "name": "macos_arm64_debug", "configuration": "Debug", - "configurePreset": "macos_arm64_debug" + "configurePreset": "macos_universal2_debug", + "name": "macos_universal2_debug" }, { - "name": "macos_arm64_debug_asan", "configuration": "Debug", - "configurePreset": "macos_arm64_debug_asan" + "configurePreset": "macos_universal2_debug_asan", + "name": "macos_universal2_debug_asan" }, { - "name": "macos_arm64_minsizerel", "configuration": "MinSizeRel", - "configurePreset": "macos_arm64_minsizerel" + "configurePreset": "macos_universal2_minsizerel", + "name": "macos_universal2_minsizerel" }, { - "name": "macos_arm64_minsizerel_asan", "configuration": "MinSizeRel", - "configurePreset": "macos_arm64_minsizerel_asan" + "configurePreset": "macos_universal2_minsizerel_asan", + "name": "macos_universal2_minsizerel_asan" }, { - "name": "macos_arm64_release", "configuration": "Release", - "configurePreset": "macos_arm64_release" + "configurePreset": "macos_universal2_release", + "name": "macos_universal2_release" }, { - "name": "macos_arm64_release_asan", "configuration": "Release", - "configurePreset": "macos_arm64_release_asan" + "configurePreset": "macos_universal2_release_asan", + "name": "macos_universal2_release_asan" }, { - "name": "macos_arm64_relwithdebinfo", "configuration": "RelWithDebInfo", - "configurePreset": "macos_arm64_relwithdebinfo" + "configurePreset": "macos_universal2_relwithdebinfo", + "name": "macos_universal2_relwithdebinfo" }, { - "name": "macos_arm64_relwithdebinfo_asan", "configuration": "RelWithDebInfo", - "configurePreset": "macos_arm64_relwithdebinfo_asan" + "configurePreset": "macos_universal2_relwithdebinfo_asan", + "name": "macos_universal2_relwithdebinfo_asan" }, { - "name": "macos_universal2_debug", "configuration": "Debug", - "configurePreset": "macos_universal2_debug" + "configurePreset": "macos_x86_64_debug", + "name": "macos_x86_64_debug" }, { - "name": "macos_universal2_debug_asan", "configuration": "Debug", - "configurePreset": "macos_universal2_debug_asan" + "configurePreset": "macos_x86_64_debug_asan", + "name": "macos_x86_64_debug_asan" }, { - "name": "macos_universal2_minsizerel", "configuration": "MinSizeRel", - "configurePreset": "macos_universal2_minsizerel" + "configurePreset": "macos_x86_64_minsizerel", + "name": "macos_x86_64_minsizerel" }, { - "name": "macos_universal2_minsizerel_asan", "configuration": "MinSizeRel", - "configurePreset": "macos_universal2_minsizerel_asan" + "configurePreset": "macos_x86_64_minsizerel_asan", + "name": "macos_x86_64_minsizerel_asan" }, { - "name": "macos_universal2_release", "configuration": "Release", - "configurePreset": "macos_universal2_release" + "configurePreset": "macos_x86_64_release", + "name": "macos_x86_64_release" }, { - "name": "macos_universal2_release_asan", "configuration": "Release", - "configurePreset": "macos_universal2_release_asan" + "configurePreset": "macos_x86_64_release_asan", + "name": "macos_x86_64_release_asan" }, { - "name": "macos_universal2_relwithdebinfo", "configuration": "RelWithDebInfo", - "configurePreset": "macos_universal2_relwithdebinfo" + "configurePreset": "macos_x86_64_relwithdebinfo", + "name": "macos_x86_64_relwithdebinfo" }, { - "name": "macos_universal2_relwithdebinfo_asan", "configuration": "RelWithDebInfo", - "configurePreset": "macos_universal2_relwithdebinfo_asan" + "configurePreset": "macos_x86_64_relwithdebinfo_asan", + "name": "macos_x86_64_relwithdebinfo_asan" + }, + { + "configuration": "Debug", + "configurePreset": "windows_arm64_debug", + "name": "windows_arm64_debug", + "output": { + "outputOnFailure": true + } + }, + { + "configuration": "Debug", + "configurePreset": "windows_arm64_debug_asan", + "name": "windows_arm64_debug_asan", + "output": { + "outputOnFailure": true + } + }, + { + "configuration": "Debug", + "configurePreset": "windows_arm64_debug_asan_no_ort", + "name": "windows_arm64_debug_asan_no_ort", + "output": { + "outputOnFailure": true + } }, { - "name": "macos_x86_64_debug", "configuration": "Debug", - "configurePreset": "macos_x86_64_debug" + "configurePreset": "windows_arm64_debug_no_ort", + "name": "windows_arm64_debug_no_ort", + "output": { + "outputOnFailure": true + } + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_arm64_minsizerel", + "name": "windows_arm64_minsizerel", + "output": { + "outputOnFailure": true + } + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_arm64_minsizerel_asan", + "name": "windows_arm64_minsizerel_asan", + "output": { + "outputOnFailure": true + } + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_arm64_minsizerel_asan_no_ort", + "name": "windows_arm64_minsizerel_asan_no_ort", + "output": { + "outputOnFailure": true + } + }, + { + "configuration": "MinSizeRel", + "configurePreset": "windows_arm64_minsizerel_no_ort", + "name": "windows_arm64_minsizerel_no_ort", + "output": { + "outputOnFailure": true + } + }, + { + "configuration": "Release", + "configurePreset": "windows_arm64_release", + "name": "windows_arm64_release", + "output": { + "outputOnFailure": true + } }, { - "name": "macos_x86_64_debug_asan", - "configuration": "Debug", - "configurePreset": "macos_x86_64_debug_asan" + "configuration": "Release", + "configurePreset": "windows_arm64_release_asan", + "name": "windows_arm64_release_asan", + "output": { + "outputOnFailure": true + } }, { - "name": "macos_x86_64_minsizerel", - "configuration": "MinSizeRel", - "configurePreset": "macos_x86_64_minsizerel" + "configuration": "Release", + "configurePreset": "windows_arm64_release_asan_no_ort", + "name": "windows_arm64_release_asan_no_ort", + "output": { + "outputOnFailure": true + } }, { - "name": "macos_x86_64_minsizerel_asan", - "configuration": "MinSizeRel", - "configurePreset": "macos_x86_64_minsizerel_asan" + "configuration": "Release", + "configurePreset": "windows_arm64_release_no_ort", + "name": "windows_arm64_release_no_ort", + "output": { + "outputOnFailure": true + } }, { - "name": "macos_x86_64_release", - "configuration": "Release", - "configurePreset": "macos_x86_64_release" + "configuration": "RelWithDebInfo", + "configurePreset": "windows_arm64_relwithdebinfo", + "name": "windows_arm64_relwithdebinfo", + "output": { + "outputOnFailure": true + } }, { - "name": "macos_x86_64_release_asan", - "configuration": "Release", - "configurePreset": "macos_x86_64_release_asan" + "configuration": "RelWithDebInfo", + "configurePreset": "windows_arm64_relwithdebinfo_asan", + "name": "windows_arm64_relwithdebinfo_asan", + "output": { + "outputOnFailure": true + } }, { - "name": "macos_x86_64_relwithdebinfo", "configuration": "RelWithDebInfo", - "configurePreset": "macos_x86_64_relwithdebinfo" + "configurePreset": "windows_arm64_relwithdebinfo_asan_no_ort", + "name": "windows_arm64_relwithdebinfo_asan_no_ort", + "output": { + "outputOnFailure": true + } }, { - "name": "macos_x86_64_relwithdebinfo_asan", "configuration": "RelWithDebInfo", - "configurePreset": "macos_x86_64_relwithdebinfo_asan" + "configurePreset": "windows_arm64_relwithdebinfo_no_ort", + "name": "windows_arm64_relwithdebinfo_no_ort", + "output": { + "outputOnFailure": true + } }, { - "name": "windows_win32_debug", "configuration": "Debug", "configurePreset": "windows_win32_debug", + "name": "windows_win32_debug", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_debug_asan", "configuration": "Debug", "configurePreset": "windows_win32_debug_asan", + "name": "windows_win32_debug_asan", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_debug_asan_no_ort", "configuration": "Debug", "configurePreset": "windows_win32_debug_asan_no_ort", + "name": "windows_win32_debug_asan_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_debug_no_ort", "configuration": "Debug", "configurePreset": "windows_win32_debug_no_ort", + "name": "windows_win32_debug_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_minsizerel", "configuration": "MinSizeRel", "configurePreset": "windows_win32_minsizerel", + "name": "windows_win32_minsizerel", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_minsizerel_asan", "configuration": "MinSizeRel", "configurePreset": "windows_win32_minsizerel_asan", + "name": "windows_win32_minsizerel_asan", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_minsizerel_asan_no_ort", "configuration": "MinSizeRel", "configurePreset": "windows_win32_minsizerel_asan_no_ort", + "name": "windows_win32_minsizerel_asan_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_minsizerel_no_ort", "configuration": "MinSizeRel", "configurePreset": "windows_win32_minsizerel_no_ort", + "name": "windows_win32_minsizerel_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_release", "configuration": "Release", "configurePreset": "windows_win32_release", + "name": "windows_win32_release", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_release_asan", "configuration": "Release", "configurePreset": "windows_win32_release_asan", + "name": "windows_win32_release_asan", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_release_asan_no_ort", "configuration": "Release", "configurePreset": "windows_win32_release_asan_no_ort", + "name": "windows_win32_release_asan_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_release_no_ort", "configuration": "Release", "configurePreset": "windows_win32_release_no_ort", + "name": "windows_win32_release_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_relwithdebinfo", "configuration": "RelWithDebInfo", "configurePreset": "windows_win32_relwithdebinfo", + "name": "windows_win32_relwithdebinfo", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_relwithdebinfo_asan", "configuration": "RelWithDebInfo", "configurePreset": "windows_win32_relwithdebinfo_asan", + "name": "windows_win32_relwithdebinfo_asan", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_relwithdebinfo_asan_no_ort", "configuration": "RelWithDebInfo", "configurePreset": "windows_win32_relwithdebinfo_asan_no_ort", + "name": "windows_win32_relwithdebinfo_asan_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_win32_relwithdebinfo_no_ort", "configuration": "RelWithDebInfo", "configurePreset": "windows_win32_relwithdebinfo_no_ort", + "name": "windows_win32_relwithdebinfo_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_debug", "configuration": "Debug", "configurePreset": "windows_x64_debug", + "name": "windows_x64_debug", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_debug_asan", "configuration": "Debug", "configurePreset": "windows_x64_debug_asan", + "name": "windows_x64_debug_asan", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_debug_asan_no_ort", "configuration": "Debug", "configurePreset": "windows_x64_debug_asan_no_ort", + "name": "windows_x64_debug_asan_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_debug_no_ort", "configuration": "Debug", "configurePreset": "windows_x64_debug_no_ort", + "name": "windows_x64_debug_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_minsizerel", "configuration": "MinSizeRel", "configurePreset": "windows_x64_minsizerel", + "name": "windows_x64_minsizerel", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_minsizerel_asan", "configuration": "MinSizeRel", "configurePreset": "windows_x64_minsizerel_asan", + "name": "windows_x64_minsizerel_asan", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_minsizerel_asan_no_ort", "configuration": "MinSizeRel", "configurePreset": "windows_x64_minsizerel_asan_no_ort", + "name": "windows_x64_minsizerel_asan_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_minsizerel_no_ort", "configuration": "MinSizeRel", "configurePreset": "windows_x64_minsizerel_no_ort", + "name": "windows_x64_minsizerel_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_release", "configuration": "Release", "configurePreset": "windows_x64_release", + "name": "windows_x64_release", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_release_asan", "configuration": "Release", "configurePreset": "windows_x64_release_asan", + "name": "windows_x64_release_asan", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_release_asan_no_ort", "configuration": "Release", "configurePreset": "windows_x64_release_asan_no_ort", + "name": "windows_x64_release_asan_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_release_no_ort", "configuration": "Release", "configurePreset": "windows_x64_release_no_ort", + "name": "windows_x64_release_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_relwithdebinfo", "configuration": "RelWithDebInfo", "configurePreset": "windows_x64_relwithdebinfo", + "name": "windows_x64_relwithdebinfo", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_relwithdebinfo_asan", "configuration": "RelWithDebInfo", "configurePreset": "windows_x64_relwithdebinfo_asan", + "name": "windows_x64_relwithdebinfo_asan", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_relwithdebinfo_asan_no_ort", "configuration": "RelWithDebInfo", "configurePreset": "windows_x64_relwithdebinfo_asan_no_ort", + "name": "windows_x64_relwithdebinfo_asan_no_ort", "output": { "outputOnFailure": true } }, { - "name": "windows_x64_relwithdebinfo_no_ort", "configuration": "RelWithDebInfo", "configurePreset": "windows_x64_relwithdebinfo_no_ort", + "name": "windows_x64_relwithdebinfo_no_ort", "output": { "outputOnFailure": true } } ], + "version": 8, "workflowPresets": [ { "name": "linux_clang_debug_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "linux_clang_debug_asan_no_ort" + "name": "linux_clang_debug_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "linux_clang_debug_asan_no_ort" + "name": "linux_clang_debug_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "linux_clang_debug_asan_no_ort" + "name": "linux_clang_debug_asan_no_ort", + "type": "test" } ] }, @@ -2540,16 +3061,16 @@ "name": "linux_clang_debug_asan_workflow", "steps": [ { - "type": "configure", - "name": "linux_clang_debug_asan" + "name": "linux_clang_debug_asan", + "type": "configure" }, { - "type": "build", - "name": "linux_clang_debug_asan" + "name": "linux_clang_debug_asan", + "type": "build" }, { - "type": "test", - "name": "linux_clang_debug_asan" + "name": "linux_clang_debug_asan", + "type": "test" } ] }, @@ -2557,16 +3078,16 @@ "name": "linux_clang_debug_cov_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "linux_clang_debug_cov_no_ort" + "name": "linux_clang_debug_cov_no_ort", + "type": "configure" }, { - "type": "build", - "name": "linux_clang_debug_cov_no_ort" + "name": "linux_clang_debug_cov_no_ort", + "type": "build" }, { - "type": "test", - "name": "linux_clang_debug_cov_no_ort" + "name": "linux_clang_debug_cov_no_ort", + "type": "test" } ] }, @@ -2574,16 +3095,16 @@ "name": "linux_clang_debug_cov_workflow", "steps": [ { - "type": "configure", - "name": "linux_clang_debug_cov" + "name": "linux_clang_debug_cov", + "type": "configure" }, { - "type": "build", - "name": "linux_clang_debug_cov" + "name": "linux_clang_debug_cov", + "type": "build" }, { - "type": "test", - "name": "linux_clang_debug_cov" + "name": "linux_clang_debug_cov", + "type": "test" } ] }, @@ -2591,16 +3112,16 @@ "name": "linux_clang_debug_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "linux_clang_debug_no_ort" + "name": "linux_clang_debug_no_ort", + "type": "configure" }, { - "type": "build", - "name": "linux_clang_debug_no_ort" + "name": "linux_clang_debug_no_ort", + "type": "build" }, { - "type": "test", - "name": "linux_clang_debug_no_ort" + "name": "linux_clang_debug_no_ort", + "type": "test" } ] }, @@ -2608,16 +3129,16 @@ "name": "linux_clang_debug_workflow", "steps": [ { - "type": "configure", - "name": "linux_clang_debug" + "name": "linux_clang_debug", + "type": "configure" }, { - "type": "build", - "name": "linux_clang_debug" + "name": "linux_clang_debug", + "type": "build" }, { - "type": "test", - "name": "linux_clang_debug" + "name": "linux_clang_debug", + "type": "test" } ] }, @@ -2625,16 +3146,16 @@ "name": "linux_gcc_debug_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_debug_asan_no_ort" + "name": "linux_gcc_debug_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_debug_asan_no_ort" + "name": "linux_gcc_debug_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_debug_asan_no_ort" + "name": "linux_gcc_debug_asan_no_ort", + "type": "test" } ] }, @@ -2642,16 +3163,16 @@ "name": "linux_gcc_debug_asan_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_debug_asan" + "name": "linux_gcc_debug_asan", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_debug_asan" + "name": "linux_gcc_debug_asan", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_debug_asan" + "name": "linux_gcc_debug_asan", + "type": "test" } ] }, @@ -2659,16 +3180,16 @@ "name": "linux_gcc_debug_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_debug_no_ort" + "name": "linux_gcc_debug_no_ort", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_debug_no_ort" + "name": "linux_gcc_debug_no_ort", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_debug_no_ort" + "name": "linux_gcc_debug_no_ort", + "type": "test" } ] }, @@ -2676,16 +3197,16 @@ "name": "linux_gcc_debug_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_debug" + "name": "linux_gcc_debug", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_debug" + "name": "linux_gcc_debug", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_debug" + "name": "linux_gcc_debug", + "type": "test" } ] }, @@ -2693,16 +3214,16 @@ "name": "linux_gcc_minsizerel_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_minsizerel_asan_no_ort" + "name": "linux_gcc_minsizerel_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_minsizerel_asan_no_ort" + "name": "linux_gcc_minsizerel_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_minsizerel_asan_no_ort" + "name": "linux_gcc_minsizerel_asan_no_ort", + "type": "test" } ] }, @@ -2710,16 +3231,16 @@ "name": "linux_gcc_minsizerel_asan_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_minsizerel_asan" + "name": "linux_gcc_minsizerel_asan", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_minsizerel_asan" + "name": "linux_gcc_minsizerel_asan", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_minsizerel_asan" + "name": "linux_gcc_minsizerel_asan", + "type": "test" } ] }, @@ -2727,16 +3248,16 @@ "name": "linux_gcc_minsizerel_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_minsizerel_no_ort" + "name": "linux_gcc_minsizerel_no_ort", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_minsizerel_no_ort" + "name": "linux_gcc_minsizerel_no_ort", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_minsizerel_no_ort" + "name": "linux_gcc_minsizerel_no_ort", + "type": "test" } ] }, @@ -2744,16 +3265,16 @@ "name": "linux_gcc_minsizerel_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_minsizerel" + "name": "linux_gcc_minsizerel", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_minsizerel" + "name": "linux_gcc_minsizerel", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_minsizerel" + "name": "linux_gcc_minsizerel", + "type": "test" } ] }, @@ -2761,16 +3282,16 @@ "name": "linux_gcc_release_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_release_asan_no_ort" + "name": "linux_gcc_release_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_release_asan_no_ort" + "name": "linux_gcc_release_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_release_asan_no_ort" + "name": "linux_gcc_release_asan_no_ort", + "type": "test" } ] }, @@ -2778,16 +3299,16 @@ "name": "linux_gcc_release_asan_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_release_asan" + "name": "linux_gcc_release_asan", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_release_asan" + "name": "linux_gcc_release_asan", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_release_asan" + "name": "linux_gcc_release_asan", + "type": "test" } ] }, @@ -2795,16 +3316,16 @@ "name": "linux_gcc_release_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_release_no_ort" + "name": "linux_gcc_release_no_ort", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_release_no_ort" + "name": "linux_gcc_release_no_ort", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_release_no_ort" + "name": "linux_gcc_release_no_ort", + "type": "test" } ] }, @@ -2812,16 +3333,16 @@ "name": "linux_gcc_release_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_release" + "name": "linux_gcc_release", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_release" + "name": "linux_gcc_release", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_release" + "name": "linux_gcc_release", + "type": "test" } ] }, @@ -2829,16 +3350,16 @@ "name": "linux_gcc_relwithdebinfo_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_relwithdebinfo_asan_no_ort" + "name": "linux_gcc_relwithdebinfo_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_relwithdebinfo_asan_no_ort" + "name": "linux_gcc_relwithdebinfo_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_relwithdebinfo_asan_no_ort" + "name": "linux_gcc_relwithdebinfo_asan_no_ort", + "type": "test" } ] }, @@ -2846,16 +3367,16 @@ "name": "linux_gcc_relwithdebinfo_asan_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_relwithdebinfo_asan" + "name": "linux_gcc_relwithdebinfo_asan", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_relwithdebinfo_asan" + "name": "linux_gcc_relwithdebinfo_asan", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_relwithdebinfo_asan" + "name": "linux_gcc_relwithdebinfo_asan", + "type": "test" } ] }, @@ -2863,16 +3384,16 @@ "name": "linux_gcc_relwithdebinfo_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_relwithdebinfo_no_ort" + "name": "linux_gcc_relwithdebinfo_no_ort", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_relwithdebinfo_no_ort" + "name": "linux_gcc_relwithdebinfo_no_ort", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_relwithdebinfo_no_ort" + "name": "linux_gcc_relwithdebinfo_no_ort", + "type": "test" } ] }, @@ -2880,16 +3401,16 @@ "name": "linux_gcc_relwithdebinfo_workflow", "steps": [ { - "type": "configure", - "name": "linux_gcc_relwithdebinfo" + "name": "linux_gcc_relwithdebinfo", + "type": "configure" }, { - "type": "build", - "name": "linux_gcc_relwithdebinfo" + "name": "linux_gcc_relwithdebinfo", + "type": "build" }, { - "type": "test", - "name": "linux_gcc_relwithdebinfo" + "name": "linux_gcc_relwithdebinfo", + "type": "test" } ] }, @@ -2897,16 +3418,16 @@ "name": "macos_arm64_debug_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_arm64_debug_asan" + "name": "macos_arm64_debug_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_arm64_debug_asan" + "name": "macos_arm64_debug_asan", + "type": "build" }, { - "type": "test", - "name": "macos_arm64_debug_asan" + "name": "macos_arm64_debug_asan", + "type": "test" } ] }, @@ -2914,16 +3435,16 @@ "name": "macos_arm64_debug_workflow", "steps": [ { - "type": "configure", - "name": "macos_arm64_debug" + "name": "macos_arm64_debug", + "type": "configure" }, { - "type": "build", - "name": "macos_arm64_debug" + "name": "macos_arm64_debug", + "type": "build" }, { - "type": "test", - "name": "macos_arm64_debug" + "name": "macos_arm64_debug", + "type": "test" } ] }, @@ -2931,16 +3452,16 @@ "name": "macos_arm64_minsizerel_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_arm64_minsizerel_asan" + "name": "macos_arm64_minsizerel_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_arm64_minsizerel_asan" + "name": "macos_arm64_minsizerel_asan", + "type": "build" }, { - "type": "test", - "name": "macos_arm64_minsizerel_asan" + "name": "macos_arm64_minsizerel_asan", + "type": "test" } ] }, @@ -2948,16 +3469,16 @@ "name": "macos_arm64_minsizerel_workflow", "steps": [ { - "type": "configure", - "name": "macos_arm64_minsizerel" + "name": "macos_arm64_minsizerel", + "type": "configure" }, { - "type": "build", - "name": "macos_arm64_minsizerel" + "name": "macos_arm64_minsizerel", + "type": "build" }, { - "type": "test", - "name": "macos_arm64_minsizerel" + "name": "macos_arm64_minsizerel", + "type": "test" } ] }, @@ -2965,16 +3486,16 @@ "name": "macos_arm64_release_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_arm64_release_asan" + "name": "macos_arm64_release_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_arm64_release_asan" + "name": "macos_arm64_release_asan", + "type": "build" }, { - "type": "test", - "name": "macos_arm64_release_asan" + "name": "macos_arm64_release_asan", + "type": "test" } ] }, @@ -2982,16 +3503,16 @@ "name": "macos_arm64_release_workflow", "steps": [ { - "type": "configure", - "name": "macos_arm64_release" + "name": "macos_arm64_release", + "type": "configure" }, { - "type": "build", - "name": "macos_arm64_release" + "name": "macos_arm64_release", + "type": "build" }, { - "type": "test", - "name": "macos_arm64_release" + "name": "macos_arm64_release", + "type": "test" } ] }, @@ -2999,16 +3520,16 @@ "name": "macos_arm64_relwithdebinfo_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_arm64_relwithdebinfo_asan" + "name": "macos_arm64_relwithdebinfo_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_arm64_relwithdebinfo_asan" + "name": "macos_arm64_relwithdebinfo_asan", + "type": "build" }, { - "type": "test", - "name": "macos_arm64_relwithdebinfo_asan" + "name": "macos_arm64_relwithdebinfo_asan", + "type": "test" } ] }, @@ -3016,16 +3537,16 @@ "name": "macos_arm64_relwithdebinfo_workflow", "steps": [ { - "type": "configure", - "name": "macos_arm64_relwithdebinfo" + "name": "macos_arm64_relwithdebinfo", + "type": "configure" }, { - "type": "build", - "name": "macos_arm64_relwithdebinfo" + "name": "macos_arm64_relwithdebinfo", + "type": "build" }, { - "type": "test", - "name": "macos_arm64_relwithdebinfo" + "name": "macos_arm64_relwithdebinfo", + "type": "test" } ] }, @@ -3033,16 +3554,16 @@ "name": "macos_universal2_debug_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_universal2_debug_asan" + "name": "macos_universal2_debug_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_universal2_debug_asan" + "name": "macos_universal2_debug_asan", + "type": "build" }, { - "type": "test", - "name": "macos_universal2_debug_asan" + "name": "macos_universal2_debug_asan", + "type": "test" } ] }, @@ -3050,16 +3571,16 @@ "name": "macos_universal2_debug_workflow", "steps": [ { - "type": "configure", - "name": "macos_universal2_debug" + "name": "macos_universal2_debug", + "type": "configure" }, { - "type": "build", - "name": "macos_universal2_debug" + "name": "macos_universal2_debug", + "type": "build" }, { - "type": "test", - "name": "macos_universal2_debug" + "name": "macos_universal2_debug", + "type": "test" } ] }, @@ -3067,16 +3588,16 @@ "name": "macos_universal2_minsizerel_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_universal2_minsizerel_asan" + "name": "macos_universal2_minsizerel_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_universal2_minsizerel_asan" + "name": "macos_universal2_minsizerel_asan", + "type": "build" }, { - "type": "test", - "name": "macos_universal2_minsizerel_asan" + "name": "macos_universal2_minsizerel_asan", + "type": "test" } ] }, @@ -3084,16 +3605,16 @@ "name": "macos_universal2_minsizerel_workflow", "steps": [ { - "type": "configure", - "name": "macos_universal2_minsizerel" + "name": "macos_universal2_minsizerel", + "type": "configure" }, { - "type": "build", - "name": "macos_universal2_minsizerel" + "name": "macos_universal2_minsizerel", + "type": "build" }, { - "type": "test", - "name": "macos_universal2_minsizerel" + "name": "macos_universal2_minsizerel", + "type": "test" } ] }, @@ -3101,16 +3622,16 @@ "name": "macos_universal2_release_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_universal2_release_asan" + "name": "macos_universal2_release_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_universal2_release_asan" + "name": "macos_universal2_release_asan", + "type": "build" }, { - "type": "test", - "name": "macos_universal2_release_asan" + "name": "macos_universal2_release_asan", + "type": "test" } ] }, @@ -3118,16 +3639,16 @@ "name": "macos_universal2_release_workflow", "steps": [ { - "type": "configure", - "name": "macos_universal2_release" + "name": "macos_universal2_release", + "type": "configure" }, { - "type": "build", - "name": "macos_universal2_release" + "name": "macos_universal2_release", + "type": "build" }, { - "type": "test", - "name": "macos_universal2_release" + "name": "macos_universal2_release", + "type": "test" } ] }, @@ -3135,16 +3656,16 @@ "name": "macos_universal2_relwithdebinfo_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_universal2_relwithdebinfo_asan" + "name": "macos_universal2_relwithdebinfo_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_universal2_relwithdebinfo_asan" + "name": "macos_universal2_relwithdebinfo_asan", + "type": "build" }, { - "type": "test", - "name": "macos_universal2_relwithdebinfo_asan" + "name": "macos_universal2_relwithdebinfo_asan", + "type": "test" } ] }, @@ -3152,16 +3673,16 @@ "name": "macos_universal2_relwithdebinfo_workflow", "steps": [ { - "type": "configure", - "name": "macos_universal2_relwithdebinfo" + "name": "macos_universal2_relwithdebinfo", + "type": "configure" }, { - "type": "build", - "name": "macos_universal2_relwithdebinfo" + "name": "macos_universal2_relwithdebinfo", + "type": "build" }, { - "type": "test", - "name": "macos_universal2_relwithdebinfo" + "name": "macos_universal2_relwithdebinfo", + "type": "test" } ] }, @@ -3169,16 +3690,16 @@ "name": "macos_x86_64_debug_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_x86_64_debug_asan" + "name": "macos_x86_64_debug_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_x86_64_debug_asan" + "name": "macos_x86_64_debug_asan", + "type": "build" }, { - "type": "test", - "name": "macos_x86_64_debug_asan" + "name": "macos_x86_64_debug_asan", + "type": "test" } ] }, @@ -3186,16 +3707,16 @@ "name": "macos_x86_64_debug_workflow", "steps": [ { - "type": "configure", - "name": "macos_x86_64_debug" + "name": "macos_x86_64_debug", + "type": "configure" }, { - "type": "build", - "name": "macos_x86_64_debug" + "name": "macos_x86_64_debug", + "type": "build" }, { - "type": "test", - "name": "macos_x86_64_debug" + "name": "macos_x86_64_debug", + "type": "test" } ] }, @@ -3203,16 +3724,16 @@ "name": "macos_x86_64_minsizerel_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_x86_64_minsizerel_asan" + "name": "macos_x86_64_minsizerel_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_x86_64_minsizerel_asan" + "name": "macos_x86_64_minsizerel_asan", + "type": "build" }, { - "type": "test", - "name": "macos_x86_64_minsizerel_asan" + "name": "macos_x86_64_minsizerel_asan", + "type": "test" } ] }, @@ -3220,16 +3741,16 @@ "name": "macos_x86_64_minsizerel_workflow", "steps": [ { - "type": "configure", - "name": "macos_x86_64_minsizerel" + "name": "macos_x86_64_minsizerel", + "type": "configure" }, { - "type": "build", - "name": "macos_x86_64_minsizerel" + "name": "macos_x86_64_minsizerel", + "type": "build" }, { - "type": "test", - "name": "macos_x86_64_minsizerel" + "name": "macos_x86_64_minsizerel", + "type": "test" } ] }, @@ -3237,16 +3758,16 @@ "name": "macos_x86_64_release_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_x86_64_release_asan" + "name": "macos_x86_64_release_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_x86_64_release_asan" + "name": "macos_x86_64_release_asan", + "type": "build" }, { - "type": "test", - "name": "macos_x86_64_release_asan" + "name": "macos_x86_64_release_asan", + "type": "test" } ] }, @@ -3254,16 +3775,16 @@ "name": "macos_x86_64_release_workflow", "steps": [ { - "type": "configure", - "name": "macos_x86_64_release" + "name": "macos_x86_64_release", + "type": "configure" }, { - "type": "build", - "name": "macos_x86_64_release" + "name": "macos_x86_64_release", + "type": "build" }, { - "type": "test", - "name": "macos_x86_64_release" + "name": "macos_x86_64_release", + "type": "test" } ] }, @@ -3271,16 +3792,16 @@ "name": "macos_x86_64_relwithdebinfo_asan_workflow", "steps": [ { - "type": "configure", - "name": "macos_x86_64_relwithdebinfo_asan" + "name": "macos_x86_64_relwithdebinfo_asan", + "type": "configure" }, { - "type": "build", - "name": "macos_x86_64_relwithdebinfo_asan" + "name": "macos_x86_64_relwithdebinfo_asan", + "type": "build" }, { - "type": "test", - "name": "macos_x86_64_relwithdebinfo_asan" + "name": "macos_x86_64_relwithdebinfo_asan", + "type": "test" } ] }, @@ -3288,16 +3809,288 @@ "name": "macos_x86_64_relwithdebinfo_workflow", "steps": [ { - "type": "configure", - "name": "macos_x86_64_relwithdebinfo" + "name": "macos_x86_64_relwithdebinfo", + "type": "configure" + }, + { + "name": "macos_x86_64_relwithdebinfo", + "type": "build" + }, + { + "name": "macos_x86_64_relwithdebinfo", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_debug_asan_no_ort_workflow", + "steps": [ + { + "name": "windows_arm64_debug_asan_no_ort", + "type": "configure" + }, + { + "name": "windows_arm64_debug_asan_no_ort", + "type": "build" + }, + { + "name": "windows_arm64_debug_asan_no_ort", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_debug_asan_workflow", + "steps": [ + { + "name": "windows_arm64_debug_asan", + "type": "configure" + }, + { + "name": "windows_arm64_debug_asan", + "type": "build" + }, + { + "name": "windows_arm64_debug_asan", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_debug_no_ort_workflow", + "steps": [ + { + "name": "windows_arm64_debug_no_ort", + "type": "configure" + }, + { + "name": "windows_arm64_debug_no_ort", + "type": "build" + }, + { + "name": "windows_arm64_debug_no_ort", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_debug_workflow", + "steps": [ + { + "name": "windows_arm64_debug", + "type": "configure" + }, + { + "name": "windows_arm64_debug", + "type": "build" + }, + { + "name": "windows_arm64_debug", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_minsizerel_asan_no_ort_workflow", + "steps": [ + { + "name": "windows_arm64_minsizerel_asan_no_ort", + "type": "configure" + }, + { + "name": "windows_arm64_minsizerel_asan_no_ort", + "type": "build" + }, + { + "name": "windows_arm64_minsizerel_asan_no_ort", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_minsizerel_asan_workflow", + "steps": [ + { + "name": "windows_arm64_minsizerel_asan", + "type": "configure" + }, + { + "name": "windows_arm64_minsizerel_asan", + "type": "build" + }, + { + "name": "windows_arm64_minsizerel_asan", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_minsizerel_no_ort_workflow", + "steps": [ + { + "name": "windows_arm64_minsizerel_no_ort", + "type": "configure" + }, + { + "name": "windows_arm64_minsizerel_no_ort", + "type": "build" + }, + { + "name": "windows_arm64_minsizerel_no_ort", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_minsizerel_workflow", + "steps": [ + { + "name": "windows_arm64_minsizerel", + "type": "configure" + }, + { + "name": "windows_arm64_minsizerel", + "type": "build" + }, + { + "name": "windows_arm64_minsizerel", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_release_asan_no_ort_workflow", + "steps": [ + { + "name": "windows_arm64_release_asan_no_ort", + "type": "configure" + }, + { + "name": "windows_arm64_release_asan_no_ort", + "type": "build" + }, + { + "name": "windows_arm64_release_asan_no_ort", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_release_asan_workflow", + "steps": [ + { + "name": "windows_arm64_release_asan", + "type": "configure" + }, + { + "name": "windows_arm64_release_asan", + "type": "build" + }, + { + "name": "windows_arm64_release_asan", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_release_no_ort_workflow", + "steps": [ + { + "name": "windows_arm64_release_no_ort", + "type": "configure" + }, + { + "name": "windows_arm64_release_no_ort", + "type": "build" + }, + { + "name": "windows_arm64_release_no_ort", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_release_workflow", + "steps": [ + { + "name": "windows_arm64_release", + "type": "configure" + }, + { + "name": "windows_arm64_release", + "type": "build" + }, + { + "name": "windows_arm64_release", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_relwithdebinfo_asan_no_ort_workflow", + "steps": [ + { + "name": "windows_arm64_relwithdebinfo_asan_no_ort", + "type": "configure" + }, + { + "name": "windows_arm64_relwithdebinfo_asan_no_ort", + "type": "build" + }, + { + "name": "windows_arm64_relwithdebinfo_asan_no_ort", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_relwithdebinfo_asan_workflow", + "steps": [ + { + "name": "windows_arm64_relwithdebinfo_asan", + "type": "configure" + }, + { + "name": "windows_arm64_relwithdebinfo_asan", + "type": "build" + }, + { + "name": "windows_arm64_relwithdebinfo_asan", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_relwithdebinfo_no_ort_workflow", + "steps": [ + { + "name": "windows_arm64_relwithdebinfo_no_ort", + "type": "configure" + }, + { + "name": "windows_arm64_relwithdebinfo_no_ort", + "type": "build" + }, + { + "name": "windows_arm64_relwithdebinfo_no_ort", + "type": "test" + } + ] + }, + { + "name": "windows_arm64_relwithdebinfo_workflow", + "steps": [ + { + "name": "windows_arm64_relwithdebinfo", + "type": "configure" }, { - "type": "build", - "name": "macos_x86_64_relwithdebinfo" + "name": "windows_arm64_relwithdebinfo", + "type": "build" }, { - "type": "test", - "name": "macos_x86_64_relwithdebinfo" + "name": "windows_arm64_relwithdebinfo", + "type": "test" } ] }, @@ -3305,16 +4098,16 @@ "name": "windows_win32_debug_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_debug_asan_no_ort" + "name": "windows_win32_debug_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_debug_asan_no_ort" + "name": "windows_win32_debug_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_win32_debug_asan_no_ort" + "name": "windows_win32_debug_asan_no_ort", + "type": "test" } ] }, @@ -3322,16 +4115,16 @@ "name": "windows_win32_debug_asan_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_debug_asan" + "name": "windows_win32_debug_asan", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_debug_asan" + "name": "windows_win32_debug_asan", + "type": "build" }, { - "type": "test", - "name": "windows_win32_debug_asan" + "name": "windows_win32_debug_asan", + "type": "test" } ] }, @@ -3339,16 +4132,16 @@ "name": "windows_win32_debug_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_debug_no_ort" + "name": "windows_win32_debug_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_debug_no_ort" + "name": "windows_win32_debug_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_win32_debug_no_ort" + "name": "windows_win32_debug_no_ort", + "type": "test" } ] }, @@ -3356,16 +4149,16 @@ "name": "windows_win32_debug_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_debug" + "name": "windows_win32_debug", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_debug" + "name": "windows_win32_debug", + "type": "build" }, { - "type": "test", - "name": "windows_win32_debug" + "name": "windows_win32_debug", + "type": "test" } ] }, @@ -3373,16 +4166,16 @@ "name": "windows_win32_minsizerel_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_minsizerel_asan_no_ort" + "name": "windows_win32_minsizerel_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_minsizerel_asan_no_ort" + "name": "windows_win32_minsizerel_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_win32_minsizerel_asan_no_ort" + "name": "windows_win32_minsizerel_asan_no_ort", + "type": "test" } ] }, @@ -3390,16 +4183,16 @@ "name": "windows_win32_minsizerel_asan_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_minsizerel_asan" + "name": "windows_win32_minsizerel_asan", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_minsizerel_asan" + "name": "windows_win32_minsizerel_asan", + "type": "build" }, { - "type": "test", - "name": "windows_win32_minsizerel_asan" + "name": "windows_win32_minsizerel_asan", + "type": "test" } ] }, @@ -3407,16 +4200,16 @@ "name": "windows_win32_minsizerel_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_minsizerel_no_ort" + "name": "windows_win32_minsizerel_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_minsizerel_no_ort" + "name": "windows_win32_minsizerel_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_win32_minsizerel_no_ort" + "name": "windows_win32_minsizerel_no_ort", + "type": "test" } ] }, @@ -3424,16 +4217,16 @@ "name": "windows_win32_minsizerel_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_minsizerel" + "name": "windows_win32_minsizerel", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_minsizerel" + "name": "windows_win32_minsizerel", + "type": "build" }, { - "type": "test", - "name": "windows_win32_minsizerel" + "name": "windows_win32_minsizerel", + "type": "test" } ] }, @@ -3441,16 +4234,16 @@ "name": "windows_win32_release_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_release_asan_no_ort" + "name": "windows_win32_release_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_release_asan_no_ort" + "name": "windows_win32_release_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_win32_release_asan_no_ort" + "name": "windows_win32_release_asan_no_ort", + "type": "test" } ] }, @@ -3458,16 +4251,16 @@ "name": "windows_win32_release_asan_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_release_asan" + "name": "windows_win32_release_asan", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_release_asan" + "name": "windows_win32_release_asan", + "type": "build" }, { - "type": "test", - "name": "windows_win32_release_asan" + "name": "windows_win32_release_asan", + "type": "test" } ] }, @@ -3475,16 +4268,16 @@ "name": "windows_win32_release_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_release_no_ort" + "name": "windows_win32_release_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_release_no_ort" + "name": "windows_win32_release_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_win32_release_no_ort" + "name": "windows_win32_release_no_ort", + "type": "test" } ] }, @@ -3492,16 +4285,16 @@ "name": "windows_win32_release_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_release" + "name": "windows_win32_release", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_release" + "name": "windows_win32_release", + "type": "build" }, { - "type": "test", - "name": "windows_win32_release" + "name": "windows_win32_release", + "type": "test" } ] }, @@ -3509,16 +4302,16 @@ "name": "windows_win32_relwithdebinfo_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_relwithdebinfo_asan_no_ort" + "name": "windows_win32_relwithdebinfo_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_relwithdebinfo_asan_no_ort" + "name": "windows_win32_relwithdebinfo_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_win32_relwithdebinfo_asan_no_ort" + "name": "windows_win32_relwithdebinfo_asan_no_ort", + "type": "test" } ] }, @@ -3526,16 +4319,16 @@ "name": "windows_win32_relwithdebinfo_asan_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_relwithdebinfo_asan" + "name": "windows_win32_relwithdebinfo_asan", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_relwithdebinfo_asan" + "name": "windows_win32_relwithdebinfo_asan", + "type": "build" }, { - "type": "test", - "name": "windows_win32_relwithdebinfo_asan" + "name": "windows_win32_relwithdebinfo_asan", + "type": "test" } ] }, @@ -3543,16 +4336,16 @@ "name": "windows_win32_relwithdebinfo_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_relwithdebinfo_no_ort" + "name": "windows_win32_relwithdebinfo_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_relwithdebinfo_no_ort" + "name": "windows_win32_relwithdebinfo_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_win32_relwithdebinfo_no_ort" + "name": "windows_win32_relwithdebinfo_no_ort", + "type": "test" } ] }, @@ -3560,16 +4353,16 @@ "name": "windows_win32_relwithdebinfo_workflow", "steps": [ { - "type": "configure", - "name": "windows_win32_relwithdebinfo" + "name": "windows_win32_relwithdebinfo", + "type": "configure" }, { - "type": "build", - "name": "windows_win32_relwithdebinfo" + "name": "windows_win32_relwithdebinfo", + "type": "build" }, { - "type": "test", - "name": "windows_win32_relwithdebinfo" + "name": "windows_win32_relwithdebinfo", + "type": "test" } ] }, @@ -3577,16 +4370,16 @@ "name": "windows_x64_debug_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_debug_asan_no_ort" + "name": "windows_x64_debug_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_debug_asan_no_ort" + "name": "windows_x64_debug_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_x64_debug_asan_no_ort" + "name": "windows_x64_debug_asan_no_ort", + "type": "test" } ] }, @@ -3594,16 +4387,16 @@ "name": "windows_x64_debug_asan_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_debug_asan" + "name": "windows_x64_debug_asan", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_debug_asan" + "name": "windows_x64_debug_asan", + "type": "build" }, { - "type": "test", - "name": "windows_x64_debug_asan" + "name": "windows_x64_debug_asan", + "type": "test" } ] }, @@ -3611,16 +4404,16 @@ "name": "windows_x64_debug_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_debug_no_ort" + "name": "windows_x64_debug_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_debug_no_ort" + "name": "windows_x64_debug_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_x64_debug_no_ort" + "name": "windows_x64_debug_no_ort", + "type": "test" } ] }, @@ -3628,16 +4421,16 @@ "name": "windows_x64_debug_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_debug" + "name": "windows_x64_debug", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_debug" + "name": "windows_x64_debug", + "type": "build" }, { - "type": "test", - "name": "windows_x64_debug" + "name": "windows_x64_debug", + "type": "test" } ] }, @@ -3645,16 +4438,16 @@ "name": "windows_x64_minsizerel_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_minsizerel_asan_no_ort" + "name": "windows_x64_minsizerel_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_minsizerel_asan_no_ort" + "name": "windows_x64_minsizerel_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_x64_minsizerel_asan_no_ort" + "name": "windows_x64_minsizerel_asan_no_ort", + "type": "test" } ] }, @@ -3662,16 +4455,16 @@ "name": "windows_x64_minsizerel_asan_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_minsizerel_asan" + "name": "windows_x64_minsizerel_asan", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_minsizerel_asan" + "name": "windows_x64_minsizerel_asan", + "type": "build" }, { - "type": "test", - "name": "windows_x64_minsizerel_asan" + "name": "windows_x64_minsizerel_asan", + "type": "test" } ] }, @@ -3679,16 +4472,16 @@ "name": "windows_x64_minsizerel_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_minsizerel_no_ort" + "name": "windows_x64_minsizerel_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_minsizerel_no_ort" + "name": "windows_x64_minsizerel_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_x64_minsizerel_no_ort" + "name": "windows_x64_minsizerel_no_ort", + "type": "test" } ] }, @@ -3696,16 +4489,16 @@ "name": "windows_x64_minsizerel_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_minsizerel" + "name": "windows_x64_minsizerel", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_minsizerel" + "name": "windows_x64_minsizerel", + "type": "build" }, { - "type": "test", - "name": "windows_x64_minsizerel" + "name": "windows_x64_minsizerel", + "type": "test" } ] }, @@ -3713,16 +4506,16 @@ "name": "windows_x64_release_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_release_asan_no_ort" + "name": "windows_x64_release_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_release_asan_no_ort" + "name": "windows_x64_release_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_x64_release_asan_no_ort" + "name": "windows_x64_release_asan_no_ort", + "type": "test" } ] }, @@ -3730,16 +4523,16 @@ "name": "windows_x64_release_asan_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_release_asan" + "name": "windows_x64_release_asan", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_release_asan" + "name": "windows_x64_release_asan", + "type": "build" }, { - "type": "test", - "name": "windows_x64_release_asan" + "name": "windows_x64_release_asan", + "type": "test" } ] }, @@ -3747,16 +4540,16 @@ "name": "windows_x64_release_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_release_no_ort" + "name": "windows_x64_release_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_release_no_ort" + "name": "windows_x64_release_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_x64_release_no_ort" + "name": "windows_x64_release_no_ort", + "type": "test" } ] }, @@ -3764,16 +4557,16 @@ "name": "windows_x64_release_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_release" + "name": "windows_x64_release", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_release" + "name": "windows_x64_release", + "type": "build" }, { - "type": "test", - "name": "windows_x64_release" + "name": "windows_x64_release", + "type": "test" } ] }, @@ -3781,16 +4574,16 @@ "name": "windows_x64_relwithdebinfo_asan_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_relwithdebinfo_asan_no_ort" + "name": "windows_x64_relwithdebinfo_asan_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_relwithdebinfo_asan_no_ort" + "name": "windows_x64_relwithdebinfo_asan_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_x64_relwithdebinfo_asan_no_ort" + "name": "windows_x64_relwithdebinfo_asan_no_ort", + "type": "test" } ] }, @@ -3798,16 +4591,16 @@ "name": "windows_x64_relwithdebinfo_asan_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_relwithdebinfo_asan" + "name": "windows_x64_relwithdebinfo_asan", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_relwithdebinfo_asan" + "name": "windows_x64_relwithdebinfo_asan", + "type": "build" }, { - "type": "test", - "name": "windows_x64_relwithdebinfo_asan" + "name": "windows_x64_relwithdebinfo_asan", + "type": "test" } ] }, @@ -3815,16 +4608,16 @@ "name": "windows_x64_relwithdebinfo_no_ort_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_relwithdebinfo_no_ort" + "name": "windows_x64_relwithdebinfo_no_ort", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_relwithdebinfo_no_ort" + "name": "windows_x64_relwithdebinfo_no_ort", + "type": "build" }, { - "type": "test", - "name": "windows_x64_relwithdebinfo_no_ort" + "name": "windows_x64_relwithdebinfo_no_ort", + "type": "test" } ] }, @@ -3832,16 +4625,16 @@ "name": "windows_x64_relwithdebinfo_workflow", "steps": [ { - "type": "configure", - "name": "windows_x64_relwithdebinfo" + "name": "windows_x64_relwithdebinfo", + "type": "configure" }, { - "type": "build", - "name": "windows_x64_relwithdebinfo" + "name": "windows_x64_relwithdebinfo", + "type": "build" }, { - "type": "test", - "name": "windows_x64_relwithdebinfo" + "name": "windows_x64_relwithdebinfo", + "type": "test" } ] } From 4bd25cca1a57b5ee70dcdc89dcc946cb8c40f294 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 12:28:18 -0700 Subject: [PATCH 23/33] update --- src/core/platform/posix/env.cc | 104 +-------------------------------- 1 file changed, 1 insertion(+), 103 deletions(-) diff --git a/src/core/platform/posix/env.cc b/src/core/platform/posix/env.cc index 43c6c4d..9e33bcc 100644 --- a/src/core/platform/posix/env.cc +++ b/src/core/platform/posix/env.cc @@ -284,104 +284,12 @@ class PosixEnv : public Env { #endif } - void SleepForMicroseconds(int64_t micros) const override { - while (micros > 0) { - timespec sleep_time; - sleep_time.tv_sec = 0; - sleep_time.tv_nsec = 0; - - if (micros >= OneMillion) { - sleep_time.tv_sec = static_cast(std::min(micros / OneMillion, - std::numeric_limits::max())); - micros -= static_cast(sleep_time.tv_sec) * OneMillion; - } - if (micros < OneMillion) { - sleep_time.tv_nsec = static_cast(1000 * micros); - micros = 0; - } - while (nanosleep(&sleep_time, &sleep_time) != 0 && errno == EINTR) { - // Ignore signals and wait for the full interval to elapse. - } - } - } PIDType GetSelfPid() const override { return getpid(); } - Status GetFileLength(const PathChar* file_path, size_t& length) const override { - ScopedFileDescriptor file_descriptor{open(file_path, O_RDONLY)}; - return GetFileLength(file_descriptor.Get(), length); - } - - common::Status GetFileLength(int fd, /*out*/ size_t& file_size) const override { - using namespace common; - if (fd < 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid fd was supplied: ", fd); - } - - struct stat buf; - int rc = fstat(fd, &buf); - if (rc < 0) { - return ReportSystemError("fstat", ""); - } - - if (buf.st_size < 0) { - return ORT_MAKE_STATUS(SYSTEM, FAIL, "Received negative size from stat call"); - } - - if (static_cast(buf.st_size) > std::numeric_limits::max()) { - return ORT_MAKE_STATUS(SYSTEM, FAIL, "File is too large."); - } - - file_size = static_cast(buf.st_size); - return Status::OK(); - } - - Status ReadFileIntoBuffer(const ORTCHAR_T* file_path, FileOffsetType offset, size_t length, - gsl::span buffer) const override { - ORT_RETURN_IF_NOT(file_path, "file_path == nullptr"); - ORT_RETURN_IF_NOT(offset >= 0, "offset < 0"); - ORT_RETURN_IF_NOT(length <= buffer.size(), "length > buffer.size()"); - - ScopedFileDescriptor file_descriptor{open(file_path, O_RDONLY)}; - if (!file_descriptor.IsValid()) { - return ReportSystemError("open", file_path); - } - - if (length == 0) - return Status::OK(); - - if (offset > 0) { - const FileOffsetType seek_result = lseek(file_descriptor.Get(), offset, SEEK_SET); - if (seek_result == -1) { - return ReportSystemError("lseek", file_path); - } - } - - size_t total_bytes_read = 0; - while (total_bytes_read < length) { - constexpr size_t k_max_bytes_to_read = 1 << 30; // read at most 1GB each time - const size_t bytes_remaining = length - total_bytes_read; - const size_t bytes_to_read = std::min(bytes_remaining, k_max_bytes_to_read); - - const ssize_t bytes_read = - TempFailureRetry(read, file_descriptor.Get(), buffer.data() + total_bytes_read, bytes_to_read); - - if (bytes_read == -1) { - return ReportSystemError("read", file_path); - } - - if (bytes_read == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ReadFileIntoBuffer - unexpected end of file. ", "File: ", file_path, - ", offset: ", offset, ", length: ", length); - } - - total_bytes_read += bytes_read; - } - - return Status::OK(); - } + static common::Status ReportSystemError(const char* operation_name, const std::string& path) { auto [err_no, err_msg] = GetErrnoInfo(); @@ -390,16 +298,6 @@ class PosixEnv : public Env { return common::Status(common::SYSTEM, err_no, oss.str()); } - common::Status GetCanonicalPath( - const PathString& path, - PathString& canonical_path) const override { - MallocdStringPtr canonical_path_cstr{realpath(path.c_str(), nullptr), Freer()}; - if (!canonical_path_cstr) { - return ReportSystemError("realpath", path); - } - canonical_path.assign(canonical_path_cstr.get()); - return Status::OK(); - } // \brief returns a value for the queried variable name (var_name) std::string GetEnvironmentVar(const std::string& var_name) const override { From 47a557a920d810cf73a4f5250493ccd277c9bed7 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 12:40:03 -0700 Subject: [PATCH 24/33] update --- .github/workflows/reusable_windows_build.yml | 24 ++++++++--------- .github/workflows/win_ci.yml | 27 ++++++++++---------- 2 files changed, 25 insertions(+), 26 deletions(-) diff --git a/.github/workflows/reusable_windows_build.yml b/.github/workflows/reusable_windows_build.yml index b3275ae..224c5c6 100644 --- a/.github/workflows/reusable_windows_build.yml +++ b/.github/workflows/reusable_windows_build.yml @@ -6,6 +6,9 @@ on: job-name: # For display purposes in the reusable workflow logs required: true type: string + runner-os: # New input for the runner OS + required: true + type: string cmake-workflow-preset: required: true type: string @@ -25,16 +28,13 @@ on: permissions: actions: read contents: read - # security-events: write is needed only if CodeQL is enabled and uploads SARIF - # We'll set it at the job level within this reusable workflow for clarity. + security-events: write # Needed if CodeQL analysis runs & uploads jobs: build_and_optional_analyze: name: ${{ inputs.job-name }} - runs-on: windows-2022 - # Define permissions here based on whether CodeQL might run. - # If enable-codeql can be true, then security-events: write is needed. - permissions: + runs-on: ${{ inputs.runner-os }} # Use the input for runs-on + permissions: # Permissions moved here as they are job-specific actions: read contents: read security-events: write # Needed if CodeQL analysis runs & uploads @@ -53,15 +53,15 @@ jobs: - name: Run CMake Workflow run: | cmake --workflow --preset ${{ inputs.cmake-workflow-preset }} - shell: cmd # Ensuring shell is explicit for windows runners if needed + shell: cmd - name: Perform CodeQL Analysis (if enabled) if: ${{ inputs.enable-codeql }} uses: github/codeql-action/analyze@v3 with: - category: "/language:cpp" # Category for the analysis - output: ${{ inputs.codeql-sarif-output-dir }} # Directory for SARIF files - upload: failure-only # Upload SARIF results only on failure + category: "/language:cpp" + output: ${{ inputs.codeql-sarif-output-dir }} + upload: failure-only - name: Filter SARIF (if CodeQL enabled) if: ${{ inputs.enable-codeql }} @@ -73,11 +73,11 @@ jobs: -tests/**/*.* -build/**/*.* input: ${{ inputs.codeql-sarif-output-dir }}/cpp.sarif - output: ${{ inputs.codeql-sarif-output-dir }}/cpp.sarif.filtered # Output to a new file + output: ${{ inputs.codeql-sarif-output-dir }}/cpp.sarif.filtered - name: Upload filtered SARIF (if CodeQL enabled) if: ${{ inputs.enable-codeql }} uses: github/codeql-action/upload-sarif@v3 with: sarif_file: ${{ inputs.codeql-sarif-output-dir }}/cpp.sarif.filtered - category: cpp-${{ inputs.job-name }} # Make category unique if needed \ No newline at end of file + category: cpp-${{ inputs.job-name }} \ No newline at end of file diff --git a/.github/workflows/win_ci.yml b/.github/workflows/win_ci.yml index e048e2e..f1bcf6a 100644 --- a/.github/workflows/win_ci.yml +++ b/.github/workflows/win_ci.yml @@ -8,7 +8,7 @@ on: pull_request: concurrency: - group: ${{ github.workflow }}-${{ github.event_name == 'pull_request' && github.ref || github.sha }} + group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true jobs: @@ -17,14 +17,15 @@ jobs: uses: ./.github/workflows/reusable_windows_build.yml with: job-name: Win32_Debug_NoOrt_CodeQL + runner-os: windows-2022 cmake-workflow-preset: windows_win32_debug_no_ort_workflow enable-codeql: true - # codeql-config-file can be omitted if default is fine Win32_release_no_ort: uses: ./.github/workflows/reusable_windows_build.yml with: job-name: Win32_Release_NoOrt + runner-os: windows-2022 cmake-workflow-preset: windows_win32_release_no_ort_workflow enable-codeql: false @@ -33,6 +34,7 @@ jobs: uses: ./.github/workflows/reusable_windows_build.yml with: job-name: Winx64_Debug_NoOrt_CodeQL + runner-os: windows-2022 cmake-workflow-preset: windows_x64_debug_no_ort_workflow enable-codeql: true @@ -40,34 +42,31 @@ jobs: uses: ./.github/workflows/reusable_windows_build.yml with: job-name: Winx64_Release_NoOrt + runner-os: windows-2022 cmake-workflow-preset: windows_x64_release_no_ort_workflow enable-codeql: false - + WinX64_release: uses: ./.github/workflows/reusable_windows_build.yml with: job-name: Winx64_Release + runner-os: windows-2022 cmake-workflow-preset: windows_x64_release_workflow enable-codeql: false - # Windows ARM64 Jobs (New) + # Windows ARM64 Jobs WinARM64_debug_no_ort: uses: ./.github/workflows/reusable_windows_build.yml with: job-name: WinARM64_Debug_NoOrt_CodeQL - cmake-workflow-preset: windows_arm64_debug_no_ort_workflow # Ensure this preset exists + runner-os: windows-11-arm # Use ARM64 runner + cmake-workflow-preset: windows_arm64_debug_no_ort_workflow enable-codeql: true - WinARM64_release_no_ort: - uses: ./.github/workflows/reusable_windows_build.yml - with: - job-name: WinARM64_Release_NoOrt - cmake-workflow-preset: windows_arm64_release_no_ort_workflow # Ensure this preset exists - enable-codeql: false - WinARM64_release: uses: ./.github/workflows/reusable_windows_build.yml with: job-name: WinARM64_Release - cmake-workflow-preset: windows_arm64_release_workflow # Ensure this preset exists - enable-codeql: false + runner-os: windows-11-arm # Use ARM64 runner + cmake-workflow-preset: windows_arm64_release_workflow + enable-codeql: false \ No newline at end of file From 84ccb45904f2b8e0381821838329a79cec05552c Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 14:45:40 -0700 Subject: [PATCH 25/33] update --- tests/unittest/matrix_buffer.h | 105 ++++++++++++++++++++++++++++ tests/unittest/test_fgemm.h | 2 +- tests/unittest/test_main.cpp | 3 + tests/unittest/test_scaleoutput.cpp | 10 ++- tests/unittest/test_util.h | 6 +- 5 files changed, 120 insertions(+), 6 deletions(-) create mode 100644 tests/unittest/matrix_buffer.h diff --git a/tests/unittest/matrix_buffer.h b/tests/unittest/matrix_buffer.h new file mode 100644 index 0000000..9336402 --- /dev/null +++ b/tests/unittest/matrix_buffer.h @@ -0,0 +1,105 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include // For malloc, free, abort +#include // For size_t +#include // For std::function +#include // For std::fill_n +#include // For std::bad_alloc (alternative to abort) + +// Include crtdbg.h for _malloc_dbg and _free_dbg on Windows debug builds +#if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG) +#include +#endif + +template +class MatrixGuardBuffer { +public: + MatrixGuardBuffer() : + buffer_(nullptr), + elements_allocated_(0) { + } + + ~MatrixGuardBuffer() { + ReleaseBuffer(); + } + + // Disable copy and move semantics for simplicity + MatrixGuardBuffer(const MatrixGuardBuffer&) = delete; + MatrixGuardBuffer& operator=(const MatrixGuardBuffer&) = delete; + MatrixGuardBuffer(MatrixGuardBuffer&&) = delete; + MatrixGuardBuffer& operator=(MatrixGuardBuffer&&) = delete; + + T* GetFilledBuffer(size_t elements, const std::function& fill_func) { + if (elements == 0) { + ReleaseBuffer(); + return nullptr; + } + + if (elements > elements_allocated_) { + ReleaseBuffer(); + + size_t bytes_to_allocate = elements * sizeof(T); + if (elements != 0 && bytes_to_allocate / elements != sizeof(T)) { // Check for overflow before multiplication + // Handle overflow, e.g., by aborting or throwing + abort(); + } + + #if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG) + buffer_ = static_cast(_malloc_dbg(bytes_to_allocate, _NORMAL_BLOCK, __FILE__, __LINE__)); + #else + buffer_ = static_cast(malloc(bytes_to_allocate)); + #endif + + if (buffer_ == nullptr) { + // Consider `throw std::bad_alloc();` for C++ style error handling. + abort(); + } + elements_allocated_ = elements; + } + + if (fill_func && buffer_ != nullptr) { + fill_func(buffer_, elements); + } else if (buffer_ == nullptr && elements > 0) { + abort(); // Should not happen if allocation failure aborts + } + + return buffer_; + } + + T* GetBuffer(size_t elements, bool zero_fill = false) { + if (zero_fill) { + return GetFilledBuffer( + elements, + [](T* start, size_t count) { + if (start && count > 0) { + std::fill_n(start, count, T{}); // Value-initialize + } + }); + } + + return GetFilledBuffer( + elements, + [](T* start, size_t count) { + //do nothing, so that we can catch read uninitialized values errors + }); + } + + void ReleaseBuffer() { + if (buffer_ != nullptr) { + #if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG) + _free_dbg(buffer_, _NORMAL_BLOCK); + #else + free(buffer_); + #endif + buffer_ = nullptr; + } + elements_allocated_ = 0; + } + +private: + T* buffer_; + size_t elements_allocated_; +}; \ No newline at end of file diff --git a/tests/unittest/test_fgemm.h b/tests/unittest/test_fgemm.h index 2bd0941..3187d18 100644 --- a/tests/unittest/test_fgemm.h +++ b/tests/unittest/test_fgemm.h @@ -202,7 +202,7 @@ class MlasFgemmTest : public MlasTestBase { for (size_t m = 0; m < M; m++) { for (size_t n = 0; n < N; n++, f++) { // Sensitive to comparing positive/negative zero. - ASSERT_EQ(C[f], CReference[f]) + ASSERT_NEAR(C[f], CReference[f],1e-5) << " Diff @[" << batch << ", " << m << ", " << n << "] f=" << f << ", " << (Packed ? "Packed" : "NoPack") << "." << (Threaded ? "SingleThread" : "Threaded") << "/" diff --git a/tests/unittest/test_main.cpp b/tests/unittest/test_main.cpp index 505c0c0..2c1a094 100644 --- a/tests/unittest/test_main.cpp +++ b/tests/unittest/test_main.cpp @@ -57,6 +57,9 @@ bool AddTestRegister(TestRegister test_register) { } int main(int argc, char** argv) { + unsigned int current_control; + _controlfp_s(¤t_control, 0, 0); // Get current control word + _controlfp_s(¤t_control, ~(_EM_INVALID | _EM_ZERODIVIDE | _EM_DENORMAL), _MCW_EM); // Unmask exceptions bool is_short_execute = (argc <= 1 || strcmp("--long", argv[1]) != 0); std::cout << "-------------------------------------------------------" << std::endl; if (is_short_execute) { diff --git a/tests/unittest/test_scaleoutput.cpp b/tests/unittest/test_scaleoutput.cpp index 34f1784..13de844 100644 --- a/tests/unittest/test_scaleoutput.cpp +++ b/tests/unittest/test_scaleoutput.cpp @@ -22,7 +22,7 @@ class MlasScaleOutputTest : public MlasTestBase { std::numeric_limits::max()); for (size_t s = 0; s < M * N; s++) { - Input[s] = int_distribution(generator); + Input[s] = int_distribution(generator); //It could be zero Output[s] = OutputRef[s] = real_distribution(generator); } @@ -52,10 +52,14 @@ class MlasScaleOutputTest : public MlasTestBase { constexpr float epsilon = 1e-6f; for (size_t n = 0; n < M * N; n++) { - float diff = std::fabs((Output[n] - OutputRef[n]) / OutputRef[n]); + float outvalue = OutputRef[n]; // When `AccumulateMode` is false, there is a high chance that this value could be zero + float diff = std::fabs(Output[n] - outvalue) ; + if (outvalue != 0) { + diff /= outvalue; + } ASSERT_LE(diff, epsilon) << " @[" << n / N << "," << n % N << "], total:[" << M << "," << N << "], got:" - << Output[n] << ", expecting:" << OutputRef[n]; + << Output[n] << ", expecting:" << outvalue; } } diff --git a/tests/unittest/test_util.h b/tests/unittest/test_util.h index 94c4143..95a4ebe 100644 --- a/tests/unittest/test_util.h +++ b/tests/unittest/test_util.h @@ -37,7 +37,9 @@ #endif MLAS_THREADPOOL* GetMlasThreadPool(void); - +#ifdef BUILD_MLAS_NO_ONNXRUNTIME +#include "matrix_buffer.h" +#else template class MatrixGuardBuffer { public: @@ -163,7 +165,7 @@ class MatrixGuardBuffer { size_t _BaseBufferSize; T* _GuardAddress; }; - +#endif class MlasTestBase { public: virtual ~MlasTestBase(void) {} From 02ab9d532d0dae6cb4549f23fcc4f410164582c6 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 20:00:14 -0700 Subject: [PATCH 26/33] udpate --- CMakeLists.txt | 3 +++ tests/unittest/matrix_buffer.h | 15 +++++++-------- tests/unittest/test_fgemm.h | 7 ++++++- tests/unittest/test_main.cpp | 4 +--- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 0c00a53..15725a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -89,6 +89,9 @@ function(onnxruntime_configure_target target_name) $<$,$>:/MP> $<$,$>:/MP> ) + if(WIN32) + target_compile_options(${target_name} PRIVATE "$<$:/Zi>") + endif() endfunction() function(onnxruntime_add_executable target_name) diff --git a/tests/unittest/matrix_buffer.h b/tests/unittest/matrix_buffer.h index 9336402..4b58362 100644 --- a/tests/unittest/matrix_buffer.h +++ b/tests/unittest/matrix_buffer.h @@ -46,12 +46,11 @@ class MatrixGuardBuffer { // Handle overflow, e.g., by aborting or throwing abort(); } - - #if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG) - buffer_ = static_cast(_malloc_dbg(bytes_to_allocate, _NORMAL_BLOCK, __FILE__, __LINE__)); - #else - buffer_ = static_cast(malloc(bytes_to_allocate)); - #endif +#ifdef _WIN32 + buffer_ = static_cast(_aligned_malloc(bytes_to_allocate, 64)); +#else + buffer_ = static_cast(std::aligned_alloc(64, bytes_to_allocate)); +#endif if (buffer_ == nullptr) { // Consider `throw std::bad_alloc();` for C++ style error handling. @@ -89,8 +88,8 @@ class MatrixGuardBuffer { void ReleaseBuffer() { if (buffer_ != nullptr) { - #if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG) - _free_dbg(buffer_, _NORMAL_BLOCK); + #if defined(_WIN32) + _aligned_free(buffer_); #else free(buffer_); #endif diff --git a/tests/unittest/test_fgemm.h b/tests/unittest/test_fgemm.h index 3187d18..347e9c6 100644 --- a/tests/unittest/test_fgemm.h +++ b/tests/unittest/test_fgemm.h @@ -195,14 +195,19 @@ class MlasFgemmTest : public MlasTestBase { std::fill_n(C, M * N * BatchSize, -0.5f); std::fill_n(CReference, M * N * BatchSize, -0.5f); + static constexpr float rtol = 1e-5f; + static constexpr float atol = 1e-8f; + PackedContext.TestGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, C, ldc, threadpool_); ReferenceGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, CReference, ldc); for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { for (size_t m = 0; m < M; m++) { for (size_t n = 0; n < N; n++, f++) { + T tolerance = atol + rtol * std::abs(CReference[f]); + // Sensitive to comparing positive/negative zero. - ASSERT_NEAR(C[f], CReference[f],1e-5) + ASSERT_NEAR(C[f], CReference[f], tolerance) << " Diff @[" << batch << ", " << m << ", " << n << "] f=" << f << ", " << (Packed ? "Packed" : "NoPack") << "." << (Threaded ? "SingleThread" : "Threaded") << "/" diff --git a/tests/unittest/test_main.cpp b/tests/unittest/test_main.cpp index 2c1a094..3a11b91 100644 --- a/tests/unittest/test_main.cpp +++ b/tests/unittest/test_main.cpp @@ -57,9 +57,7 @@ bool AddTestRegister(TestRegister test_register) { } int main(int argc, char** argv) { - unsigned int current_control; - _controlfp_s(¤t_control, 0, 0); // Get current control word - _controlfp_s(¤t_control, ~(_EM_INVALID | _EM_ZERODIVIDE | _EM_DENORMAL), _MCW_EM); // Unmask exceptions + bool is_short_execute = (argc <= 1 || strcmp("--long", argv[1]) != 0); std::cout << "-------------------------------------------------------" << std::endl; if (is_short_execute) { From bd7fd7a9f1558505d2ea4b50a515d78e1e45cd1b Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 20:15:34 -0700 Subject: [PATCH 27/33] update --- tests/unittest/matrix_buffer.h | 16 +++++++++++++++- tests/unittest/test_conv2d.h | 29 ++++++++++++++++------------- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/tests/unittest/matrix_buffer.h b/tests/unittest/matrix_buffer.h index 4b58362..3fb171b 100644 --- a/tests/unittest/matrix_buffer.h +++ b/tests/unittest/matrix_buffer.h @@ -82,7 +82,21 @@ class MatrixGuardBuffer { return GetFilledBuffer( elements, [](T* start, size_t count) { - //do nothing, so that we can catch read uninitialized values errors + constexpr float offset = -21.f; + constexpr float range = 43.f; + + // The following value will be used in most GEMM/CONV tests. Because this value is an integer that is + // small enough, all the floating point operations will generate exact values instead of approximate + // values. + float FillValue = 11.f; + T* FillAddress = start; + for (size_t i = 0; i < count; i++) { + auto itemv = FillValue - offset; + *FillAddress++ = (T)(itemv); + + FillValue += 7.f; + FillValue = FillValue >= range ? FillValue - range : FillValue; + } }); } diff --git a/tests/unittest/test_conv2d.h b/tests/unittest/test_conv2d.h index 20bf0ec..4382ae4 100644 --- a/tests/unittest/test_conv2d.h +++ b/tests/unittest/test_conv2d.h @@ -245,19 +245,22 @@ class MlasConv2DTest : public MlasTestBase { Filter, Bias, OutputReference); - - ASSERT_EQ(memcmp(Output, OutputReference, OutputElements * sizeof(float)), 0) - << "B" << BatchCount << "/" - << "G" << GroupCount << "/" - << "Cpg" << InputChannels << "/" - << "Fpg" << FilterCount << "/" - << "H" << InputHeight << "/" - << "W" << InputWidth << "/" - << "KH" << KernelHeight << "/" - << "KW" << KernelWidth << "/" - << "Pad" << PaddingLeftHeight << "," << PaddingLeftWidth << "," << PaddingRightHeight << "," << PaddingRightWidth << "/" - << "Dilation" << DilationHeight << "," << DilationWidth << "/" - << "Stride" << StrideHeight << "," << StrideWidth; + static constexpr float rtol = 1e-4f; + static constexpr float atol = 1e-6f; + for (size_t i = 0; i != OutputElements; ++i) { + float tolerance = atol + rtol * std::abs(OutputReference[i]); + ASSERT_NEAR(Output[i], OutputReference[i], tolerance) << "B" << BatchCount << "/" + << "G" << GroupCount << "/" + << "Cpg" << InputChannels << "/" + << "Fpg" << FilterCount << "/" + << "H" << InputHeight << "/" + << "W" << InputWidth << "/" + << "KH" << KernelHeight << "/" + << "KW" << KernelWidth << "/" + << "Pad" << PaddingLeftHeight << "," << PaddingLeftWidth << "," << PaddingRightHeight << "," << PaddingRightWidth << "/" + << "Dilation" << DilationHeight << "," << DilationWidth << "/" + << "Stride" << StrideHeight << "," << StrideWidth; + } } void ExecuteLong(void) override { From 1246e51fa26d2579a7196bae44932a75d9924545 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Sat, 31 May 2025 20:16:02 -0700 Subject: [PATCH 28/33] update --- tests/unittest/matrix_buffer.h | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/unittest/matrix_buffer.h b/tests/unittest/matrix_buffer.h index 3fb171b..0af513e 100644 --- a/tests/unittest/matrix_buffer.h +++ b/tests/unittest/matrix_buffer.h @@ -9,11 +9,6 @@ #include // For std::fill_n #include // For std::bad_alloc (alternative to abort) -// Include crtdbg.h for _malloc_dbg and _free_dbg on Windows debug builds -#if defined(_WIN32) && !defined(NDEBUG) && defined(_DEBUG) -#include -#endif - template class MatrixGuardBuffer { public: From 163eaf0aabf237bb2f71cdc4447fac9e5a7ca115 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 10 Oct 2025 00:22:45 -0700 Subject: [PATCH 29/33] update --- CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 15725a7..9bb06b5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -15,7 +15,6 @@ cmake_policy(SET CMP0117 NEW) # Project project(MLAS C CXX) - include(CheckCXXCompilerFlag) include(CheckLanguage) include(CMakeDependentOption) From 63237fcb705bb8824b8dcf4682da74d0fea79e89 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 10 Oct 2025 00:30:06 -0700 Subject: [PATCH 30/33] Fix: Disable CodeQL for Windows ARM64 jobs --- .github/workflows/win_ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/win_ci.yml b/.github/workflows/win_ci.yml index f1bcf6a..23639c0 100644 --- a/.github/workflows/win_ci.yml +++ b/.github/workflows/win_ci.yml @@ -58,10 +58,10 @@ jobs: WinARM64_debug_no_ort: uses: ./.github/workflows/reusable_windows_build.yml with: - job-name: WinARM64_Debug_NoOrt_CodeQL + job-name: WinARM64_Debug_NoOrt runner-os: windows-11-arm # Use ARM64 runner cmake-workflow-preset: windows_arm64_debug_no_ort_workflow - enable-codeql: true + enable-codeql: false WinARM64_release: uses: ./.github/workflows/reusable_windows_build.yml From 3876c87229967a18b86661ae6b5abc2a6e8b28f3 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 17 Oct 2025 20:37:10 -0700 Subject: [PATCH 31/33] Sync with ONNX Runtime commit id 7cc28b0be3486a1fc22e48ddb76f754bec469a3f --- cmake/deps.txt | 4 +- cmake/external_deps.cmake | 17 +- include/mlas.h | 148 +- include/mlas_gemm_postprocessor.h | 2 +- include/mlas_q4.h | 4 +- include/mlas_qnbit.h | 11 + src/common/cpuid_info.cc | 84 +- src/common/cpuid_info_vendor.cc | 244 +++ src/common/helper.cc | 4 +- src/common/safeint.h | 4 +- src/common/semver.h | 33 + src/common/spin_pause.cc | 36 + src/common/string_utils.h | 9 +- src/lib/CMakeLists.txt | 118 +- src/lib/activate.cpp | 2 +- src/lib/cast.cpp | 54 + src/lib/compute.cpp | 46 +- src/lib/convolve.cpp | 131 +- src/lib/convsym.cpp | 3 +- src/lib/dequantize.cpp | 395 +++++ src/lib/dgemm.cpp | 4 +- src/lib/erf.cpp | 3 +- src/lib/kleidiai/convolve_kleidiai.cpp | 720 +++++++++ src/lib/kleidiai/mlasi_kleidiai.h | 151 ++ src/lib/kleidiai/qgemm_kleidiai.cpp | 116 ++ src/lib/kleidiai/sgemm_kleidiai.cpp | 432 +++++ src/lib/logistic.cpp | 9 +- src/lib/mlasi.h | 354 ++++- src/lib/platform.cpp | 96 +- src/lib/pooling.cpp | 2 +- src/lib/power/qgemm_kernel_power10.cpp | 268 +++- src/lib/q4gemm.h | 2 +- src/lib/q4gemm_avx512.cpp | 1 + src/lib/qgemm.cpp | 85 +- src/lib/qgemm.h | 4 + src/lib/qgemm_kernel_amx.cpp | 1 + src/lib/qgemm_kernel_wasmrelaxedsimd.cpp | 325 ++-- src/lib/qgemm_kernel_wasmsimd.cpp | 288 ++-- src/lib/qladd.cpp | 107 ++ src/lib/qladd.h | 95 ++ src/lib/qlgavgpool.cpp | 2 +- src/lib/qlmul.cpp | 92 ++ src/lib/qnbitgemm.cpp | 129 +- src/lib/qnbitgemm.h | 566 +++++++ src/lib/qnbitgemm_kernel_neon.cpp | 313 +++- src/lib/qnbitgemm_kernel_neon.h | 30 + src/lib/quantize.cpp | 202 ++- src/lib/rotary_embedding_kernel_avx2.cpp | 7 +- src/lib/s390x/DgemmKernel.cpp | 87 + src/lib/s390x/DgemmKernelZVECTOR.h | 122 ++ src/lib/s390x/FgemmKernelZVECTOR.h | 333 ++++ src/lib/s390x/Quantize.cpp | 300 ++++ src/lib/s390x/QuantizeZVECTOR.cpp | 149 ++ src/lib/s390x/SgemmKernel.cpp | 87 + src/lib/s390x/SgemmKernelZVECTOR.cpp | 451 ++++++ src/lib/s390x/SgemmKernelZVECTOR.h | 139 ++ src/lib/s390x/qgemm_kernel_zvector.cpp | 1409 +++++++++++++++++ src/lib/sconv.h | 29 + src/lib/sconv_kernel_neon.cpp | 524 ++++++ src/lib/sgemm.cpp | 40 +- src/lib/snchwc.cpp | 18 +- src/lib/spool_kernel_neon.cpp | 293 ++++ src/lib/sqnbitgemm_kernel_avx2.cpp | 5 +- .../sqnbitgemm_kernel_avx2_int8_blklen32.h | 17 + src/lib/sqnbitgemm_kernel_avx512.cpp | 4 +- src/lib/sqnbitgemm_kernel_avx512vnni.cpp | 3 +- src/lib/sqnbitgemm_kernel_avx_common.h | 3 +- src/lib/sqnbitgemm_kernel_lasx.cpp | 1089 +++++++++++++ src/lib/sqnbitgemm_kernel_lasx_common.h | 514 ++++++ src/lib/sqnbitgemm_kernel_neon_int8.cpp | 941 +++++++++++ src/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp | 743 +++++++++ src/lib/sve/elementwise_sve.cpp | 683 ++++++++ src/lib/sve/mlasi_sve.h | 653 ++++++++ src/lib/transpose.cpp | 180 ++- src/ort_include/core/common/common.h | 20 +- .../core/common/const_pointer_container.h | 4 + .../core/common/cpuid_arch_definition.h | 2 +- src/ort_include/core/common/cpuid_info.h | 72 +- .../core/{framework => common}/endian.h | 0 .../core/{framework => common}/float16.h | 4 +- src/ort_include/core/common/float8.h | 937 +++++++++++ src/ort_include/core/common/parse_string.h | 23 +- src/ort_include/core/common/path_string.h | 14 + src/ort_include/core/common/status.h | 5 + src/ort_include/core/common/string_helper.h | 6 +- src/ort_include/core/framework/callback.h | 74 - .../platform/EigenNonBlockingThreadPool.h | 119 +- src/ort_include/core/platform/env.h | 1 - src/ort_include/core/platform/env_var_utils.h | 4 +- src/ort_include/core/platform/threadpool.h | 10 +- .../core/session/onnxruntime_c_api.h | 140 +- src/ort_include/core/util/thread_utils.h | 39 +- tests/bench/bench_cast.cpp | 2 +- tests/bench/bench_computesoftmax.cpp | 4 +- tests/bench/bench_fp16_neon_common.cpp | 2 +- tests/bench/bench_rope.cpp | 2 +- tests/bench/bench_sconv.cpp | 109 ++ tests/bench/bench_sgemm.cpp | 7 +- tests/bench/bench_util.h | 2 +- tests/unittest/test_conv2d.h | 29 +- tests/unittest/test_dequantizelinear.cpp | 75 + tests/unittest/test_dynamic_qgemm.cpp | 172 ++ tests/unittest/test_fgemm.h | 13 +- tests/unittest/test_fgemm_fixture.h | 1 + tests/unittest/test_halfgemm.h | 3 +- tests/unittest/test_main.cpp | 1 - tests/unittest/test_rope.cpp | 3 +- tests/unittest/test_scaleoutput.cpp | 6 +- tests/unittest/test_softmax.cpp | 4 +- tests/unittest/test_sq8bitgemm.cpp | 416 ++++- 110 files changed, 15925 insertions(+), 969 deletions(-) create mode 100644 src/common/cpuid_info_vendor.cc create mode 100644 src/common/semver.h create mode 100644 src/common/spin_pause.cc create mode 100644 src/lib/dequantize.cpp create mode 100644 src/lib/kleidiai/convolve_kleidiai.cpp create mode 100644 src/lib/kleidiai/mlasi_kleidiai.h create mode 100644 src/lib/kleidiai/qgemm_kleidiai.cpp create mode 100644 src/lib/kleidiai/sgemm_kleidiai.cpp create mode 100644 src/lib/qnbitgemm.h create mode 100644 src/lib/s390x/DgemmKernel.cpp create mode 100644 src/lib/s390x/DgemmKernelZVECTOR.h create mode 100644 src/lib/s390x/FgemmKernelZVECTOR.h create mode 100644 src/lib/s390x/Quantize.cpp create mode 100644 src/lib/s390x/QuantizeZVECTOR.cpp create mode 100644 src/lib/s390x/SgemmKernel.cpp create mode 100644 src/lib/s390x/SgemmKernelZVECTOR.cpp create mode 100644 src/lib/s390x/SgemmKernelZVECTOR.h create mode 100644 src/lib/s390x/qgemm_kernel_zvector.cpp create mode 100644 src/lib/sconv.h create mode 100644 src/lib/sconv_kernel_neon.cpp create mode 100644 src/lib/spool_kernel_neon.cpp create mode 100644 src/lib/sqnbitgemm_kernel_lasx.cpp create mode 100644 src/lib/sqnbitgemm_kernel_lasx_common.h create mode 100644 src/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp create mode 100644 src/lib/sve/elementwise_sve.cpp create mode 100644 src/lib/sve/mlasi_sve.h rename src/ort_include/core/{framework => common}/endian.h (100%) rename src/ort_include/core/{framework => common}/float16.h (99%) create mode 100644 src/ort_include/core/common/float8.h delete mode 100644 src/ort_include/core/framework/callback.h create mode 100644 tests/unittest/test_dequantizelinear.cpp create mode 100644 tests/unittest/test_dynamic_qgemm.cpp diff --git a/cmake/deps.txt b/cmake/deps.txt index 7a9a0bf..eba65ee 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -1,5 +1,5 @@ -eigen;https://gitlab.com/libeigen/eigen/-/archive/ff174f79264d3f8dc0115dea7a288f98208b694f/eigen-ff174f79264d3f8dc0115dea7a288f98208b694f.zip;666e2f940faeef0196e72617a5d01241a22b67f3 +eigen;https://github.com/eigen-mirror/eigen/archive/1d8b82b0740839c0de7f1242a3585e3390ff5f33/eigen-1d8b82b0740839c0de7f1242a3585e3390ff5f33.zip;05b19b49e6fbb91246be711d801160528c135e34 +kleidiai;https://github.com/ARM-software/kleidiai/archive/refs/tags/v1.10.0.tar.gz;11b62149cb2514b3b9069cc435c3aa7a4e82b97a microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 googletest;https://github.com/google/googletest/archive/refs/tags/v1.17.0.zip;f638fa0e724760e2ba07ff8cfba32cd644e1ce28 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.8.5.zip;cd47d3d272faf353600c8cc2fdec2b52d6f69177 -microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 diff --git a/cmake/external_deps.cmake b/cmake/external_deps.cmake index 3e6495f..ecee953 100644 --- a/cmake/external_deps.cmake +++ b/cmake/external_deps.cmake @@ -74,14 +74,22 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME AND BUILD_TESTING) set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + # Disable the KleidiAI tests + set(KLEIDIAI_BUILD_TESTS OFF) + + onnxruntime_fetchcontent_declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai} EXCLUDE_FROM_ALL) + onnxruntime_fetchcontent_makeavailable(kleidiai) + + # gtest and gmock - FetchContent_Declare( + onnxruntime_fetchcontent_declare( googletest URL ${DEP_URL_googletest} URL_HASH SHA1=${DEP_SHA1_googletest} + EXCLUDE_FROM_ALL FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest ) - FetchContent_MakeAvailable(googletest) + onnxruntime_fetchcontent_makeavailable(googletest) #google benchmark doesn't work for Emscripten if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") message("CMAKE_SYSTEM_NAME: ${CMAKE_SYSTEM_NAME}") @@ -90,12 +98,13 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME AND BUILD_TESTING) # We will not need to install benchmark since we link it statically. set(BENCHMARK_ENABLE_INSTALL OFF CACHE BOOL "Disable benchmark install to avoid overwriting vendor install.") - FetchContent_Declare( + onnxruntime_fetchcontent_declare( google_benchmark URL ${DEP_URL_google_benchmark} URL_HASH SHA1=${DEP_SHA1_google_benchmark} + EXCLUDE_FROM_ALL FIND_PACKAGE_ARGS NAMES benchmark ) onnxruntime_fetchcontent_makeavailable(google_benchmark) endif() -endif() \ No newline at end of file +endif() diff --git a/include/mlas.h b/include/mlas.h index 2663709..40361be 100644 --- a/include/mlas.h +++ b/include/mlas.h @@ -57,6 +57,9 @@ Module Name: #if defined(MLAS_TARGET_ARM64) || defined(MLAS_TARGET_ARM64EC) || defined(MLAS_TARGET_ARM) #define MLAS_TARGET_ARM_ANY #endif +#if defined(__s390x__) +#define MLAS_TARGET_S390X +#endif #if defined(__VSX__) #define MLAS_TARGET_POWER @@ -80,7 +83,7 @@ Module Name: // Define the support levels for the target architecture. // -#if defined(MLAS_TARGET_AMD64) || defined (MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_ZVECTOR) #define MLAS_SUPPORTS_GEMM_DOUBLE #endif @@ -631,6 +634,49 @@ MlasGemm( { MlasGemmBatch(Shape, &DataParams, 1, ThreadPool); } +/** + * @brief Parameters that define the shape of a dynamically quantized GEMM operation. + * + * The structure holds the dimensions of the matrices involved in the GEMM + * computation: + * C = A * B + */ +struct MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS { + size_t M = 0; /**< Row size of matrix A */ + size_t N = 0; /**< Column size of matrix B */ + size_t K = 0; /**< Column size of matrix A and Row size of matrix B */ +}; +/** + * @brief Parameters that define the data buffers and layout for a dynamic quant GEMM. + * + * This structure provides the memory pointers and strides for matrices + * involved in a dynamically quantized GEMM operation, along with the packed B format. + */ +struct MLAS_GEMM_DYN_QUANT_DATA_PARAMS { + const float* A = nullptr; /**< Pointer to input matrix A in FP32 format**/ + size_t lda = 0; /**< Number of elements between adjecent rows in A*/ + const void* PackedB = 0; /**< Points to packed weight matrix B */ + float* C = nullptr; /**< Points to output Matric C */ + size_t ldc = 0; /**< Number of elements between adjecent rows in Matrix C*/ + void* Workspace = nullptr; /**< Workspace buffer for LHS Packing Allocation */ + size_t WorkspaceSize = 0; /**< Workspace buffer size */ +}; + +void + MLASCALL + MlasDynamicQGemmBatch( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool); + +inline void +MlasDynamicQGemm( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + MLAS_THREADPOOL* ThreadPool) { + MlasDynamicQGemmBatch(Shape, DataParams, 1, ThreadPool); +} // // Symmetric QGEMM has limited buffer overrun. @@ -683,22 +729,23 @@ MlasSymmQgemmBatch( // size_t -MLASCALL -MlasGemmPackBSize( - size_t N, - size_t K - ); + MLASCALL + MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); void -MLASCALL -MlasGemmPackB( - CBLAS_TRANSPOSE TransB, - size_t N, - size_t K, - const float* B, - size_t ldb, - void* PackedB - ); + MLASCALL + MlasGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); size_t MLASCALL @@ -750,6 +797,22 @@ MlasSymmQgemmPackB( void* PackedB ); +size_t + MLASCALL + MlasDynamicQgemmPackBSize( + size_t N, + size_t K); + +void + MLASCALL + MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB); + // // Convolution routines. // @@ -1012,16 +1075,16 @@ MlasComputeLogistic( template void -MLASCALL -MlasComputeSoftmax( - const T* Input, - T* Output, - size_t N, - size_t D, - bool LogSoftmax, - bool SmoothSoftmax, - MLAS_THREADPOOL* ThreadPool - ); + MLASCALL + MlasComputeSoftmax( + const T* Input, + T* Output, + size_t N, + size_t D, + bool LogSoftmax, + bool SmoothSoftmax, + float Sink, + MLAS_THREADPOOL* ThreadPool); template void @@ -1223,6 +1286,20 @@ MlasQuantizeLinearS4( int8_t ZeroPoint ); +// +// Linear dequantization routines. +// + +template +void + MLASCALL + MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint); + /** * @brief Requantize a block of the intermediate buffer to the output buffer, * optionally adding the supplied bias @@ -1419,6 +1496,16 @@ MlasConvertHalfToFloatBuffer( size_t Count ); +#define MLAS_MIN_TENSOR_SIZE_FOR_HALF_TO_FLOAT_CONVERSION_IN_PARALLEL 128000 + +void + MLASCALL + MlasConvertHalfToFloatBufferInParallel( + const MLAS_FP16* Source, + float* Destination, + size_t Count, + MLAS_THREADPOOL* ThreadPool); + void MLASCALL MlasConvertFloatToHalfBuffer( @@ -1997,3 +2084,14 @@ MlasFlashAttention( MlasFlashAttentionThreadedArgs* args, MLAS_THREADPOOL* ThreadPool ); + +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +/** + * @brief Function to override the packing mechanism decision if kleidi ai is included + * @param enable enable kleidiai packing (allow or disallow depending on true/false) + * @return + */ +void + MLASCALL + MlasGemmBatchPackUseKleidi(bool enable); +#endif diff --git a/include/mlas_gemm_postprocessor.h b/include/mlas_gemm_postprocessor.h index 7f5ec05..7ea29eb 100644 --- a/include/mlas_gemm_postprocessor.h +++ b/include/mlas_gemm_postprocessor.h @@ -16,7 +16,7 @@ Module Name: --*/ #pragma once -#include + template class MLAS_GEMM_POSTPROCESSOR { diff --git a/include/mlas_q4.h b/include/mlas_q4.h index c5f846f..b43f089 100644 --- a/include/mlas_q4.h +++ b/include/mlas_q4.h @@ -286,9 +286,7 @@ MlasBlockwiseQuantizedBufferSizes( int columns, size_t& q_data_size_in_bytes, size_t& q_scale_num_elements, - size_t* q_zero_point_size_in_bytes -); - + size_t* q_zero_point_size_in_bytes); /** * @brief Blockwise 4 bits quantization, resulting elements and quantization diff --git a/include/mlas_qnbit.h b/include/mlas_qnbit.h index 3627989..2a1e1fc 100644 --- a/include/mlas_qnbit.h +++ b/include/mlas_qnbit.h @@ -48,6 +48,17 @@ struct MLAS_QNBIT_GEMM_DATA_PARAMS { const T* QuantBScale = nullptr; ///< address of scale values of quantized B, one per block const void* QuantBZeroPoint = nullptr; ///< optional address of zero point values of quantized B, one per block const T* QuantBBlkSum = nullptr; ///< optional address of scale * zp, one per block + + /// + /// Address of scale * accumulate(quant - zp), one per block, where `scale`, `quant`, `zp` are respectively + /// an individual block's scale, quantized values, and zero point for the input `B`. + /// When converting the activation input (A) to uint8, we first convert the values to int8 and then + /// add a "bias" of +128 to convert the range of values from [-128, +127] to [0, +255]. + /// This input helps to "de-bias" the output of the +128 bias added to the activation input. + /// This input is to be used only when A is quantized to uint8. + /// + const T* BlkUnsignedQuantAZeroPointCorrection = nullptr; + const T* Bias = nullptr; ///< optional address of Bias, vector size N T* C = nullptr; ///< address of result matrix size_t ldc = 0; ///< leading dimension of C diff --git a/src/common/cpuid_info.cc b/src/common/cpuid_info.cc index 0b675fc..0d996a0 100644 --- a/src/common/cpuid_info.cc +++ b/src/common/cpuid_info.cc @@ -1,8 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "core/common/cpuid_info.h" + +#include +#include + #include "core/common/logging/logging.h" #include "core/common/logging/severity.h" +#include "core/platform/check_intel.h" #ifdef __linux__ #if (defined(_M_AMD64) || defined(__x86_64__)) && !defined(__ANDROID__) @@ -24,6 +29,10 @@ #define HWCAP_ASIMDDP (1 << 20) #endif +#ifndef HWCAP_SVE +#define HWCAP_SVE (1 << 22) +#endif + #ifndef HWCAP2_I8MM #define HWCAP2_I8MM (1 << 13) #endif @@ -50,6 +59,14 @@ #endif // _WIN32 +#if defined(__APPLE__) +#if defined(CPUIDINFO_ARCH_ARM) + +#include + +#endif // defined(CPUIDINFO_ARCH_ARM) +#endif // defined(__APPLE__) + #if defined(CPUINFO_SUPPORTED) #include #if defined(CPUIDINFO_ARCH_ARM) @@ -73,6 +90,14 @@ void decodeMIDR(uint32_t midr, uint32_t uarch[1]); namespace onnxruntime { +void CPUIDInfo::LogEarlyWarning(std::string_view message) { + if (logging::LoggingManager::HasDefaultLogger()) { + LOGS_DEFAULT(WARNING) << message; + } else { + std::cerr << "onnxruntime cpuid_info warning: " << message << "\n"; + } +} + #if defined(CPUIDINFO_ARCH_X86) static inline void GetCPUID(int function_id, int data[4]) { // NOLINT @@ -107,9 +132,6 @@ void CPUIDInfo::X86Init() { int data[4] = {-1}; GetCPUID(0, data); - vendor_ = GetX86Vendor(data); - vendor_id_ = GetVendorId(vendor_); - int num_IDs = data[0]; if (num_IDs >= 1) { GetCPUID(1, data); @@ -126,7 +148,9 @@ void CPUIDInfo::X86Init() { has_f16c_ = has_avx_ && (data[2] & (1 << 29)) && (data[3] & (1 << 26)); if (num_IDs >= 7) { - GetCPUID(7, data); + // This change is made to overcome the issue of __get_cpuid returning all zeros, instead use __get_cpuid_count. + // Reference: https://stackoverflow.com/questions/46272579/why-does-get-cpuid-return-all-zeros-for-leaf-4 + GetCPUID(7, 0, data); const uint32_t max_SubLeaves = data[0]; has_amx_bf16_ = (data[3] & (1 << 22)); has_avx2_ = has_avx_ && (data[1] & (1 << 5)); @@ -135,6 +159,7 @@ void CPUIDInfo::X86Init() { // avx512_skylake = avx512f | avx512vl | avx512cd | avx512bw | avx512dq has_avx512_skylake_ = has_avx512 && (data[1] & ((1 << 16) | (1 << 17) | (1 << 28) | (1 << 30) | (1 << 31))); is_hybrid_ = (data[3] & (1 << 15)); + if (max_SubLeaves >= 1) { GetCPUID(7, 1, data); has_avx512_bf16_ = has_avx512 && (data[0] & (1 << 5)); @@ -144,24 +169,8 @@ void CPUIDInfo::X86Init() { } } -std::string CPUIDInfo::GetX86Vendor(int32_t* data) { - char vendor[sizeof(int32_t) * 3 + 1]{}; - *reinterpret_cast(vendor + 0) = data[1]; - *reinterpret_cast(vendor + 4) = data[3]; - *reinterpret_cast(vendor + 8) = data[2]; - return vendor; -} - #endif // defined(CPUIDINFO_ARCH_X86) -uint32_t CPUIDInfo::GetVendorId(const std::string& vendor) { - if (vendor == "GenuineIntel") return 0x8086; - if (vendor == "GenuineAMD") return 0x1022; - if (vendor.find("Qualcomm") == 0) return 'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24); - if (vendor.find("NV") == 0) return 0x10DE; - return 0; -} - #if defined(CPUIDINFO_ARCH_ARM) #if defined(__linux__) @@ -174,8 +183,12 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + // SVE is enabled only on Linux-based ARM CPUs for now, where it has been tested. + has_arm_sve_ = cpuinfo_has_arm_sve(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + has_arm_sme_ = cpuinfo_has_arm_sme(); + has_arm_sme2_ = cpuinfo_has_arm_sme2(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -204,6 +217,7 @@ void CPUIDInfo::ArmLinuxInit() { has_fp16_ |= has_arm_neon_dot_; has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); + has_arm_sve_ = ((getauxval(AT_HWCAP) & HWCAP_SVE) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); @@ -213,10 +227,6 @@ void CPUIDInfo::ArmLinuxInit() { #elif defined(_WIN32) // ^ defined(__linux__) void CPUIDInfo::ArmWindowsInit() { - // Get the ARM vendor string from the registry - vendor_ = GetArmWindowsVendor(); - vendor_id_ = GetVendorId(vendor_); - // Read MIDR and ID_AA64ISAR1_EL1 register values from Windows registry // There should be one per CPU std::vector midr_values{}, id_aa64isar1_el1_values{}; @@ -308,15 +318,6 @@ void CPUIDInfo::ArmWindowsInit() { #endif // defined(CPUINFO_SUPPORTED) } -std::string CPUIDInfo::GetArmWindowsVendor() { - const int MAX_VALUE_NAME = 256; - const CHAR vendorKey[] = "HARDWARE\\DESCRIPTION\\System\\CentralProcessor\\0"; - CHAR vendorVal[MAX_VALUE_NAME] = ""; - unsigned long vendorSize = sizeof(char) * MAX_VALUE_NAME; - ::RegGetValueA(HKEY_LOCAL_MACHINE, vendorKey, "Vendor Identifier", RRF_RT_REG_SZ | RRF_ZEROONFAILURE, nullptr, &vendorVal, &vendorSize); - return vendorVal; -} - #elif defined(__APPLE__) // ^ defined(_WIN32) void CPUIDInfo::ArmAppleInit() { @@ -328,6 +329,8 @@ void CPUIDInfo::ArmAppleInit() { has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); + has_arm_sme_ = cpuinfo_has_arm_sme(); + has_arm_sme2_ = cpuinfo_has_arm_sme2(); // Note: We leave is_armv8_narrow_ld_ unset because it only applies to a limited set of uarchs that we don't expect // to encounter on Apple platforms. @@ -360,16 +363,21 @@ uint32_t CPUIDInfo::GetCurrentCoreIdx() const { } CPUIDInfo::CPUIDInfo() { -#ifdef CPUIDINFO_ARCH_X86 - X86Init(); -#elif defined(CPUIDINFO_ARCH_ARM) #if defined(CPUINFO_SUPPORTED) pytorch_cpuinfo_init_ = cpuinfo_initialize(); if (!pytorch_cpuinfo_init_) { - std::cout << "Failed to initialize PyTorch cpuinfo library. May cause CPU EP performance degradation " - "due to undetected CPU features."; + LogEarlyWarning( + "Failed to initialize PyTorch cpuinfo library. May cause CPU EP performance degradation due to undetected CPU " + "features."); } #endif // defined(CPUINFO_SUPPORTED) + + // Note: This should be run after cpuinfo initialization if cpuinfo is enabled. + VendorInfoInit(); + +#ifdef CPUIDINFO_ARCH_X86 + X86Init(); +#elif defined(CPUIDINFO_ARCH_ARM) #if defined(__linux__) ArmLinuxInit(); #elif defined(_WIN32) diff --git a/src/common/cpuid_info_vendor.cc b/src/common/cpuid_info_vendor.cc new file mode 100644 index 0000000..d4d940e --- /dev/null +++ b/src/common/cpuid_info_vendor.cc @@ -0,0 +1,244 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/cpuid_info.h" + +#include +#include +#include + +#if defined(CPUINFO_SUPPORTED) +#include "cpuinfo.h" +#endif + +namespace { + +#if !defined(CPUINFO_SUPPORTED) + +// The `cpuinfo_vendor` enum is defined by the cpuinfo library. +// In case we don't build with cpuinfo, we define our own copy. +// The enum was copied from here: +// https://github.com/pytorch/cpuinfo/blob/8a1772a0c5c447df2d18edf33ec4603a8c9c04a6/include/cpuinfo.h#L154-L307 + +/** Vendor of processor core design */ +enum cpuinfo_vendor { + /** Processor vendor is not known to the library, or the library failed + to get vendor information from the OS. */ + cpuinfo_vendor_unknown = 0, + + /* Active vendors of modern CPUs */ + + /** + * Intel Corporation. Vendor of x86, x86-64, IA64, and ARM processor + * microarchitectures. + * + * Sold its ARM design subsidiary in 2006. The last ARM processor design + * was released in 2004. + */ + cpuinfo_vendor_intel = 1, + /** Advanced Micro Devices, Inc. Vendor of x86 and x86-64 processor + microarchitectures. */ + cpuinfo_vendor_amd = 2, + /** ARM Holdings plc. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_arm = 3, + /** Qualcomm Incorporated. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_qualcomm = 4, + /** Apple Inc. Vendor of ARM and ARM64 processor microarchitectures. */ + cpuinfo_vendor_apple = 5, + /** Samsung Electronics Co., Ltd. Vendir if ARM64 processor + microarchitectures. */ + cpuinfo_vendor_samsung = 6, + /** Nvidia Corporation. Vendor of ARM64-compatible processor + microarchitectures. */ + cpuinfo_vendor_nvidia = 7, + /** MIPS Technologies, Inc. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_mips = 8, + /** International Business Machines Corporation. Vendor of PowerPC + processor microarchitectures. */ + cpuinfo_vendor_ibm = 9, + /** Ingenic Semiconductor. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_ingenic = 10, + /** + * VIA Technologies, Inc. Vendor of x86 and x86-64 processor + * microarchitectures. + * + * Processors are designed by Centaur Technology, a subsidiary of VIA + * Technologies. + */ + cpuinfo_vendor_via = 11, + /** Cavium, Inc. Vendor of ARM64 processor microarchitectures. */ + cpuinfo_vendor_cavium = 12, + /** Broadcom, Inc. Vendor of ARM processor microarchitectures. */ + cpuinfo_vendor_broadcom = 13, + /** Applied Micro Circuits Corporation (APM). Vendor of ARM64 processor + microarchitectures. */ + cpuinfo_vendor_apm = 14, + /** + * Huawei Technologies Co., Ltd. Vendor of ARM64 processor + * microarchitectures. + * + * Processors are designed by HiSilicon, a subsidiary of Huawei. + */ + cpuinfo_vendor_huawei = 15, + /** + * Hygon (Chengdu Haiguang Integrated Circuit Design Co., Ltd), Vendor + * of x86-64 processor microarchitectures. + * + * Processors are variants of AMD cores. + */ + cpuinfo_vendor_hygon = 16, + /** SiFive, Inc. Vendor of RISC-V processor microarchitectures. */ + cpuinfo_vendor_sifive = 17, + + /* Active vendors of embedded CPUs */ + + /** Texas Instruments Inc. Vendor of ARM processor microarchitectures. + */ + cpuinfo_vendor_texas_instruments = 30, + /** Marvell Technology Group Ltd. Vendor of ARM processor + * microarchitectures. + */ + cpuinfo_vendor_marvell = 31, + /** RDC Semiconductor Co., Ltd. Vendor of x86 processor + microarchitectures. */ + cpuinfo_vendor_rdc = 32, + /** DM&P Electronics Inc. Vendor of x86 processor microarchitectures. */ + cpuinfo_vendor_dmp = 33, + /** Motorola, Inc. Vendor of PowerPC and ARM processor + microarchitectures. */ + cpuinfo_vendor_motorola = 34, + + /* Defunct CPU vendors */ + + /** + * Transmeta Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 2004. + * Transmeta processors implemented VLIW ISA and used binary translation + * to execute x86 code. + */ + cpuinfo_vendor_transmeta = 50, + /** + * Cyrix Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1996. + */ + cpuinfo_vendor_cyrix = 51, + /** + * Rise Technology. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1999. + */ + cpuinfo_vendor_rise = 52, + /** + * National Semiconductor. Vendor of x86 processor microarchitectures. + * + * Sold its x86 design subsidiary in 1999. The last processor design was + * released in 1998. + */ + cpuinfo_vendor_nsc = 53, + /** + * Silicon Integrated Systems. Vendor of x86 processor + * microarchitectures. + * + * Sold its x86 design subsidiary in 2001. The last processor design was + * released in 2001. + */ + cpuinfo_vendor_sis = 54, + /** + * NexGen. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1994. + * NexGen designed the first x86 microarchitecture which decomposed x86 + * instructions into simple microoperations. + */ + cpuinfo_vendor_nexgen = 55, + /** + * United Microelectronics Corporation. Vendor of x86 processor + * microarchitectures. + * + * Ceased x86 in the early 1990s. The last processor design was released + * in 1991. Designed U5C and U5D processors. Both are 486 level. + */ + cpuinfo_vendor_umc = 56, + /** + * Digital Equipment Corporation. Vendor of ARM processor + * microarchitecture. + * + * Sold its ARM designs in 1997. The last processor design was released + * in 1997. + */ + cpuinfo_vendor_dec = 57, +}; + +#endif // !defined(CPUINFO_SUPPORTED) + +} // namespace + +namespace onnxruntime { + +namespace { + +struct CpuVendorInfo { + cpuinfo_vendor vendor; + std::string_view name; + uint32_t id; +}; + +constexpr auto kUnknownCpuVendorInfo = CpuVendorInfo{cpuinfo_vendor_unknown, "unknown", 0x0000}; + +constexpr std::array kCpuVendorInfos{ + CpuVendorInfo{cpuinfo_vendor_amd, "AMD", 0x1022}, + CpuVendorInfo{cpuinfo_vendor_intel, "Intel", 0x8086}, + CpuVendorInfo{cpuinfo_vendor_qualcomm, "Qualcomm", uint32_t{'Q' | ('C' << 8) | ('O' << 16) | ('M' << 24)}}, + CpuVendorInfo{cpuinfo_vendor_nvidia, "Nvidia", 0x10DE}, + CpuVendorInfo{cpuinfo_vendor_apple, "Apple", 0x106B}, + CpuVendorInfo{cpuinfo_vendor_arm, "ARM", 0x13B5}, + + // TODO add more as needed +}; + +const CpuVendorInfo* FindCpuVendorInfo(cpuinfo_vendor vendor) { + const auto vendor_mapping_it = std::find_if(kCpuVendorInfos.begin(), kCpuVendorInfos.end(), + [vendor](const CpuVendorInfo& entry) { + return entry.vendor == vendor; + }); + + if (vendor_mapping_it != kCpuVendorInfos.end()) { + return &*vendor_mapping_it; + } + + return nullptr; +} + +} // namespace + +void CPUIDInfo::VendorInfoInit() { + const cpuinfo_vendor vendor = [&]() { + cpuinfo_vendor result = cpuinfo_vendor_unknown; +#if defined(CPUINFO_SUPPORTED) + if (pytorch_cpuinfo_init_) { + const auto* processor = cpuinfo_get_processor(0); + if (processor && processor->core) { + result = processor->core->vendor; + } + } +#endif // defined(CPUINFO_SUPPORTED) + return result; + }(); + + const auto* vendor_info = FindCpuVendorInfo(vendor); + if (vendor_info == nullptr) { + LogEarlyWarning(MakeString("Unknown CPU vendor. cpuinfo_vendor value: ", static_cast(vendor))); + vendor_info = &kUnknownCpuVendorInfo; + } + + vendor_ = vendor_info->name; + vendor_id_ = vendor_info->id; +} + +} // namespace onnxruntime diff --git a/src/common/helper.cc b/src/common/helper.cc index 6a52db7..07cd167 100644 --- a/src/common/helper.cc +++ b/src/common/helper.cc @@ -18,7 +18,7 @@ namespace onnxruntime { #ifdef _WIN32 -std::string ToUTF8String(const std::wstring& s) { +std::string ToUTF8String(std::wstring_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); @@ -33,7 +33,7 @@ std::string ToUTF8String(const std::wstring& s) { return ret; } -std::wstring ToWideString(const std::string& s) { +std::wstring ToWideString(std::string_view s) { if (s.size() >= static_cast(std::numeric_limits::max())) ORT_THROW("length overflow"); diff --git a/src/common/safeint.h b/src/common/safeint.h index 3ee70f3..6aba587 100644 --- a/src/common/safeint.h +++ b/src/common/safeint.h @@ -13,11 +13,11 @@ class SafeIntExceptionHandler; template <> class SafeIntExceptionHandler { public: - static void SafeIntOnOverflow() { + [[noreturn]] static void SafeIntOnOverflow() { ORT_THROW("Integer overflow"); } - static void SafeIntOnDivZero() { + [[noreturn]] static void SafeIntOnDivZero() { ORT_THROW("Divide by zero"); } }; diff --git a/src/common/semver.h b/src/common/semver.h new file mode 100644 index 0000000..98bb6a2 --- /dev/null +++ b/src/common/semver.h @@ -0,0 +1,33 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +#include "core/common/status.h" + +namespace onnxruntime { + +// Semantic Versioning version utilities. +// See https://github.com/semver/semver/blob/v2.0.0/semver.md. + +// Semantic Versioning version components. +struct SemVerVersion { + uint32_t major{}; + uint32_t minor{}; + uint32_t patch{}; + std::optional prerelease{}; + std::optional build_metadata{}; +}; + +// Parse a Semantic Versioning version from `version_string`. +// If provided, the parsed version components will be written to `semver_version`. +Status ParseSemVerVersion(std::string_view version_string, SemVerVersion* semver_version); + +// Parse a Semantic Versioning version from `version_string`. +SemVerVersion ParseSemVerVersion(std::string_view version_string); + +} // namespace onnxruntime diff --git a/src/common/spin_pause.cc b/src/common/spin_pause.cc new file mode 100644 index 0000000..7b1de8d --- /dev/null +++ b/src/common/spin_pause.cc @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/spin_pause.h" + +#if defined(_M_AMD64) +#include +#endif + +#if defined(__x86_64__) +#include +#endif + +#if defined(_M_AMD64) || defined(__x86_64__) +#include "core/common/cpuid_info.h" +#if defined(__linux__) +#include +#include +#endif +#endif + +namespace onnxruntime { +namespace concurrency { + +// Intrinsic to use in spin-loops +void SpinPause() { +#if (defined(_M_AMD64) || defined(__x86_64__)) && \ + !defined(_M_ARM64EC) && \ + !defined(__ANDROID__) && \ + !defined(__APPLE__) + _mm_pause(); +#endif +} + +} // namespace concurrency +} // namespace onnxruntime diff --git a/src/common/string_utils.h b/src/common/string_utils.h index c2e26f6..d8d943d 100644 --- a/src/common/string_utils.h +++ b/src/common/string_utils.h @@ -61,10 +61,11 @@ inline void TrimStringFromRight(std::string& s) { * @param s The string to trim. * @return The trimmed string. */ -inline std::string TrimString(std::string s) { - TrimStringFromRight(s); - TrimStringFromLeft(s); - return s; +inline std::string TrimString(std::string_view s) { + std::string s_trimmed{s}; + TrimStringFromRight(s_trimmed); + TrimStringFromLeft(s_trimmed); + return s_trimmed; } /** diff --git a/src/lib/CMakeLists.txt b/src/lib/CMakeLists.txt index acabed4..da54462 100644 --- a/src/lib/CMakeLists.txt +++ b/src/lib/CMakeLists.txt @@ -37,6 +37,10 @@ if (MSVC) endif() endif() + +# mlas_private_compile_definitions contains compile definitions that are private to onnxruntime_mlas and targets which +# use internal MLAS headers like mlasi.h. +set(mlas_private_compile_definitions) # # All hardware agnostic source files here # hardware specific files would cause trouble in @@ -63,6 +67,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/eltwise.cpp ${MLAS_SRC_DIR}/erf.cpp ${MLAS_SRC_DIR}/compute.cpp + ${MLAS_SRC_DIR}/dequantize.cpp ${MLAS_SRC_DIR}/quantize.cpp ${MLAS_SRC_DIR}/qgemm_kernel_default.cpp ${MLAS_SRC_DIR}/qladd.cpp @@ -70,6 +75,7 @@ onnxruntime_add_static_library(onnxruntime_mlas ${MLAS_SRC_DIR}/qpostprocessor.cpp ${MLAS_SRC_DIR}/qlgavgpool.cpp ${MLAS_SRC_DIR}/qdwconv_kernelsize.cpp + ${MLAS_SRC_DIR}/qnbitgemm.h ${MLAS_SRC_DIR}/qnbitgemm.cpp ${MLAS_SRC_DIR}/sqnbitgemm_q8_block.h ${MLAS_SRC_DIR}/flashattn.cpp @@ -95,6 +101,8 @@ if (NOT onnxruntime_ORT_MINIMAL_BUILD) ) endif() +set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas) + #TODO: set MASM flags properly function(setup_mlas_source_for_windows) @@ -136,6 +144,7 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/eltwise_kernel_neon.h ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon_fp16.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp ) set(mlas_platform_preprocess_srcs @@ -159,7 +168,11 @@ function(setup_mlas_source_for_windows) ${MLAS_SRC_DIR}/arm64/SymQgemmS8KernelSDotLd64.asm ) - if (onnxruntime_USE_KLEIDIAI) + if (onnxruntime_USE_ARM_NEON_NCHWC) + setup_arm_neon_nchwc() + endif() + + if (onnxruntime_USE_KLEIDIAI) setup_kleidiai() endif() else() @@ -295,23 +308,35 @@ function(setup_mlas_source_for_windows) endfunction() function(setup_kleidiai) - target_compile_definitions(onnxruntime_mlas PRIVATE USE_KLEIDIAI) - - # Disable the KleidiAI tests - set(KLEIDIAI_BUILD_TESTS OFF) - - # Fetch KleidiAI sources: - if (NOT TARGET kleidiai) - onnxruntime_fetchcontent_declare(kleidiai URL ${DEP_URL_kleidiai} URL_HASH SHA1=${DEP_SHA1_kleidiai} EXCLUDE_FROM_ALL) - endif() - onnxruntime_fetchcontent_makeavailable(kleidiai) - target_sources(onnxruntime_mlas PRIVATE ${MLAS_SRC_DIR}/kai_ukernel_interface.cpp + ${MLAS_SRC_DIR}/kleidiai/sgemm_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/convolve_kleidiai.cpp + ${MLAS_SRC_DIR}/kleidiai/qgemm_kleidiai.cpp ) target_link_libraries(onnxruntime_mlas PRIVATE kleidiai) + list(APPEND onnxruntime_EXTERNAL_LIBRARIES kleidiai) + set(onnxruntime_EXTERNAL_LIBRARIES ${onnxruntime_EXTERNAL_LIBRARIES} PARENT_SCOPE) + + if (NOT onnxruntime_BUILD_SHARED_LIB) + install(TARGETS kleidiai EXPORT ${PROJECT_NAME}Targets + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} + FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) + endif() endfunction() +function (setup_arm_neon_nchwc) + target_sources(onnxruntime_mlas PRIVATE + ${MLAS_SRC_DIR}/sconv.h + ${MLAS_SRC_DIR}/sconv_kernel_neon.cpp + ${MLAS_SRC_DIR}/spool_kernel_neon.cpp + ) + list(APPEND mlas_private_compile_definitions MLAS_USE_ARM_NEON_NCHWC) + set(mlas_private_compile_definitions ${mlas_private_compile_definitions} PARENT_SCOPE) +endfunction () + if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") if (onnxruntime_ENABLE_WEBASSEMBLY_SIMD) file(GLOB_RECURSE mlas_platform_srcs @@ -336,7 +361,6 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") elseif(MSVC) setup_mlas_source_for_windows() else() - if(APPLE) get_target_property(ONNXRUNTIME_MLAS_OSX_ARCH onnxruntime_mlas OSX_ARCHITECTURES) @@ -393,6 +417,8 @@ else() set(X86_64 TRUE) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*") set(LOONGARCH64 TRUE) + elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^s390x$") + set(S390X TRUE) endif() endif() @@ -456,12 +482,29 @@ else() ${MLAS_SRC_DIR}/softmax_kernel_neon.cpp ${MLAS_SRC_DIR}/eltwise_kernel_neon.h ${MLAS_SRC_DIR}/eltwise_kernel_neon.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp ) - if (onnxruntime_USE_KLEIDIAI) + + # Conditionally add the SVE implementation if compiler supports it + if (onnxruntime_USE_SVE) + list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/mlasi_sve.h) + list(APPEND mlas_platform_srcs ${MLAS_SRC_DIR}/sve/elementwise_sve.cpp) + set_source_files_properties(${MLAS_SRC_DIR}/sve/elementwise_sve.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+sve+fp16 ") + list(APPEND mlas_private_compile_definitions MLAS_USE_SVE) + endif() + + if (onnxruntime_USE_ARM_NEON_NCHWC) + setup_arm_neon_nchwc() + endif() + + if (onnxruntime_USE_KLEIDIAI) setup_kleidiai() endif() set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+dotprod") + set_source_files_properties(${MLAS_SRC_DIR}/sqnbitgemm_kernel_neon_int8_i8mm.cpp + PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + if (NOT APPLE) set(mlas_platform_srcs ${mlas_platform_srcs} @@ -500,7 +543,7 @@ else() endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) - onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) + onnxruntime_add_static_library(onnxruntime_mlas_arm64 ${mlas_platform_srcs}) set_target_properties(onnxruntime_mlas_arm64 PROPERTIES OSX_ARCHITECTURES "arm64") list(APPEND ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas_arm64) set(mlas_platform_srcs ) @@ -767,6 +810,7 @@ endif() if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET) set(mlas_platform_srcs ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp + ${MLAS_SRC_DIR}/sqnbitgemm_kernel_lasx.cpp ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S @@ -784,6 +828,24 @@ endif() set(MLAS_SOURCE_IS_NOT_SET 0) endif() endif() + if(S390X AND MLAS_SOURCE_IS_NOT_SET) + set(mlas_platform_srcs + ${MLAS_SRC_DIR}/s390x/SgemmKernel.cpp + ${MLAS_SRC_DIR}/s390x/SgemmKernelZVECTOR.cpp + ${MLAS_SRC_DIR}/dgemm.cpp + ${MLAS_SRC_DIR}/s390x/DgemmKernel.cpp + ${MLAS_SRC_DIR}/s390x/Quantize.cpp + ${MLAS_SRC_DIR}/s390x/QuantizeZVECTOR.cpp + ${MLAS_SRC_DIR}/s390x/qgemm_kernel_zvector.cpp + ) + set_source_files_properties(${MLAS_SRC_DIR}/s390x/SgemmKernel.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") + set_source_files_properties(${MLAS_SRC_DIR}/s390x/SgemmKernelZVECTOR.cpp PROPERTIES COMPILE_FLAGS "-DSINGLE") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mvx -mzvector -march=z15") + + if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH) + set(MLAS_SOURCE_IS_NOT_SET 0) + endif() + endif() if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET) file(GLOB_RECURSE mlas_platform_srcs "${MLAS_SRC_DIR}/scalar/*.cpp") @@ -802,6 +864,8 @@ foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) target_include_directories(${mlas_target} PRIVATE ${ONNXRUNTIME_INCLUDE_DIR} ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) target_link_libraries(${mlas_target} Microsoft.GSL::GSL) + target_compile_definitions(${mlas_target} PRIVATE ${mlas_private_compile_definitions}) + set_target_properties(${mlas_target} PROPERTIES FOLDER "ONNXRuntime") endforeach() @@ -812,12 +876,6 @@ if (WIN32) endif() endif() -if (PLATFORM_NAME STREQUAL "macabi") - # Needed for maccatalyst C compilation - # i.e. the flags below add "--target=x86_64-apple-ios14.0-macabi -ffunction-sections -fdata-sections" - target_compile_options(onnxruntime_mlas PRIVATE ${CMAKE_C_FLAGS}) -endif() - if (NOT onnxruntime_BUILD_SHARED_LIB) install(TARGETS onnxruntime_mlas EXPORT ${PROJECT_NAME}Targets ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} @@ -826,19 +884,7 @@ if (NOT onnxruntime_BUILD_SHARED_LIB) FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() -# set up source group for MLAS source files -block() - set(source_group_srcs) - foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS}) - get_target_property(mlas_target_srcs ${mlas_target} SOURCES) - foreach(mlas_target_src ${mlas_target_srcs}) - cmake_path(IS_PREFIX MLAS_ROOT ${mlas_target_src} in_mlas_root) - if(in_mlas_root) - list(APPEND source_group_srcs ${mlas_target_src}) - endif() - endforeach() - endforeach() -endblock() + @@ -848,7 +894,7 @@ endblock() # based on block-wise quantization of int4 # - add_executable(onnxruntime_mlas_q4dq + onnxruntime_add_executable(onnxruntime_mlas_q4dq ${MLAS_SRC_DIR}/q4_dq_cli.cpp ) target_include_directories(onnxruntime_mlas_q4dq PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR}) @@ -883,3 +929,5 @@ endblock() set_target_properties(onnxruntime_mlas_q4dq PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1") endif() endif() + + diff --git a/src/lib/activate.cpp b/src/lib/activate.cpp index df3b884..5dd8244 100644 --- a/src/lib/activate.cpp +++ b/src/lib/activate.cpp @@ -141,7 +141,7 @@ struct MLAS_ACTIVATION_FUNCTION return _mm_blendv_ps(ValueTimesAlpha, Value, _mm_cmple_ps(ZeroFloat32x4, Value)); #elif defined(MLAS_SSE2_INTRINSICS) return MlasBlendFloat32x4(ValueTimesAlpha, Value, _mm_cmple_ps(ZeroFloat32x4, Value)); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_sel(ValueTimesAlpha, Value, vec_cmple(ZeroFloat32x4, Value)); #elif defined(MLAS_LSX_INTRINSICS) return MlasBlendFloat32x4(ValueTimesAlpha, Value, (__m128)__lsx_vfcmp_cle_s(ZeroFloat32x4, Value)); diff --git a/src/lib/cast.cpp b/src/lib/cast.cpp index 9b5800b..6b71fe3 100644 --- a/src/lib/cast.cpp +++ b/src/lib/cast.cpp @@ -33,6 +33,60 @@ MlasConvertHalfToFloatBuffer( } } + +void +MLASCALL +MlasConvertHalfToFloatBufferInParallel( + const MLAS_FP16* Source, + float* Destination, + size_t Count, + MLAS_THREADPOOL* ThreadPool +) +{ +#if defined(BUILD_MLAS_NO_ONNXRUNTIME) + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + // If the ThreadPool is not available, use the single-threaded version. + MlasConvertHalfToFloatBuffer(Source, Destination, Count); +#else + // Check if the Tensor is long enough to use threads. + // Check if the Thread Pool is available. + // If not, execute single threaded conversion of half to float + if (!((Count > MLAS_MIN_TENSOR_SIZE_FOR_HALF_TO_FLOAT_CONVERSION_IN_PARALLEL) && ThreadPool)) { + MlasConvertHalfToFloatBuffer(Source, Destination, Count); + } + else { + + // Calculate the number of compute cycles per implementation + size_t num_compute_cycles; + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasSSE3()) { + num_compute_cycles = Count >> 1; + } else if (MLAS_CPUIDINFO::GetCPUIDInfo().HasAVX2()) { + num_compute_cycles = Count >> 2; + } else { + num_compute_cycles = Count * 10; + } + + MLAS_THREADPOOL::TryParallelFor( + ThreadPool, Count, + // Tensor Op Cost + { + static_cast(Count * sizeof(MLAS_FP16)), // Size of no. of elements in bytes to be loaded + static_cast(Count * sizeof(float)), // Size of no. of elements in bytes to be stored + static_cast(num_compute_cycles), // No. of compute cycles required for the tensor op + }, + // Lambda function required by TryParallelFor method + [Source, Destination](std::ptrdiff_t first_span, std::ptrdiff_t last_span) { + MlasConvertHalfToFloatBuffer( + Source + first_span, + Destination + first_span, + static_cast(last_span - first_span)); + } + ); + } +#endif // BUILD_MLAS_NO_ONNXRUNTIME +} + void MLASCALL MlasConvertFloatToHalfBuffer( diff --git a/src/lib/compute.cpp b/src/lib/compute.cpp index 96a2398..4916062 100644 --- a/src/lib/compute.cpp +++ b/src/lib/compute.cpp @@ -74,6 +74,7 @@ struct MLAS_SOFTMAX_WORK_BLOCK { ptrdiff_t ThreadCountN; bool LogSoftmax; bool SmoothSoftmax; + float Sink; const T* Input; T* Output; size_t N; @@ -500,7 +501,6 @@ Return Value: Input += 1; N -= 1; } - return Accumulator; } @@ -570,7 +570,6 @@ Return Value: Input += 1; N -= 1; } - return Maximum; } @@ -833,7 +832,7 @@ Return Value: --*/ { const auto* WorkBlock = (MLAS_SOFTMAX_WORK_BLOCK*)Context; - + // // Partition the operation along the N dimension. // @@ -850,6 +849,7 @@ Return Value: const size_t D = WorkBlock->D; const bool LogSoftmax = WorkBlock->LogSoftmax; const bool SmoothSoftmax = WorkBlock->SmoothSoftmax; + const float Sink = WorkBlock->Sink; const float* Input = WorkBlock->Input + n * D; float* Output = WorkBlock->Output + n * D; @@ -874,30 +874,34 @@ Return Value: // // Find the maximum value for the row. // + float Maximum; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) - float Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D); -#else - float Maximum = MlasReduceMaximumF32Kernel(Input, D); +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) + Maximum = GetMlasPlatform().ReduceMaximumF32Kernel(Input, D); +#else + Maximum = MlasReduceMaximumF32Kernel(Input, D); #endif - float NegativeMaximum = -Maximum; - if (SmoothSoftmax && NegativeMaximum > 0.0f) { - NegativeMaximum = 0.0f; + if (SmoothSoftmax && Sink > Maximum) { + Maximum = Sink; } + float NegativeMaximum = -Maximum; + // // Compute the exponential function for each element of the row (save to Temp if provided) and // compute the sum of these exponential functions. // float* Temp = LogSoftmax ? nullptr : Output; -#if defined(MLAS_TARGET_AMD64) - float Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); + float Accumulation; + +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) + Accumulation = GetMlasPlatform().ComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); #else - float Accumulation = MlasComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); + Accumulation = MlasComputeSumExpF32Kernel(Input, Temp, D, &NegativeMaximum); #endif if (SmoothSoftmax) { - Accumulation += expf(NegativeMaximum); + Accumulation += expf(Sink + NegativeMaximum); } if (LogSoftmax) { @@ -906,19 +910,19 @@ Return Value: // float Parameters[] = {NegativeMaximum, std::log(Accumulation)}; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) GetMlasPlatform().ComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); -#else +#else + MlasComputeLogSoftmaxOutputF32Kernel(Input, Output, D, Parameters); #endif - } else { // // Normalize the softmax output. // float Parameters[] = {1.0f / Accumulation}; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_USE_SVE) GetMlasPlatform().ComputeSoftmaxOutputF32Kernel(Output, D, Parameters); #else MlasComputeSoftmaxOutputF32Kernel(Output, D, Parameters); @@ -1014,6 +1018,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ) /*++ @@ -1039,6 +1044,8 @@ Routine Description: SmoothSoftmax - Supplies true if a smooth factor is used in softmax operation. + Sink - Supplies the smooth factor to use in the softmax operation. + ThreadPool - Supplies the thread pool object to use, else nullptr if the base library threading support should be used. @@ -1060,6 +1067,7 @@ Return Value: WorkBlock.Output = Output; WorkBlock.N = N; WorkBlock.D = D; + WorkBlock.Sink = Sink; // // Compute the number of target threads given the complexity of the softmax @@ -1097,6 +1105,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); @@ -1110,6 +1119,7 @@ MlasComputeSoftmax( size_t D, bool LogSoftmax, bool SmoothSoftmax, + float Sink, MLAS_THREADPOOL* ThreadPool ); diff --git a/src/lib/convolve.cpp b/src/lib/convolve.cpp index ec79641..9518134 100644 --- a/src/lib/convolve.cpp +++ b/src/lib/convolve.cpp @@ -729,6 +729,82 @@ Return Value: } } +void +MlasConvExpandThenGemmSegmentedThreaded( + void* Context, + ptrdiff_t Index +) +/*++ + +Routine Description: + + This routine is invoked from a worker thread to execute a segment of a + convolution operation. + + If using this, the entire convolution operation is parallelized on the + (batch size * group count) parameter and this routine has logic to + perform a specific thread's shard of the entire Convolution operation. + +Arguments: + + Context - Supplies the pointer to the context for the threaded operation. + + Index - Supplies the current index of the threaded operation. + +Return Value: + + None. + +--*/ + +{ + MLAS_CONV_WORK_BLOCK* WorkBlock = (MLAS_CONV_WORK_BLOCK*)Context; + + const MLAS_CONV_PARAMETERS* Parameters = WorkBlock->Parameters; + + const size_t GroupCount = Parameters->GroupCount; + const size_t BatchGroupCount = Parameters->BatchCount * GroupCount; + + const size_t TargetThreadCount = WorkBlock->TargetThreadCount; + + const size_t BatchGroupCountPerThread = BatchGroupCount / TargetThreadCount; + const size_t BatchGroupCountExtra = BatchGroupCount % TargetThreadCount; + + size_t BatchGroupStart; + size_t BatchGroupEnd; + + if (static_cast(Index) < BatchGroupCountExtra) { + BatchGroupStart = (BatchGroupCountPerThread + 1) * Index; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread + 1; + } else { + BatchGroupStart = BatchGroupCountPerThread * Index + BatchGroupCountExtra; + BatchGroupEnd = BatchGroupStart + BatchGroupCountPerThread; + } + + const size_t FilterCount = Parameters->FilterCount; + const size_t OutputSize = Parameters->OutputSize; + const size_t K = Parameters->K; + + const size_t InputGroupSize = Parameters->InputChannels * Parameters->InputSize; + const size_t OutputGroupSize = FilterCount * OutputSize; + const size_t FilterGroupSize = FilterCount * K; + + for (size_t bg = BatchGroupStart; bg < BatchGroupEnd; bg++) { + size_t group = bg % GroupCount; + + const float* input = WorkBlock->Input + bg * InputGroupSize; + const float* filter = WorkBlock->Filter + group * FilterGroupSize; + float* output = WorkBlock->Output + bg * OutputGroupSize; + const float* bias = WorkBlock->Bias; + if (bias != nullptr) { + bias += group * FilterCount; + } + float* ColumnBuffer = WorkBlock->WorkingBuffer + Index * OutputSize * K; + + MlasConvOperation(Parameters, input, filter, bias, ColumnBuffer, output, 0, OutputSize); + } +} + inline bool MlasConvTryMultithread( @@ -861,6 +937,12 @@ Return Value: --*/ { + // Override + if(GetMlasPlatform().MlasConvOverride != nullptr && + GetMlasPlatform().MlasConvOverride(Parameters,Input,Filter,Bias,WorkingBuffer,Output,ThreadPool)){ + return; + } + const size_t FilterCount = Parameters->FilterCount; const size_t OutputSize = Parameters->OutputSize; const size_t K = Parameters->K; @@ -884,8 +966,8 @@ Return Value: ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); - if (size_t(TargetThreadCount) >= BatchGroupCount) { - TargetThreadCount = ptrdiff_t(BatchGroupCount); + if (static_cast(TargetThreadCount) >= BatchGroupCount) { + TargetThreadCount = static_cast(BatchGroupCount); } MLAS_CONV_WORK_BLOCK WorkBlock; @@ -913,6 +995,30 @@ Return Value: #endif + if (Algorithm == MlasConvAlgorithmExpandThenGemmSegmented && ((BatchCount > 1) || (GroupCount > 1))) { + const size_t BatchGroupCount = BatchCount * GroupCount; + + ptrdiff_t TargetThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (static_cast(TargetThreadCount) >= BatchGroupCount) { + TargetThreadCount = static_cast(BatchGroupCount); + } + + MLAS_CONV_WORK_BLOCK WorkBlock; + + WorkBlock.Parameters = Parameters; + WorkBlock.Input = Input; + WorkBlock.Filter = Filter; + WorkBlock.Bias = Bias; + WorkBlock.WorkingBuffer = WorkingBuffer; + WorkBlock.Output = Output; + WorkBlock.TargetThreadCount = TargetThreadCount; + + MlasExecuteThreaded(MlasConvExpandThenGemmSegmentedThreaded, &WorkBlock, TargetThreadCount, ThreadPool); + + return; + } + // // Iterate over each batch and group. // @@ -1094,6 +1200,13 @@ Return Value: --*/ { + // Override + if (GetMlasPlatform().MlasConvPrepareOverride != nullptr && + GetMlasPlatform().MlasConvPrepareOverride(Parameters, Dimensions, BatchCount, GroupCount, InputChannels, + InputShape,KernelShape,DilationShape, Padding, StrideShape, OutputShape, FilterCount, + Activation, WorkingBufferSize, Beta, ThreadPool)){ + return; + } // // Save the convolution parameters. // @@ -1295,8 +1408,20 @@ Return Value: Parameters->u.ExpandThenGemmSegmented.ThreadStrideN = StrideN; *WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; + + if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) { + + size_t WorkingBufferSizePerThread = std::max({Parameters->OutputSize * Parameters->K, + Parameters->FilterCount * Parameters->OutputSize, + static_cast(MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD)}); + TargetThreadCount = MaximumThreadCount; + if (static_cast(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) { + TargetThreadCount = static_cast(Parameters->BatchCount * Parameters->GroupCount); + } + *WorkingBufferSize = TargetThreadCount * WorkingBufferSizePerThread; + } } } #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) -#endif \ No newline at end of file +#endif diff --git a/src/lib/convsym.cpp b/src/lib/convsym.cpp index 0ea7bef..5591aa4 100644 --- a/src/lib/convsym.cpp +++ b/src/lib/convsym.cpp @@ -16,7 +16,8 @@ Module Name: --*/ #include "mlasi.h" -#include +#include + // // Define the prototypes of the platform optimized routines. // diff --git a/src/lib/dequantize.cpp b/src/lib/dequantize.cpp new file mode 100644 index 0000000..175d3f6 --- /dev/null +++ b/src/lib/dequantize.cpp @@ -0,0 +1,395 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + dequantize.cpp + +Abstract: + + This module implements routines to dequantize buffers. + + The dequantization formula as specified in the ONNX operator documentation is: + + Output = (Input - ZeroPoint) * Scale + +--*/ + +#include "mlasi.h" + +// +// DequantizeLinear reference implementation using the C++ runtime. +// + +template +static +MLAS_FORCEINLINE +void +MlasDequantizeLinearRefImpl( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer with quantized data. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + Output[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } +} + +#if defined(MLAS_SSE2_INTRINSICS) +// Implementation for Intel SSE 2. Refer to the Intel Intrisics Guide: +// https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + __m128i VectorS8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Sign-extend into 2 vectors of 8 int16s + __m128i SignMaskS8 = _mm_cmpgt_epi8(Zeros, VectorS8); // 0xFF for every negative byte in VectorS8 + __m128i VectorS16_0 = _mm_unpacklo_epi8(VectorS8, SignMaskS8); // [0 ... 7] + __m128i VectorS16_1 = _mm_unpackhi_epi8(VectorS8, SignMaskS8); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = _mm_sub_epi16(VectorS16_0, ZeroPointS16Vector); + VectorS16_1 = _mm_sub_epi16(VectorS16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const __m128 ScaleVector = MlasBroadcastFloat32x4(Scale); + const __m128i ZeroPointS16Vector = _mm_set1_epi16(static_cast(ZeroPoint)); // Broadcast zp to 8 int16s + const __m128i Zeros = _mm_setzero_si128(); + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + __m128i VectorU8 = _mm_loadu_si128(reinterpret_cast(Input)); + + // Zero-extend into 2 vectors of 8 uint16s + __m128i VectorU16_0 = _mm_unpacklo_epi8(VectorU8, Zeros); // [0 ... 7] + __m128i VectorU16_1 = _mm_unpackhi_epi8(VectorU8, Zeros); // [8 ... 15] + + // Subtract the zero-points as uint16s. Due to two's compliment, negative results can be reinterpreted as int16 + __m128i VectorS16_0 = _mm_sub_epi16(VectorU16_0, ZeroPointS16Vector); + __m128i VectorS16_1 = _mm_sub_epi16(VectorU16_1, ZeroPointS16Vector); + + // Sign-extend into 4 vectors of 4 int32s + __m128i SignMaskS16_0 = _mm_cmpgt_epi16(Zeros, VectorS16_0); + __m128i VectorS32_0 = _mm_unpacklo_epi16(VectorS16_0, SignMaskS16_0); // [0 ... 3] + __m128i VectorS32_1 = _mm_unpackhi_epi16(VectorS16_0, SignMaskS16_0); // [4 ... 7] + + __m128i SignMaskS16_1 = _mm_cmpgt_epi16(Zeros, VectorS16_1); + __m128i VectorS32_2 = _mm_unpacklo_epi16(VectorS16_1, SignMaskS16_1); // [8 ... 11] + __m128i VectorS32_3 = _mm_unpackhi_epi16(VectorS16_1, SignMaskS16_1); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + __m128 VectorF32_0 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_0), ScaleVector); + __m128 VectorF32_1 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_1), ScaleVector); + __m128 VectorF32_2 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_2), ScaleVector); + __m128 VectorF32_3 = _mm_mul_ps(_mm_cvtepi32_ps(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + _mm_storeu_ps(Output + 0, VectorF32_0); + _mm_storeu_ps(Output + 4, VectorF32_1); + _mm_storeu_ps(Output + 8, VectorF32_2); + _mm_storeu_ps(Output + 12, VectorF32_3); + + Input += 16; + Output += 16; + N -= 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearS8Kernel( +#else + MlasDequantizeLinearS8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ +#if defined(MLAS_TARGET_AMD64) + GetMlasPlatform().DequantizeLinearU8Kernel( +#else + MlasDequantizeLinearU8Kernel( +#endif + Input, Output, N, Scale, ZeroPoint); +} +#elif defined(MLAS_NEON64_INTRINSICS) +// Implementation for ARM64 NEON. Refer to the ARM instrinsics guide: +// https://developer.arm.com/architectures/instruction-sets/intrinsics/ + +void +MLASCALL +MlasDequantizeLinearS8Kernel( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const int16x8_t ZeroPointVector = vdupq_n_s16(ZeroPoint); // Broadcast ZeroPoint (sign-extended to 16bits) + + while (N >= 16) { + // Load a vector of 16 int8s: [0 ... 15] + int8x16_t VectorS8 = vld1q_s8(Input); + + // Sign-extend into 2 vectors of 8 int16s + int16x8_t VectorS16_0 = vmovl_s8(vget_low_s8(VectorS8)); // [0 ... 7] + int16x8_t VectorS16_1 = vmovl_s8(vget_high_s8(VectorS8)); // [8 ... 15] + + // Subtract the zero-points in int16 domain. + VectorS16_0 = vsubq_s16(VectorS16_0, ZeroPointVector); + VectorS16_1 = vsubq_s16(VectorS16_1, ZeroPointVector); + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasDequantizeLinearU8Kernel( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + const float32x4_t ScaleVector = MlasBroadcastFloat32x4(Scale); + const uint8x8_t ZeroPointVector = vdup_n_u8(ZeroPoint); // Broadcast ZeroPoint to 8 uint8s + + while (N >= 16) { + // Load a vector of 16 uint8s: [0 ... 15] + uint8x16_t VectorU8 = vld1q_u8(Input); + + // Subtract zero-point. The vsubl_u8 instruction zero-extends its arguments to uint16 first. + // The reinterpret from uint16x8 to int16x8 is actually a NOP. + int16x8_t VectorS16_0 = vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(VectorU8), ZeroPointVector)); // [0 ... 7] + int16x8_t VectorS16_1 = vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(VectorU8), ZeroPointVector)); // [8 ... 15] + + // Sign-extend into 4 vectors of 4 int32s + int32x4_t VectorS32_0 = vmovl_s16(vget_low_s16(VectorS16_0)); // [0 ... 3] + int32x4_t VectorS32_1 = vmovl_s16(vget_high_s16(VectorS16_0)); // [4 ... 7] + int32x4_t VectorS32_2 = vmovl_s16(vget_low_s16(VectorS16_1)); // [8 ... 11] + int32x4_t VectorS32_3 = vmovl_s16(vget_high_s16(VectorS16_1)); // [12 ... 15] + + // Cast each int32x4 to float and multiply by the scale vector. + float32x4_t VectorF32_0 = vmulq_f32(vcvtq_f32_s32(VectorS32_0), ScaleVector); + float32x4_t VectorF32_1 = vmulq_f32(vcvtq_f32_s32(VectorS32_1), ScaleVector); + float32x4_t VectorF32_2 = vmulq_f32(vcvtq_f32_s32(VectorS32_2), ScaleVector); + float32x4_t VectorF32_3 = vmulq_f32(vcvtq_f32_s32(VectorS32_3), ScaleVector); + + // Store each int32x4 into the output. + vst1q_f32(Output + 0, VectorF32_0); + vst1q_f32(Output + 4, VectorF32_1); + vst1q_f32(Output + 8, VectorF32_2); + vst1q_f32(Output + 12, VectorF32_3); + + N -= 16; + Input += 16; + Output += 16; + } + + // Handle leftover elements (< 16) with the scalar reference implementation. + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasDequantizeLinearS8Kernel(Input, Output, N, Scale, ZeroPoint); +} + +template<> +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + MlasDequantizeLinearU8Kernel(Input, Output, N, Scale, ZeroPoint); +} +#else +// Implementation that uses the scalar reference implementation. + +template +void +MLASCALL +MlasDequantizeLinear( + const InputType* Input, + float* Output, + size_t N, + float Scale, + InputType ZeroPoint + ) +{ + MlasDequantizeLinearRefImpl(Input, Output, N, Scale, ZeroPoint); +} + +template +void +MLASCALL +MlasDequantizeLinear( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ); + +template +void +MLASCALL +MlasDequantizeLinear( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ); + +#endif diff --git a/src/lib/dgemm.cpp b/src/lib/dgemm.cpp index 50c6274..bc341e6 100644 --- a/src/lib/dgemm.cpp +++ b/src/lib/dgemm.cpp @@ -26,7 +26,7 @@ Module Name: #define MLAS_DGEMM_TRANSA_ROWS 12 -#if defined (MLAS_TARGET_AMD64) || defined (MLAS_TARGET_POWER) +#if defined (MLAS_TARGET_AMD64) || defined (MLAS_TARGET_POWER) || defined (MLAS_TARGET_S390X) void MlasDgemmMultiplyBeta( @@ -530,7 +530,7 @@ Return Value: size_t RowsHandled; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_S390X) RowsHandled = GetMlasPlatform().GemmDoubleKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { diff --git a/src/lib/erf.cpp b/src/lib/erf.cpp index b45bd51..f972406 100644 --- a/src/lib/erf.cpp +++ b/src/lib/erf.cpp @@ -22,7 +22,6 @@ Module Name: --*/ #include "mlasi.h" - // // Bundles the constants for use by kernels written in assembly. // @@ -261,7 +260,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) GetMlasPlatform().ErfKernelRoutine(Input, Output, N); #else MlasErfKernel(Input, Output, N); diff --git a/src/lib/kleidiai/convolve_kleidiai.cpp b/src/lib/kleidiai/convolve_kleidiai.cpp new file mode 100644 index 0000000..9eaf490 --- /dev/null +++ b/src/lib/kleidiai/convolve_kleidiai.cpp @@ -0,0 +1,720 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include +#include +#include +#include +#include "mlasi_kleidiai.h" +#include +#include + +#include "kai/ukernels/matmul/imatmul_clamp_f32_f32p_f32p/kai_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_imatmul_pack_x32p2vlx1_x32p_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme.h" + +// Right-hand-side (weights) cache key +struct RhsCacheKey { + size_t co, ci, kh, kw, dilationh, dilationw; + size_t weights_hash; + + bool operator==(const RhsCacheKey& other) const { + return co == other.co && ci == other.ci && + kh == other.kh && kw == other.kw && + dilationh == other.dilationh && dilationw == other.dilationw && + weights_hash == other.weights_hash; + } +}; + + +// Left-hand-side (input indirection) cache key +struct LhsCacheKey { + size_t ci, ih, iw; + size_t padding, sh, sw; + size_t kh, kw; + size_t dilationh, dilationw; + size_t data_hash; + + bool operator==(const LhsCacheKey& other) const { + return ci == other.ci && ih == other.ih && iw == other.iw && + padding == other.padding && sh == other.sh && sw == other.sw && + kh == other.kh && kw == other.kw && + dilationh == other.dilationh && dilationw == other.dilationw && + data_hash == other.data_hash; + } +}; + +// Derived from 2^32 * (sqrt(5) - 1) / 2 ≈ 0.6180339887 (reciprocal of the golden ratio) +// Based on Knuth's multiplicative hashing method +constexpr size_t HASH_GOLDEN_RATIO_CONST = 0x9e3779b9; + +size_t HashWeights(const float* data, size_t count = 16) { + size_t h = 0; + for (size_t i = 0; i < count; ++i) { + h ^= std::hash()(data[i]) + HASH_GOLDEN_RATIO_CONST + (h << 6) + (h >> 2); + } + return h; +} + +namespace std { + // Specialize hash type for cache keys and do it within namespace std. + // Doing this allows standard containers like std::unordered_map to find + // the appropriate hash function via template specialization, as ADL + // (argument-dependent lookup) does not apply to std::hash. + template<> + struct hash { + size_t operator()(const RhsCacheKey& k) const { + return k.weights_hash ^ + (std::hash()(k.co) << 1) ^ + (std::hash()(k.ci) << 2) ^ + (std::hash()(k.kh) << 3) ^ + (std::hash()(k.kw) << 4) ^ + (std::hash()(k.dilationh) << 5) ^ + (std::hash()(k.dilationw) << 6); + } + }; + + template<> + struct hash { + size_t operator()(const LhsCacheKey& k) const { + return k.data_hash ^ + (std::hash()(k.ci) << 1) ^ + (std::hash()(k.ih) << 2) ^ + (std::hash()(k.iw) << 3) ^ + (std::hash()(k.padding) << 4) ^ + (std::hash()(k.sh) << 5) ^ + (std::hash()(k.sw) << 6) ^ + (std::hash()(k.kh) << 7) ^ + (std::hash()(k.kw) << 8) ^ + (std::hash()(k.dilationh) << 9) ^ + (std::hash()(k.dilationw) << 10); + } + }; + +} + + +static constexpr size_t ComputeKernelSize(const size_t D, const size_t K) { + // D - dilation size + // K - kernel size + + // D*S scale 1D kernel dimension by dilation factor + // (D-1) remove affect of dilation scaling at kernel end + return (D*K) - (D - 1); +} + +static constexpr size_t ComputeConvOutSize(const size_t L, const size_t K, const size_t P, const size_t S) { + + //With start + end padding + + //L - input size + //K - kernel size + //P - Padding size + //S - stride size + + //Does the convolution compute one value or less ? + if ( S > 0 && (L + 2*P) >= K) { + // L-(K-1) standard convolution output size is L-(K-1) for a step size of 1 with no padding + // (2*P) 1D start and end padding + // (L+2*P)-(K-1) the 1D length of convolution result for a kernel step size of 1 + // /S apply the kernel step + return (((L - K) + (2 * P)) / S) + 1; + } + return 0; +} + +static size_t ComputeMlasWorkingBufferSize(const size_t co, + const size_t ih, const size_t iw, + const size_t kh, const size_t kw, + const size_t dilationh, const size_t dilationw, + const size_t sh, const size_t sw, + const size_t padding) { + // dimensions of dilated kernel + const auto d_kh = ComputeKernelSize(dilationh, kh); + const auto d_kw = ComputeKernelSize(dilationw, kw); + + const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) * + ComputeConvOutSize(iw, d_kw, padding, sw); + + return m * co; +} + +static bool CheckCapabilitiesSme(const MLAS_CONV_PARAMETERS* Parameters) { + + //functional checks - logically can the conv be performed + if ((Parameters->Dimensions != 2) || + (Parameters->BatchCount != 1) || + (Parameters->Beta != 0.f) || + (Parameters->Padding[0] != Parameters->Padding[1]) || + (Parameters->Padding[0] != Parameters->Padding[2]) || + (Parameters->Padding[0] != Parameters->Padding[3]) || + (ComputeConvOutSize(Parameters->InputShape[0], + ComputeKernelSize(Parameters->DilationShape[0],Parameters->KernelShape[0]), + Parameters->Padding[0], Parameters->StrideShape[0]) * + ComputeConvOutSize(Parameters->InputShape[1], + ComputeKernelSize(Parameters->DilationShape[1],Parameters->KernelShape[1]), + Parameters->Padding[1], Parameters->StrideShape[1]) == 0)) { + return false; + } + + //optimization checks - is the implementation optimal for the conv request + + const auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + auto M = ComputeConvOutSize(Parameters->InputShape[0], ComputeKernelSize(Parameters->DilationShape[0], + Parameters->KernelShape[0]), Parameters->Padding[0], Parameters->StrideShape[0]) * + ComputeConvOutSize(Parameters->InputShape[1], ComputeKernelSize(Parameters->DilationShape[1], + Parameters->KernelShape[1]), Parameters->Padding[1], Parameters->StrideShape[1]); + auto N = Parameters->FilterCount; + auto K = Parameters->InputChannels * Parameters->KernelShape[0] * Parameters->KernelShape[1]; + + //Can use these variables to add other conditions as required + MLAS_UNREFERENCED_PARAMETER(M); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(m_step); + MLAS_UNREFERENCED_PARAMETER(n_step); + + if (N == 1 || Parameters->KernelShape[0] < 3 || Parameters->KernelShape[1] < 3) { + return false; + } + return true; +} + +//General purpose axis swapping +static auto Transpose4D(std::array shape_in, + const float* in, + std::array permute) { + + std::array shape_out{shape_in[permute[0]], + shape_in[permute[1]], + shape_in[permute[2]], + shape_in[permute[3]]}; + + assert((shape_in[0] * shape_in[1] * shape_in[2] * shape_in[3]) == + (shape_out[0] * shape_out[1] * shape_out[2] * shape_out[3])); + assert(permute[0] < 4 && permute[1] < 4 && permute[2] < 4 && permute[3] < 4); + + const size_t get_stride[] {shape_in[1] * shape_in[2] * shape_in[3], shape_in[2] * shape_in[3], shape_in[3]}; + auto get = [get_stride,in](const std::array& el) { + return in[el[0] * get_stride[0] + + el[1] * get_stride[1] + + el[2] * get_stride[2] + + el[3]]; + }; + + auto out_ = std::make_unique(shape_in[0] * shape_in[1] * shape_in[2] * shape_in[3]); + auto out = out_.get(); + + const size_t set_stride[]{shape_out[1] * shape_out[2] * shape_out[3], shape_out[2] * shape_out[3], shape_out[3]}; + auto set = [set_stride,out](const std::array& el, float v) { + out[el[0] * set_stride[0] + + el[1] * set_stride[1] + + el[2] * set_stride[2] + + el[3]] = v; + }; + + std::array shape; + for (shape[0] = 0; shape[0] < shape_in[0]; ++shape[0]) { + for (shape[1] = 0; shape[1] < shape_in[1]; ++shape[1]) { + for (shape[2] = 0; shape[2] < shape_in[2]; ++shape[2]) { + for (shape[3] = 0; shape[3] < shape_in[3]; ++shape[3]) { + set({shape[permute[0]], shape[permute[1]], shape[permute[2]], shape[permute[3]]}, get(shape)); + } + } + } + } + + return out_; +} + +//nchw to nhwc specific axis swapping +static std::unique_ptr NChwToNhwc(const size_t n, + const size_t c, + const size_t h, + const size_t w, + const float* RESTRICT in, + const size_t dilationh=1, + const size_t dilationw=1, + const bool zero_fill=false, + MLAS_THREADPOOL* ThreadPool=nullptr) { + + const auto d_h = ComputeKernelSize(dilationh, h); + const auto d_w = ComputeKernelSize(dilationw, w); + + auto t = std::make_unique(n*d_h*d_w*c); + if (zero_fill) { + std::fill(&t.get()[0], &t.get()[n*d_h*d_w*c], 0.f); + } + + if (dilationh > 1 || dilationw > 1 || n > 1) { + const size_t get_strides[] {c*h*w,h*w,w}; + auto get = [get_strides,in](const std::array& el) { + return in[el[0]*get_strides[0] + + el[1]*get_strides[1] + + el[2]*get_strides[2] + + el[3]]; + }; + + const size_t set_strides[] {d_h*d_w*c,dilationh*d_w*c,dilationw*c}; + auto set = [set_strides](const std::array& el, float v, float* out) { + out[el[0]*set_strides[0] + + el[1]*set_strides[1] + + el[2]*set_strides[2] + + el[3]] = v; + }; + + MLAS_UNREFERENCED_PARAMETER(set); + MLAS_UNREFERENCED_PARAMETER(get); + + auto out0 = t.get(); + for (size_t s0 = n; s0 > 0; --s0) { + auto out1 = out0; + for (size_t s1 = c; s1 > 0; --s1) { + auto out2 = out1; + for (size_t s2 = h; s2 > 0; --s2) { + float* RESTRICT out3 = out2; + size_t s3 = w; + for (; s3 > 4; s3 -= 4) { + auto vf32 = MlasLoadFloat32x4(in); + in += 4; + MlasStoreLaneFloat32x4<0>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<1>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<2>(out3,vf32); + out3 += set_strides[2]; + MlasStoreLaneFloat32x4<3>(out3, vf32); + out3 += set_strides[2]; + } + for (; s3 > 0; --s3) { + //set({s0,s2,s3,s1}, get({s0,s1,s2,s3}),t.get()); + *out3 = *in++; + out3 += set_strides[2]; + } + out2 += set_strides[1]; + } + out1++; + } + out0 += set_strides[0]; + } + } else { + MlasTranspose(in, t.get(), c, d_h*d_w, ThreadPool); + } + + return t; +} + +static void MultiThreadedLHSPackSme(MLAS_THREADPOOL* ThreadPool, const size_t ci, const size_t m, const size_t kh, + const size_t kw, const void * const* lhs_ptrs, std::byte* lhs_data, + const float* in_data, + const float* pad_ptr) { + + auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = MlasDivRoundup(m, m_step); + auto MaxTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), RequiredTiles); + m_step *= MlasDivRoundup(RequiredTiles, MaxTiles); + RequiredTiles = MlasDivRoundup(m, m_step); + + MlasTrySimpleParallel(ThreadPool, static_cast(RequiredTiles), [&](ptrdiff_t tid) { + + auto m_idx = static_cast(tid) * m_step; + auto offset = kai_get_lhs_packed_offset_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m_idx,kh*kw,ci); + + kai_run_lhs_imatmul_pack_x32p2vlx1_x32p_sme( + m < (m_idx + m_step) ? m - m_idx : m_step, kh * kw, ci, + lhs_ptrs + m_idx * kh * kw, + reinterpret_cast(in_data), + reinterpret_cast(pad_ptr), + lhs_data + offset + ); + }); +} + +static std::shared_ptr RhsPackWeightsBiasSme(const size_t co, const size_t ci, + const size_t kh, const size_t kw, + const size_t dilationh, const size_t dilationw, + const float* weights, const float* bias, + MLAS_THREADPOOL* ThreadPool) +{ + //cache of prepacked kai rhs weights and biases + static std::unordered_map> rhs_cache; + + RhsCacheKey key = { co, ci, kh, kw, dilationh, dilationw, HashWeights(weights) }; + + auto found = rhs_cache.find(key); + if (found != rhs_cache.end()) { + return found->second; + } else { + // prepare mlas filter weights for kai rhs packing + // dilated nhwc format + auto nhwc = NChwToNhwc(co, ci, kh, kw, weights, dilationh, dilationw, true, ThreadPool); + + + //dilation, axis swap (n x k -> k x n) where n == co, k == d_kh x d_kw x ci + const auto d_kh = ComputeKernelSize(dilationh,kh); + const auto d_kw = ComputeKernelSize(dilationw,kw); + + //t_weights[d_kh][d_kw][ci][co] = nhwc[co][d_kh][d_kw][ci] + auto t_weights = Transpose4D({co,d_kh,d_kw,ci},&nhwc[0],{1,2,3,0}); + + const auto packed_size = kai_get_rhs_packed_size_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme(co,d_kh*d_kw,ci); + auto packed = std::shared_ptr(new std::byte[packed_size], std::default_delete()); + + rhs_cache[key] = packed; + + std::vector bias_copy; + if (bias) { + bias_copy.assign(bias, bias + co); + } else { + bias_copy.resize(co, 0.0f); + } + + kai_run_rhs_imatmul_pack_kxn_x32p2vlx1b_x32_x32_sme( + co, d_kh*d_kw, ci, co * sizeof(float), &t_weights[0], bias_copy.data(), packed.get() + ); + + return packed; + } +} + +static std::shared_ptr LhsPtrFill(const size_t ci, const size_t ih, const size_t iw, + const size_t kh, const size_t kw, size_t sh, size_t sw, + const size_t padding, + const float* pad_ptr) { + size_t check_filled{0}; + + const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw); + + const auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + const auto lhs_ptrs_k = kh * kw; + const auto lhs_ptrs_m = m_step * MlasDivRoundup(m, m_step); + auto lhs_ptrs = std::shared_ptr(new const void*[lhs_ptrs_k * lhs_ptrs_m], + std::default_delete()); + + + auto ih_out_size = ComputeConvOutSize(ih, kh, padding, 1); + auto iw_out_size = ComputeConvOutSize(iw, kw, padding, 1); + + auto ptrs_offset = [lhs_ptrs_m,lhs_ptrs_k, m_step](size_t k, size_t m) { + //(m/m_step,transpose(m_step,k) + auto offset {((m/m_step) * lhs_ptrs_k * m_step) + (k*m_step) + (m%m_step)}; + assert(offset < (lhs_ptrs_k * lhs_ptrs_m)); + + MLAS_UNREFERENCED_PARAMETER(lhs_ptrs_m); + + return offset; + }; + + auto pixel_offset = [ih, iw, ci, pad_ptr, padding](size_t h, size_t w) { + if (h < padding) { + return reinterpret_cast(&pad_ptr[0]); + } + h -= padding; + + if (w < padding) { + return reinterpret_cast(&pad_ptr[0]); + } + w -= padding; + + if ((h >= ih) || (w >= iw)) { + return reinterpret_cast(&pad_ptr[0]); + } + + auto offset{h * iw * ci + w * ci}; + assert(offset < (ih*iw*ci)); + return offset*sizeof(float); + }; + + size_t m_{0}; + auto lhs_ptrs_ = lhs_ptrs.get(); + for (size_t ih_ = 0; ih_ < ih_out_size; ih_ += sh) { + for (size_t iw_ = 0; iw_ < iw_out_size; iw_ += sw, ++m_) { + size_t k_{0}; + for (size_t kh_ = 0; kh_ < kh; ++kh_) { + for (size_t kw_ = 0; kw_ < kw; ++kw_) { + lhs_ptrs_[ptrs_offset(k_, m_)] = reinterpret_cast(pixel_offset(ih_+kh_, iw_+kw_)); + k_++; check_filled++; + } + } + } + } + + assert(check_filled == (lhs_ptrs_k * m)); + MLAS_UNREFERENCED_PARAMETER(check_filled); + + return lhs_ptrs; +} + +static std::unique_ptr LhsPackImageDataSme(const size_t ci, const size_t ih, const size_t iw, + const size_t kh, const size_t kw, const size_t sh, + const size_t sw, const size_t padding, const float* in, + MLAS_THREADPOOL* ThreadPool) +{ + size_t padsize = 256; + if(ci > padsize) + { + // figure out how many blocks needed to correctly fill padding + padsize = ((ci + padsize - 1) / padsize) * padsize; + } + static std::vectorpad_ptr(padsize, 0.f); + + LhsCacheKey key = { + ci, ih, iw, + padding, sh, sw, + kh, kw, + 1, 1, + HashWeights(in) + }; + + //create lhs in format required for imatmul + const auto m = ComputeConvOutSize(ih, kh, padding, sh) * ComputeConvOutSize(iw, kw, padding, sw); + + const auto lhs_size = kai_get_lhs_packed_size_lhs_imatmul_pack_x32p2vlx1_x32p_sme(m,kh*kw,ci); + auto lhs = std::make_unique(lhs_size); + + auto nhwc = NChwToNhwc(1, ci, ih, iw, in, 1, 1, false, ThreadPool); + + //cache of computed lhs ptr offsets + static std::unordered_map> lhs_ptrs_cache; + + std::shared_ptr lhs_ptrs; + if (auto found = lhs_ptrs_cache.find(key); found != lhs_ptrs_cache.end()) { + lhs_ptrs = found->second; + } else { + lhs_ptrs = LhsPtrFill(ci, ih, iw, kh, kw, sh, sw, padding, &pad_ptr[0]); + lhs_ptrs_cache[key] = lhs_ptrs; + } + + MultiThreadedLHSPackSme(ThreadPool, ci, m, kh, kw, &lhs_ptrs[0], &lhs[0], &nhwc[0], &pad_ptr[0]); + + return lhs; +} + +static void ConvolveSme(const size_t co, //channels out + const size_t ci, //channels in + const size_t ih, //image height + const size_t iw, //image width + const size_t kh, //kernel height + const size_t kw, //kernel width + const size_t sh, //kernel stride height + const size_t sw, //kernel stride width + const size_t dilationh, //kernel dilation stride + const size_t dilationw, //kernel dilation stride + const size_t padding, //padding size + const size_t groups, //number of filter groups + const float* weights, //kernel weights [co,ci,ih,iw] + const float* bias, //kernel biases + const float* in, //in image data + float* out, //out image data + float* tmp_mlas_aligned, //intermediate buffer if we need to perform a transpose + MLAS_THREADPOOL* ThreadPool) { + + //RhsPackWeightsBiasSme() - to perform dilation increases kernel size and masks unused weights + //compute corrected dimensions of dilated kernel + const auto d_kh = ComputeKernelSize(dilationh, kh); + const auto d_kw = ComputeKernelSize(dilationw, kw); + + //run igemm based convolution + const auto m = ComputeConvOutSize(ih, d_kh, padding, sh) * + ComputeConvOutSize(iw, d_kw, padding, sw); + + auto n_step = kai_get_n_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + auto m_step = kai_get_m_step_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(); + + //tile iteration dimensions + std::array dim; + dim[0] = 1; // B + dim[1] = MlasDivRoundup(m, m_step); // M + dim[2] = MlasDivRoundup(co, n_step); // N + + //Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0]*dim[1]*dim[2]); + + //scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + //compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(m, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(co, dim[2]), n_step); + + //update tile iterations + dim[1] = MlasDivRoundup(m, m_step); + dim[2] = MlasDivRoundup(co, n_step); + + for (size_t g = 0; g < groups; ++g) { + + auto result{out}; + //do we require a post matmul transpose ? + //output is m x n or image_data x co or hw x co + //MLAS require it as n x m (or co x hw), transpose required + if (co > 1) { + //intermediate buffer required, pre-transpose + //Note: because we are calling MlasTranspose() need to ensure we use a MLAS aligned buffer + result = tmp_mlas_aligned; + } + + auto lhs = LhsPackImageDataSme(ci, ih, iw, d_kh, d_kw, sh, sw, padding, in, ThreadPool); + auto rhs = RhsPackWeightsBiasSme(co, ci, kh, kw, dilationh, dilationw, weights, bias, ThreadPool); + + + MlasTrySimpleParallel(ThreadPool, + static_cast(dim[0]*dim[1]*dim[2]), + [&](ptrdiff_t tid) + { + //compute B,M,N index from iteration index + //ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = + kai_get_rhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(NIdx*n_step, + d_kh*d_kw,ci); + + auto BTile = reinterpret_cast( + reinterpret_cast(rhs.get()) + rhs_packed_offset + ); + + // Get lhs tile, A + const size_t lhs_packed_offset = + kai_get_lhs_packed_offset_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa(MIdx*m_step, + d_kh*d_kw,ci); + + auto ATile = reinterpret_cast( + reinterpret_cast(lhs.get()) + lhs_packed_offset + ); + + auto TileSizeM = (MIdx + 1) * m_step > m ? (m - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > co ? (co - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = &reinterpret_cast(result)[ + MIdx * m_step * co * sizeof(float) + + NIdx * n_step * sizeof(float)]; + + kai_run_imatmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme2_mopa( + TileSizeM, TileSizeN, d_kh*d_kw, ci, ATile, BTile, CTile, co * sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + }); + + if (result == tmp_mlas_aligned) { + //Note: this could be absorbed into post conv activation + MlasTranspose(tmp_mlas_aligned, out, m, co, ThreadPool); + } + + in += ci * ih * iw; + out += m * co; + weights += co * ci * kh * kw; + if(bias){ + bias += co; + } + } +} + +bool MLASCALL +ArmKleidiAI::MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool) +{ + //Check dimensions before accessing + if (Dimensions < 2) { + return false; + } + + Parameters->Activation = Activation; + Parameters->Dimensions = Dimensions; + Parameters->BatchCount = BatchCount; + Parameters->GroupCount = GroupCount; + Parameters->InputChannels = InputChannels; + Parameters->FilterCount = FilterCount; + Parameters->Beta = Beta; + + size_t InputSize = 1; + size_t OutputSize = 1; + size_t K = InputChannels; + + for (size_t dim = 0; dim < Dimensions; dim++) { + + Parameters->InputShape[dim] = size_t(InputShape[dim]); + Parameters->OutputShape[dim] = size_t(OutputShape[dim]); + Parameters->KernelShape[dim] = size_t(KernelShape[dim]); + Parameters->DilationShape[dim] = size_t(DilationShape[dim]); + Parameters->Padding[dim] = size_t(Padding[dim]); + Parameters->Padding[dim + Dimensions] = size_t(Padding[dim + Dimensions]); + Parameters->StrideShape[dim] = size_t(StrideShape[dim]); + + InputSize *= Parameters->InputShape[dim]; + OutputSize *= Parameters->OutputShape[dim]; + K *= Parameters->KernelShape[dim]; + } + + Parameters->InputSize = InputSize; + Parameters->OutputSize = OutputSize; + Parameters->K = K; + + Parameters->ThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if(!CheckCapabilitiesSme(Parameters)){ + return false; + } + + //Allocate an aligned buffer for MlasTranspose() + *WorkingBufferSize = ComputeMlasWorkingBufferSize(Parameters->FilterCount, + Parameters->InputShape[0], Parameters->InputShape[1], + Parameters->KernelShape[0], Parameters->KernelShape[1], + Parameters->DilationShape[0], Parameters->DilationShape[1], + Parameters->StrideShape[0], Parameters->StrideShape[1], + Parameters->Padding[0]); + return true; +} + +bool +MLASCALL +ArmKleidiAI::MlasConv( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ) +{ + if(!CheckCapabilitiesSme(Parameters)){ + //Fallback to Default Mlas + return false; + }; + ConvolveSme(Parameters->FilterCount, Parameters->InputChannels, // channel out, in + Parameters->InputShape[0], Parameters->InputShape[1], // image dimensions + Parameters->KernelShape[0], Parameters->KernelShape[1], // kernel dimensions + Parameters->StrideShape[0], Parameters->StrideShape[1], // kernel stride dimensions + Parameters->DilationShape[0], Parameters->DilationShape[1], // kernel dilation + Parameters->Padding[0], // image padding + Parameters->GroupCount, // filter groups + Filter, Bias, Input, Output, WorkingBuffer, ThreadPool); + + MlasActivation(Parameters->Activation, Output, nullptr, Parameters->FilterCount, Parameters->OutputSize, + Parameters->OutputSize); + return true; +} diff --git a/src/lib/kleidiai/mlasi_kleidiai.h b/src/lib/kleidiai/mlasi_kleidiai.h new file mode 100644 index 0000000..2e9c457 --- /dev/null +++ b/src/lib/kleidiai/mlasi_kleidiai.h @@ -0,0 +1,151 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "mlasi.h" + +// Fix to ensure compatibility with MSVC build +#if defined(_MSC_VER) + #define RESTRICT __restrict +#else + #define RESTRICT __restrict__ +#endif +namespace ArmKleidiAI { +// By default we should try for SME2 first before falling back to SME. +inline const bool UseSME2 = MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME2(); + +// +// Buffer packing routines. +// + +size_t +MLASCALL +MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K + ); + +bool +MLASCALL +MlasGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB + ); + +bool +MLASCALL +MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool + ); + +size_t +MLASCALL +MlasDynamicQgemmPackBSize( + size_t N, + size_t K +); + +void +MLASCALL +MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +); + +//pack symmetric quantized B and dynamic quantized A +void +MLASCALL +MlasDynamicQGemmBatch( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool + ); + +bool +MLASCALL +MlasConvPrepare(MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool); + +bool +MLASCALL +MlasConv( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +} + +/*++ + +Routine Description: + + This routine determines if a wraparound will occur when multiplying two size_t variables + Uses __builtin_mul_overflow if available on the current system and if not falls back + to a default implementation to check this wraparound. + +Arguments: + + a - Supplies the first number to be muliplied. + + b - Supplies the second number to be muliplied. + + out - pointer to a size_t which acts as the return value in success cases. + +Return Value: + + Returns false if the operation was successful + Returns true if wraparound of size_t was detected + +--*/ +inline bool mul_overflow_size_t_builtin(size_t a, size_t b, size_t* out) { +#if defined(__has_builtin) +# if __has_builtin(__builtin_mul_overflow) + return __builtin_mul_overflow(a, b, out); +# endif +#endif + // Fallback to manual check if builtin not available + if (b != 0 && a > SIZE_MAX / b) return true; + if (out) *out = a * b; + return false; +} diff --git a/src/lib/kleidiai/qgemm_kleidiai.cpp b/src/lib/kleidiai/qgemm_kleidiai.cpp new file mode 100644 index 0000000..fb38f2c --- /dev/null +++ b/src/lib/kleidiai/qgemm_kleidiai.cpp @@ -0,0 +1,116 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" +#include "kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon.h" + +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h" + +#include "mlasi_kleidiai.h" + +//Matmul with float output of dynamic quantized A and symmetric quantized B. + +size_t +MLASCALL +ArmKleidiAI::MlasDynamicQgemmPackBSize( + size_t N, + size_t K +) { + //Default to sme2_mopa but this may not awalys be the most optimal kernel variant to use + auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + //regardless of kernel variant use neon packing variant + return kai_get_rhs_packed_size_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(N, K, nr, kr, sr); +} + +void +MLASCALL +ArmKleidiAI::MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +) { + // Default to sme2_mopa but this may not awalys be the most optimal kernel variant to use + auto nr = kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + // y - float output + // scale_factor_lhs - lhs scaling factor + // scale_factor_rhs - rhs scaling factor + // lhs_q - lhs quantized (asymmetric, so has zero point) + // rhs_q - rhs quantized (symmetric so no zero point) + // lhs_zp - lhs zero point + // y = (1/(scale_factor_lhs * scale_factor_rhs) * sum( (lhs_q + lhs_zp)*rhs_q )) + bias + + // rhs packing requires lhs_zp because it will perform lhs_zp*rhs_q during rhs packing + // because lhs quantization is hidden from us, by lhs quant packing, we don't have a value for lhs_zp it is + // lhs dynamic quantization + + kai_rhs_pack_qsi8cx_params params{ + 1, // lhs_zp - set to 1 so it becomes sum((lhs_q + 1)*rhs_q )), + // the actual lhs_zp is applied during the matmul + 1.f // it is not used + }; + + //regardless of kernel variant use neon packing variant + kai_run_rhs_pack_kxn_qsi8cxp_qsi8cx_neon(1, N, K, nr, kr, sr, B, + // N bias values + Bias, + // N scale values + Scales, PackedB, 0, ¶ms); +} + +void +MLASCALL +ArmKleidiAI::MlasDynamicQGemmBatch( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +) { + for (auto b = BatchN; b > 0; --b,++DataParams) { + auto mr = kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto kr = kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + auto sr = kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa(); + + + //TODO enable multi-threading for lhs packing and matmul + MLAS_UNREFERENCED_PARAMETER(ThreadPool); + + //Dynamic Quantize A - lhs + auto lhs_size = kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr); + std::byte* lhs = nullptr; + std::unique_ptr fallback; + + if (DataParams->Workspace && DataParams->WorkspaceSize >= lhs_size) { + lhs = static_cast(DataParams->Workspace); + } else { + fallback = std::make_unique(lhs_size); + lhs = fallback.get(); + } + + kai_run_lhs_quant_pack_qai8dxp_f32(Shape.M, Shape.K, mr, kr, sr, 0, DataParams->A, + Shape.K*sizeof(float), lhs); + + kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa( + Shape.M, Shape.N, Shape.K, lhs, DataParams->PackedB, + DataParams->C, + Shape.N * sizeof(float), + sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } +} diff --git a/src/lib/kleidiai/sgemm_kleidiai.cpp b/src/lib/kleidiai/sgemm_kleidiai.cpp new file mode 100644 index 0000000..435ff1f --- /dev/null +++ b/src/lib/kleidiai/sgemm_kleidiai.cpp @@ -0,0 +1,432 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa.h" +#include "kai/ukernels/matmul/matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa.h" +#include "kai/ukernels/matmul/pack/kai_lhs_pack_f32p2vlx1_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme.h" +#include "kai/ukernels/matmul/pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme.h" +#include "mlasi_kleidiai.h" + + +// Thread-local reusable buffers to reduce allocation overhead across tiles. +struct KaiTlsBuffers { + std::vector output_tile; + std::vector bias_zero; + std::vector rhs_packed; + std::vector lhs_packed; +}; +static thread_local KaiTlsBuffers g_kai_tls; + +size_t +MLASCALL +ArmKleidiAI::MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K +) +/*++ + +Routine Description: + + This routine computes the length in bytes for the packed matrix B buffer. + +Arguments: + + TransA - Supplies the transpose operation on A matrix + + TransB - Supplies the transpose operation on B matrix + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + +Return Value: + + Returns the size in bytes for the packed matrix B buffer. + +--*/ +{ + if (TransA != CblasNoTrans || N == 0 || K == 0) { + return 0; + } + // + // Compute the number of bytes required to hold the packed buffer. + // + size_t bytes = 0; + if (TransA == CblasNoTrans) { + switch (TransB) { + case CblasNoTrans: + bytes = kai_get_rhs_packed_size_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(N, K); + break; + case CblasTrans: + bytes = kai_get_rhs_packed_size_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(N, K); + break; + default: + return 0; + } + } else { + return 0; + } + + return bytes; +} + +bool +MLASCALL +ArmKleidiAI::MlasGemmPackB( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB +) +/*++ + +Routine Description: + + This routine packs the contents of matrix B to the destination buffer. The + destination buffer should be sized based on MlasGemmPackBSize(). For best + performance, the destination buffer should be aligned to the value returned + from MlasGetPreferredBufferAlignment(). + +Arguments: + + TransA - Supplies the transpose operation for matrix A. + + TransB - Supplies the transpose operation for matrix B. + + N - Supplies the number of columns of matrix B. + + K - Supplies the number of rows of matrix B. + + B - Supplies the address of matrix B. + + ldb - Supplies the first dimension of matrix B. + + PackedB - Supplies the address of packed matrix B. + +Return Value: + + Returns true if the packing operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. + +--*/ +{ + if (N == 0 || K == 0) { + return false; + } + + if (TransA == CblasNoTrans) { + const size_t nr = UseSME2 ? kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_nr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + + // Ensure size and zero the used span. + g_kai_tls.bias_zero.resize(N, 0.0f); + + switch (TransB) { + case CblasNoTrans: + kai_run_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr); + break; + case CblasTrans: + kai_run_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme(1, N, K, nr, kr, sr, ldb * sizeof(float), B, g_kai_tls.bias_zero.data(), nullptr, PackedB, 0, nullptr); + break; + default: + return false; + } + return true; + } + else{ + return false; + } +} + +bool +MLASCALL +ArmKleidiAI::MlasGemmBatch( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool +) +/*++ + +Routine Description: + + This routine performs a batched matrix multiplication (GEMM) operation using KleidiAI kernels. + It handles both packed and unpacked inputs and manages tiling and kernel selection depending on + SME2 availability. If packing is needed, it prepares the required buffers and invokes the + appropriate left-hand side (LHS) and right-hand side (RHS) pack functions. + + The function also applies alpha and beta scaling to the result, supports efficient memcpy + paths where possible, and dispatches tile-level GEMM work using multithreading. + +Arguments: + + TransA - Supplies the transpose operation for matrix A. + + TransB - Supplies the transpose operation for matrix B. + + M - Supplies the number of rows of matrix A and matrix C. + + N - Supplies the number of columns of matrix B and matrix C. + + K - Supplies the number of columns of matrix A and rows of matrix B. + + Data - Supplies a pointer to the MLAS_SGEMM_DATA_PARAMS array containing per-batch input/output pointers and parameters. + + BatchSize - Supplies the number of independent GEMM computations to perform in the batch. + + ThreadPool - Supplies the thread pool to parallelize computation across batches and tiles. + +Return Value: + + Returns true if the GEMM operation was handled by KleidiAI. + Returns false if the configuration requires a fallback to the default MLAS implementation. + +--*/ +{ + if (M == 0 || N == 0) { + return true; + } + + if (Data->alpha == 0.0f || K == 0) { + if (Data->beta == 0.0f) { + for (size_t i = 0; i < M; ++i) { + std::fill_n(Data->C + i * Data->ldc, N, 0.0f); + } + } else if (Data->beta != 1.0f) { + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < N; ++j) { + Data->C[i * Data->ldc + j] *= Data->beta; + } + } + } + return true; + } + + const size_t mr = UseSME2 ? kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_mr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t kr = UseSME2 ? kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_kr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + const size_t sr = UseSME2 ? kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_sr_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + + size_t m_step = UseSME2 ? kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_m_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + size_t n_step = UseSME2 ? kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa() + : kai_get_n_step_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(); + + if ((M < m_step || N < n_step) && !Data->BIsPacked) { + // Fallback to MLAS + return false; + } + + size_t LhsPackedStride = 0; + std::byte* LhsPackedData = nullptr; + + LhsPackedStride = kai_get_lhs_packed_size_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr); + + size_t lhs_resize = 0; + if(mul_overflow_size_t_builtin(LhsPackedStride, BatchSize, &lhs_resize)) + { + // size_t wraparound detected for LhsPackedStride, fallback to MLAS + return false; + } + + g_kai_tls.lhs_packed.resize(lhs_resize); + LhsPackedData = g_kai_tls.lhs_packed.data(); + + // RHS packed buffer: use TLS reusable vector to minimize allocations + size_t RhsPackedStride = 0; + std::byte* RhsPackedData = nullptr; + + // It is assumed all B batches require packing or not + if (Data[0].BIsPacked) { + // We have already decided the matmul variant we are using, before having values for M,N,K + MlasTrySimpleParallel(ThreadPool, BatchSize, [&](ptrdiff_t batch_idx) { + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + }); + } else { + // Multithread pack lhs and rhs + RhsPackedStride = ArmKleidiAI::MlasGemmPackBSize(TransA, TransB, N, K); + size_t rhs_resize = 0; + if (mul_overflow_size_t_builtin(RhsPackedStride, BatchSize, &rhs_resize)) + { + // size_t wraparound detected for RhsPackedStride, fallback to MLAS + return false; + } + + g_kai_tls.rhs_packed.resize(rhs_resize); + RhsPackedData = g_kai_tls.rhs_packed.data(); + + MlasTrySimpleParallel(ThreadPool, BatchSize * 2, [&](ptrdiff_t batch_idx) { + if (batch_idx & 0x1) { + batch_idx >>= 1; + std::byte* LhsPackedPtr = &(LhsPackedData[LhsPackedStride * batch_idx]); + kai_run_lhs_pack_f32p2vlx1_f32_sme(M, K, mr, kr, sr, 0, Data[batch_idx].A, Data[batch_idx].lda * sizeof(float), LhsPackedPtr); + } else { + batch_idx >>= 1; + std::byte* RhsPackedPtr = &(RhsPackedData[RhsPackedStride * batch_idx]); + ArmKleidiAI::MlasGemmPackB(TransA, TransB, N, K, + reinterpret_cast(Data[batch_idx].B), + Data[batch_idx].ldb, RhsPackedPtr); + } + }); + } + + // tile iteration dimensions + std::array dim; + dim[0] = BatchSize; // B + dim[1] = MlasDivRoundup(M, m_step); // M + dim[2] = MlasDivRoundup(N, n_step); // N + + // Minimize the kernel call count for the number of available threads + auto RequiredTiles = std::min(static_cast(MlasGetMaximumThreadCount(ThreadPool)), dim[0] * dim[1] * dim[2]); + + // scale required tiles over available tile processors + dim[1] = MlasDivRoundup(RequiredTiles * dim[1], dim[1] * dim[2]); + dim[2] = MlasDivRoundup(RequiredTiles * dim[2], dim[1] * dim[2]); + + // compute new step sizes + m_step *= MlasDivRoundup(MlasDivRoundup(M, dim[1]), m_step); + n_step *= MlasDivRoundup(MlasDivRoundup(N, dim[2]), n_step); + + // update tile iterations + dim[1] = MlasDivRoundup(M, m_step); + dim[2] = MlasDivRoundup(N, n_step); + + // Pre-check maximum tile size to avoid per-iteration overflow inside the parallel loop. + // Any TileSizeM/TileSizeN used below will be <= m_step/n_step respectively. + size_t max_tile_elems = 0; + if (mul_overflow_size_t_builtin(m_step, n_step, &max_tile_elems)) { + // size_t wraparound detected for tile size, fallback to MLAS + return false; + } + + MlasTrySimpleParallel(ThreadPool, static_cast(dim[0] * dim[1] * dim[2]), [=](ptrdiff_t tid) { + // compute B,M,N index from iteration index + ptrdiff_t BIdx = tid / (dim[1] * dim[2]); + ptrdiff_t MIdx = (tid % (dim[1] * dim[2])) / dim[2]; + ptrdiff_t NIdx = (tid % (dim[1] * dim[2])) % dim[2]; + + // Get rhs tile, B + const size_t rhs_packed_offset = + UseSME2 ? kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(NIdx * n_step, K) + : kai_get_rhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(NIdx * n_step, K); + + const std::byte* B_base = Data[0].BIsPacked + ? reinterpret_cast(Data[BIdx].B) + : (RhsPackedData + RhsPackedStride * BIdx); + auto BTile = reinterpret_cast(B_base + rhs_packed_offset); + + // Get lhs tile, A + const size_t lhs_packed_offset = + UseSME2 ? kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa(MIdx * m_step, K) + : kai_get_lhs_packed_offset_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa(MIdx * m_step, K); + + const std::byte* A_base = LhsPackedData + LhsPackedStride * BIdx; + auto ATile = reinterpret_cast(A_base + lhs_packed_offset); + + auto TileSizeM = (MIdx + 1) * m_step > M ? (M - MIdx * m_step) : m_step; + auto TileSizeN = (NIdx + 1) * n_step > N ? (N - NIdx * n_step) : n_step; + + // Get result tile, C + auto CTile = reinterpret_cast( + reinterpret_cast(Data[BIdx].C) + + MIdx * m_step * Data[BIdx].ldc * sizeof(float) + + NIdx * n_step * sizeof(float) + ); + // Allocate temporary buffer for raw A*B result (TLS reusable buffer) + size_t tile_elems = TileSizeM * TileSizeN; + + // resize the tile to the required size + g_kai_tls.output_tile.resize(tile_elems); + + float* temp_tile = g_kai_tls.output_tile.data(); + std::fill_n(temp_tile, tile_elems, 0.0f); + + if (UseSME2) { + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa( + TileSizeM, + TileSizeN, + K, + ATile, BTile, temp_tile, + TileSizeN * sizeof(float), sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } else { + kai_run_matmul_clamp_f32_f32p2vlx1_f32p2vlx1b_2vlx2vl_sme_mopa( + TileSizeM, + TileSizeN, + K, + ATile, BTile, temp_tile, + TileSizeN * sizeof(float), sizeof(float), + -std::numeric_limits::max(), std::numeric_limits::max() + ); + } + + // Final output tile pointer + float* dst_tile = reinterpret_cast(CTile); + + // quick copy of data in cases where we are not scaling or accumulating anything + // with bounds checking on tile sizing to ensure the data fits in the memory block + bool can_memcpy = ( + Data[BIdx].alpha == 1.0f && + Data[BIdx].beta == 0.0f && + Data[BIdx].ldc == TileSizeN && + MIdx * m_step + TileSizeM <= M && + NIdx * n_step + TileSizeN <= N && + TileSizeM != 0 && + TileSizeN != 0); + + if (can_memcpy) { + std::memcpy(dst_tile, temp_tile, TileSizeM * TileSizeN * sizeof(float)); + return; + } + + float alpha = Data[BIdx].alpha; + float beta = Data[BIdx].beta; + size_t ldc = Data[BIdx].ldc; + + for (size_t i = 0; i < TileSizeM; ++i) { + for (size_t j = 0; j < TileSizeN; ++j) { + const size_t temp_idx = i * TileSizeN + j; + const size_t dst_idx = i * ldc + j; + + float ab = temp_tile[temp_idx]; + float c_orig = dst_tile[dst_idx]; + + if (alpha == 1.0f && beta == 0.0f) { + dst_tile[dst_idx] = ab; + } else if (alpha == 1.0f) { + dst_tile[dst_idx] = ab + beta * c_orig; + } else if (beta == 0.0f) { + dst_tile[dst_idx] = alpha * ab; + } else { + dst_tile[dst_idx] = alpha * ab + beta * c_orig; + } + } + } + return; + }); + return true; +} diff --git a/src/lib/logistic.cpp b/src/lib/logistic.cpp index ecca39f..f8244b7 100644 --- a/src/lib/logistic.cpp +++ b/src/lib/logistic.cpp @@ -110,7 +110,10 @@ Return Value: q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.beta_2)); q = MlasMultiplyAddFloat32x4(q, ValueSquared, MlasBroadcastFloat32x4(MlasLogisticConstants.beta_0)); - MlasStoreFloat32x4(Output, MlasAddFloat32x4(MlasDivideFloat32x4(p, q), MlasBroadcastFloat32x4(0.5f))); + MlasStoreFloat32x4(Output, MlasClampFloat32x4( + MlasAddFloat32x4(MlasDivideFloat32x4(p, q), MlasBroadcastFloat32x4(0.5f)), + 0.0f, + 1.0f)); Input += 4; Output += 4; @@ -145,7 +148,7 @@ Return Value: q = q * ValueSquared + MlasLogisticConstants.beta_2; q = q * ValueSquared + MlasLogisticConstants.beta_0; - *Output++ = (p / q) + 0.5f; + *Output++ = std::clamp((p / q) + 0.5f, 0.0f, 1.0f); N -= 1; } @@ -178,7 +181,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_USE_SVE) GetMlasPlatform().LogisticKernelRoutine(Input, Output, N); #else MlasLogisticKernel(Input, Output, N); diff --git a/src/lib/mlasi.h b/src/lib/mlasi.h index 0dd0165..d25f1f2 100644 --- a/src/lib/mlasi.h +++ b/src/lib/mlasi.h @@ -70,6 +70,9 @@ Module Name: #undef pixel #undef bool #endif +#if defined(__s390x__) +#include +#endif #if defined(__loongarch64) #include #endif @@ -118,9 +121,8 @@ Module Name: #ifdef MLAS_NO_EXCEPTION -MLAS_FORCEINLINE -void -MlasPrintFinalMessage(const std::string& msg) +MLAS_FORCEINLINE void + MlasPrintFinalMessage(const std::string& msg) { #if defined(__ANDROID__) __android_log_print(ANDROID_LOG_ERROR, "mlas", "%s", msg.c_str()); @@ -134,6 +136,7 @@ MlasPrintFinalMessage(const std::string& msg) #endif } + #define MLAS_THROW_EX(ex, what) \ do { \ std::string msg = #ex; \ @@ -162,7 +165,7 @@ MlasPrintFinalMessage(const std::string& msg) #include "core/common/cpuid_info.h" using MLAS_CPUIDINFO = onnxruntime::CPUIDInfo; -#include "core/framework/float16.h" +#include "core/common/float16.h" #else // BUILD_MLAS_NO_ONNXRUNTIME @@ -192,6 +195,8 @@ class MLASCPUIDInfo bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + bool HasArmSVE() const { return has_arm_sve_; } + bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } @@ -202,6 +207,7 @@ class MLASCPUIDInfo bool has_arm_neon_dot_{false}; bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_{false}; bool has_arm_sve_i8mm_{false}; bool has_arm_neon_bf16_{false}; }; @@ -252,66 +258,41 @@ enum MlasUArch { // Define MLAS_FP16 // #include "mlas_float16.h" -#include "../ort_include/core/session/onnxruntime_float16.h" namespace onnxruntime { -// MLFloat16 -struct MLFloat16 : onnxruntime_float16::Float16Impl { - private: - explicit constexpr MLFloat16(uint16_t x) noexcept { val = x; } - - public: - using Base = onnxruntime_float16::Float16Impl; - - MLFloat16() = default; - - constexpr static MLFloat16 FromBits(uint16_t x) noexcept { return MLFloat16(x); } - - // Using inherited implementation instead of math floatToHalf allows us to use this - // in other shared providers without having to implement the bridge - explicit MLFloat16(float v) noexcept { val = Base::ToUint16Impl(v); } - - static const MLFloat16 NaN; - static const MLFloat16 NegativeNaN; - static const MLFloat16 Infinity; - static const MLFloat16 NegativeInfinity; - static const MLFloat16 MaxValue; - static const MLFloat16 Zero; - static const MLFloat16 One; - static const MLFloat16 MinusOne; - - // Using inherited implementation instead of math halfToFloat allows us to use this - // in other shared providers without having to implement the bridge - float ToFloat() const noexcept { return Base::ToFloatImpl(); } - - using Base::IsNegative; - - using Base::IsNaN; - - using Base::IsFinite; - - using Base::IsPositiveInfinity; - - using Base::IsNegativeInfinity; - - using Base::IsInfinity; - - using Base::IsNaNOrZero; - - using Base::IsNormal; +struct MLFloat16 { + uint16_t val{0}; - using Base::IsSubnormal; + MLFloat16() = default; + explicit constexpr MLFloat16(uint16_t x) : val(x) {} + explicit MLFloat16(float ff) : val(MLAS_Float2Half(ff)) {} + constexpr static MLFloat16 FromBits(uint16_t x) noexcept { return MLFloat16(x); } - using Base::Abs; + MLFloat16 Abs() const noexcept { + return MLFloat16(static_cast(val & ~kSignMask)); + } + bool IsNaN() const noexcept { + return Abs().val > kPositiveInfinityBits; + } + bool IsNegative() const noexcept { + return static_cast(val) < 0; + } + MLFloat16 Negate() const { + return MLFloat16(IsNaN() ? val : static_cast(val ^ kSignMask)); + } + static constexpr uint16_t kSignMask = 0x8000U; + static constexpr uint16_t kPositiveInfinityBits = 0x7C00U; - using Base::Negate; + float ToFloat() const { return MLAS_Half2Float(val); } - operator float() const noexcept { return ToFloat(); } + operator float() const { return ToFloat(); } - using Base::operator==; - using Base::operator!=; - using Base::operator<; + MLFloat16& operator=(float ff) + { + val = MLAS_Float2Half(ff); + return *this; + } }; inline bool @@ -371,7 +352,7 @@ static_assert(sizeof(MLAS_FP16) == FP16_SIZE); // #if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || \ - defined(MLAS_TARGET_LARCH64) + defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_S390X) typedef size_t @@ -772,6 +753,24 @@ void float Scale, int8_t ZeroPoint); +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_U8_KERNEL)( + const uint8_t* Input, + float* Output, + size_t N, + float Scale, + uint8_t ZeroPoint); + +typedef +void +(MLASCALL MLAS_DEQUANTIZE_LINEAR_S8_KERNEL)( + const int8_t* Input, + float* Output, + size_t N, + float Scale, + int8_t ZeroPoint); + template struct MLAS_QUANT_KERNEL { @@ -788,6 +787,119 @@ struct MLAS_QUANT_KERNEL size_t KernelSize ); }; +typedef +void +(MLASCALL MLAS_CONV_FLOAT_FN)( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +typedef +bool +(MLASCALL MLAS_CONV_FLOAT_OVERRIDE)( + const MLAS_CONV_PARAMETERS* Parameters, + const float* Input, + const float* Filter, + const float* Bias, + float* WorkingBuffer, + float* Output, + MLAS_THREADPOOL* ThreadPool + ); +// TODO: Investigate if overridden typedefs can be removed +typedef +void +(MLASCALL MLAS_CONV_PREPARE_FLOAT_FN)( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool + ); +typedef +bool +(MLASCALL MLAS_CONV_PREPARE_FLOAT_OVERRIDE)( + MLAS_CONV_PARAMETERS* Parameters, + size_t Dimensions, + size_t BatchCount, + size_t GroupCount, + size_t InputChannels, + const int64_t* InputShape, + const int64_t* KernelShape, + const int64_t* DilationShape, + const int64_t* Padding, + const int64_t* StrideShape, + const int64_t* OutputShape, + size_t FilterCount, + const MLAS_ACTIVATION* Activation, + size_t* WorkingBufferSize, + float Beta, + MLAS_THREADPOOL* ThreadPool + ); + +typedef void (MLASCALL MLAS_GEMM_BATCH)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool); + +typedef bool (MLASCALL MLAS_GEMM_BATCH_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t M, + size_t N, + size_t K, + const MLAS_SGEMM_DATA_PARAMS* Data, + size_t BatchSize, + MLAS_THREADPOOL* ThreadPool); + +typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef size_t (MLASCALL MLAS_GEMM_PACK_B_SIZE_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K); + +typedef void (MLASCALL MLAS_GEMM_PACK_B_KERNEL)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); + +typedef bool (MLASCALL MLAS_GEMM_PACK_B_KERNEL_OVERRIDE)( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, + size_t N, + size_t K, + const float* B, + size_t ldb, + void* PackedB); extern "C" { @@ -813,6 +925,12 @@ extern "C" { MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelPOWER10; MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelVSX; MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelVSX; +#elif defined(MLAS_TARGET_S390X) + MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernel; + MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZVECTOR; + MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernel; + MLAS_QUANTIZE_LINEAR_S8_KERNEL MlasQuantizeLinearS8KernelZVECTOR; + MLAS_QUANTIZE_LINEAR_U8_KERNEL MlasQuantizeLinearU8KernelZVECTOR; #elif defined(MLAS_TARGET_LARCH64) MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLSX; MLAS_GEMM_FLOAT_KERNEL MlasGemmFloatKernelLasx; @@ -843,6 +961,15 @@ extern "C" { #if defined(__aarch64__) && defined(__linux__) MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif +#if defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC) + MLAS_CONV_FLOAT_KERNEL MlasConvNchwFloatKernelNeon; + MLAS_CONV_FLOAT_KERNEL MlasConvNchwcFloatKernelNeon; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL MlasConvDepthwiseFloatKernelNeon; + MLAS_CONV_POINTWISE_FLOAT_KERNEL MlasConvPointwiseFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolMaximumFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageExcludePadFloatKernelNeon; + MLAS_POOL_FLOAT_KERNEL MlasPoolAverageIncludePadFloatKernelNeon; #endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; @@ -928,6 +1055,8 @@ extern "C" { MLAS_QUANTIZE_LINEAR_S4_KERNEL MlasQuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL MlasQuantizeLinearU4Kernel; #if defined(MLAS_TARGET_AMD64) + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL MlasDequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL MlasDequantizeLinearU8Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasErfKernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelFma3; MLAS_COMPUTE_UNARY_FLOAT_KERNEL MlasComputeExpF32KernelAvx512F; @@ -1041,6 +1170,7 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmRelaxedSimd; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchZVECTOR; #if defined(MLAS_TARGET_WASM_RELAXED_SIMD) extern bool HasUSDot(); @@ -1092,7 +1222,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH; const MLAS_QNBIT_GEMM_DISPATCH& GetMlasQNBitGemmDispatchNeon( - bool InitializeWithDotSupport + bool InitializeWithDotSupport, + bool InitializeWithI8MMSupport ); extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2; @@ -1103,6 +1234,8 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512; extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni; +extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchLasx; + // // Rotary embedding dispatch structure. // @@ -1189,8 +1322,16 @@ struct MLAS_PLATFORM { // TODO: move to cpuinfo bool Avx2Supported_ = false; bool Avx512Supported_ = false; + bool ArmNeonIsQuantActivationsUnsigned = false; -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) + // Mlas overrides initialisation + MLAS_GEMM_BATCH_OVERRIDE* MlasGemmBatchOverride = nullptr; + MLAS_GEMM_PACK_B_SIZE_OVERRIDE* MlasGemmPackBSizeOverride = nullptr; + MLAS_GEMM_PACK_B_KERNEL_OVERRIDE* MlasGemmPackBOverride = nullptr; + MLAS_CONV_PREPARE_FLOAT_OVERRIDE* MlasConvPrepareOverride = nullptr; + MLAS_CONV_FLOAT_OVERRIDE* MlasConvOverride = nullptr; + +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) MLAS_GEMM_FLOAT_KERNEL* GemmFloatKernel; #endif #if defined(MLAS_TARGET_LARCH64) @@ -1218,6 +1359,14 @@ struct MLAS_PLATFORM { const MLAS_GEMM_QUANT_DISPATCH* GemmU8U8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmU8S8Dispatch; const MLAS_GEMM_QUANT_DISPATCH* GemmS8S8Dispatch; +#if defined(MLAS_USE_ARM_NEON_NCHWC) + MLAS_CONV_FLOAT_KERNEL* ConvNchwFloatKernel; + MLAS_CONV_FLOAT_KERNEL* ConvNchwcFloatKernel; + MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; + MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; + MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; + uint32_t NchwcBlockSize; +#endif #endif const MLAS_SYMM_QGEMM_DISPATCH* SymmQgemmDispatch{nullptr}; @@ -1229,7 +1378,7 @@ struct MLAS_PLATFORM { MLAS_QUANT_KERNEL::DepthwiseKernel* ConvDepthwiseS8S8Kernel; MLAS_QUANT_KERNEL::DepthwiseKernel* ConvDepthwiseS8U8Kernel; -#if defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) MLAS_GEMM_DOUBLE_KERNEL* GemmDoubleKernel; const MLAS_GEMM_QUANT_DISPATCH* GemmU8X8Dispatch; MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; @@ -1239,6 +1388,15 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; #endif + +#if defined(MLAS_USE_SVE) || defined(MLAS_TARGET_AMD64) + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* ErfKernelRoutine; + MLAS_COMPUTE_UNARY_FLOAT_KERNEL* LogisticKernelRoutine; + MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL* ReduceMaximumF32Kernel; + MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL* ComputeSumExpF32Kernel; + MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; + MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; +#endif #if defined(MLAS_TARGET_AMD64) MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1Routine; MLAS_SGEMM_KERNEL_M1_ROUTINE* KernelM1TransposeBRoutine; @@ -1254,16 +1412,10 @@ struct MLAS_PLATFORM { MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* ConvDepthwiseFloatKernel; MLAS_CONV_POINTWISE_FLOAT_KERNEL* ConvPointwiseFloatKernel; MLAS_POOL_FLOAT_KERNEL* PoolFloatKernel[MlasPoolingKindCount]; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL* ErfKernelRoutine; MLAS_QLINEAR_BINARY_OP_S8_KERNEL* QLinearAddS8Kernel; MLAS_QLINEAR_BINARY_OP_U8_KERNEL* QLinearAddU8Kernel; MLAS_COMPUTE_UNARY_FLOAT_KERNEL* ComputeExpF32Kernel; - MLAS_COMPUTE_UNARY_FLOAT_KERNEL* LogisticKernelRoutine; MLAS_COMPUTE_UNARY_FLOAT_KERNEL* TanhKernelRoutine; - MLAS_COMPUTE_SUMEXP_FLOAT_KERNEL* ComputeSumExpF32Kernel; - MLAS_COMPUTE_SOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeSoftmaxOutputF32Kernel; - MLAS_COMPUTE_LOGSOFTMAX_OUTPUT_FLOAT_KERNEL* ComputeLogSoftmaxOutputF32Kernel; - MLAS_REDUCE_MAXIMUM_FLOAT_KERNEL* ReduceMaximumF32Kernel; MLAS_REDUCE_MINIMUM_MAXIMUM_FLOAT_KERNEL* ReduceMinimumMaximumF32Kernel; MLAS_QUANTIZE_LINEAR_S8_KERNEL* QuantizeLinearS8Kernel; MLAS_QUANTIZE_LINEAR_U8_KERNEL* QuantizeLinearU8Kernel; @@ -1271,11 +1423,14 @@ struct MLAS_PLATFORM { MLAS_QUANTIZE_LINEAR_U16_KERNEL* QuantizeLinearU16Kernel; MLAS_QUANTIZE_LINEAR_S4_KERNEL* QuantizeLinearS4Kernel; MLAS_QUANTIZE_LINEAR_U4_KERNEL* QuantizeLinearU4Kernel; + MLAS_DEQUANTIZE_LINEAR_S8_KERNEL* DequantizeLinearS8Kernel; + MLAS_DEQUANTIZE_LINEAR_U8_KERNEL* DequantizeLinearU8Kernel; uint32_t NchwcBlockSize; uint32_t PreferredBufferAlignment; int32_t MaximumThreadCount; #elif defined(MLAS_TARGET_ARM64) static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT * 4; + static constexpr size_t MLAS_NEON_NCHWC_BLOCK_SIZE = 16; #else static constexpr int32_t MaximumThreadCount = MLAS_MAXIMUM_THREAD_COUNT; #endif @@ -1485,6 +1640,8 @@ MlasConvDepthwiseFloat_CHW( #define MLAS_NEON64_INTRINSICS #elif defined(MLAS_TARGET_POWER) #define MLAS_VSX_INTRINSICS +#elif defined(MLAS_TARGET_S390X) +#define MLAS_ZVECTOR_INTRINSICS #elif defined(MLAS_TARGET_AMD64_IX86) #define MLAS_SSE2_INTRINSICS #if defined(__SSE4_1__) || (defined(_MSC_VER) && defined(__AVX__)) @@ -1554,6 +1711,8 @@ MlasCastToInt32x4(MLAS_FLOAT32X4 Vector) return _mm_cvttps_epi32(Vector); #elif defined(MLAS_VSX_INTRINSICS) return vec_cts(Vector, 0); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_signed(Vector); #elif defined(MLAS_LSX_INTRINSICS) return __lsx_vftint_w_s(Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) @@ -1573,6 +1732,8 @@ MlasCastToFloat32x4(MLAS_INT32X4 Vector) return _mm_cvtepi32_ps(Vector); #elif defined(MLAS_VSX_INTRINSICS) return vec_ctf(Vector, 0); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_float(Vector); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_convert_i32x4(Vector); #elif defined(MLAS_LSX_INTRINSICS) @@ -1592,7 +1753,7 @@ MlasBroadcastInt32x4(int32_t Value) return _mm_set1_epi32(Value); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_splat(Value); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_splats(Value); #elif defined(MLAS_LSX_INTRINSICS) return __lsx_vreplgr2vr_w(Value); @@ -1611,6 +1772,8 @@ MlasLoadInt32x4(const int32_t* Buffer) return _mm_loadu_si128((const __m128i*)Buffer); #elif defined(MLAS_VSX_INTRINSICS) return vec_vsx_ld(0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_xl(0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load(Buffer); #elif defined(MLAS_LSX_INTRINSICS) @@ -1630,6 +1793,8 @@ MlasStoreInt32x4(int32_t* Buffer, MLAS_INT32X4 Vector) _mm_storeu_si128((__m128i*)Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) vec_vsx_st(Vector, 0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + vec_xst(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); #elif defined(MLAS_LSX_INTRINSICS) @@ -1736,7 +1901,7 @@ MlasXorInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return _mm_xor_si128(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_xor(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_xor(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) return __lsx_vxor_v(Vector1, Vector2); @@ -1782,6 +1947,8 @@ MlasMaximumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return MlasBlendInt32x4(Vector2, Vector1, _mm_cmpgt_epi32(Vector1, Vector2)); #elif defined(MLAS_VSX_INTRINSICS) return vec_vmaxsw(Vector1, Vector2); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_max(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_max(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) @@ -1803,6 +1970,8 @@ MlasMinimumInt32x4(MLAS_INT32X4 Vector1, MLAS_INT32X4 Vector2) return MlasBlendInt32x4(Vector2, Vector1, _mm_cmpgt_epi32(Vector2, Vector1)); #elif defined(MLAS_VSX_INTRINSICS) return vec_vminsw(Vector1, Vector2); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_min(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_i32x4_min(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) @@ -1837,7 +2006,7 @@ MlasBroadcastFloat32x4(float Value) return _mm_set1_ps(Value); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_splat(Value); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) // Suppress wrong GCC warnings MLAS_UNREFERENCED_PARAMETER(Value); return vec_splats(Value); @@ -1858,7 +2027,7 @@ MlasBroadcastFloat32x4(const float* Value) return _mm_load_ps1(Value); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load32_splat(Value); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_splats(*Value); #elif defined(MLAS_LSX_INTRINSICS) return MLAS_FLOAT32X4{*Value, *Value, *Value, *Value}; @@ -1894,6 +2063,8 @@ MlasLoadFloat32x4(const float* Buffer) return _mm_loadu_ps(Buffer); #elif defined(MLAS_VSX_INTRINSICS) return vec_vsx_ld(0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_xl(0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_v128_load(Buffer); #elif defined(MLAS_LSX_INTRINSICS) @@ -1914,6 +2085,8 @@ MlasStoreFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) _mm_storeu_ps(Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) vec_vsx_st(Vector, 0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + vec_xst(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); #elif defined(MLAS_LSX_INTRINSICS) @@ -1936,6 +2109,8 @@ MlasStoreAlignedFloat32x4(float* Buffer, MLAS_FLOAT32X4 Vector) MLAS_UNREFERENCED_PARAMETER(Buffer); MLAS_UNREFERENCED_PARAMETER(Vector); vec_st(Vector, 0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + vec_xst(Vector, 0, Buffer); #elif defined(MLAS_WASM_SIMD_INTRINSICS) wasm_v128_store(Buffer, Vector); #elif defined(MLAS_LSX_INTRINSICS) @@ -2070,7 +2245,7 @@ MlasInterleaveLowFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return zipped.val[0]; #elif defined(MLAS_SSE2_INTRINSICS) return _mm_unpacklo_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_mergeh(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) return (MLAS_FLOAT32X4)__lsx_vilvl_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); @@ -2090,7 +2265,7 @@ MlasInterleaveHighFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return zipped.val[1]; #elif defined(MLAS_SSE2_INTRINSICS) return _mm_unpackhi_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return vec_mergel(Vector1, Vector2); #elif defined(MLAS_LSX_INTRINSICS) return (MLAS_FLOAT32X4)__lsx_vilvh_w(MlasReinterpretAsInt32x4(Vector2), MlasReinterpretAsInt32x4(Vector1)); @@ -2164,13 +2339,20 @@ MLAS_FLOAT32X4 MlasMultiplyAddFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2, MLAS_FLOAT32X4 Vector3) { #if defined(MLAS_NEON_INTRINSICS) +#if defined(MLAS_TARGET_ARM) + // ARMv7 NEON doesn't have vfmaq_f32() return vmlaq_f32(Vector3, Vector1, Vector2); +#else + return vfmaq_f32(Vector3, Vector1, Vector2); +#endif #elif defined(MLAS_FMA3_INTRINSICS) return _mm_fmadd_ps(Vector1, Vector2, Vector3); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_add_ps(_mm_mul_ps(Vector1, Vector2), Vector3); #elif defined(MLAS_VSX_INTRINSICS) return vec_madd(Vector1, Vector2, Vector3); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return __builtin_s390_vfmasb(Vector1, Vector2, Vector3); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_add(wasm_f32x4_mul(Vector1, Vector2), Vector3); #elif defined(MLAS_LSX_INTRINSICS) @@ -2227,7 +2409,7 @@ MlasGreaterThanFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return _mm_cmpgt_ps(Vector1, Vector2); #elif defined(MLAS_WASM_SIMD_INTRINSICS) return wasm_f32x4_gt(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return MLAS_FLOAT32X4(vec_cmpgt(Vector1, Vector2)); #elif defined(MLAS_LSX_INTRINSICS) return (MLAS_FLOAT32X4)__lsx_vfcmp_clt_s(Vector2, Vector1); @@ -2311,7 +2493,7 @@ MlasMaximumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vmaxq_f32(Vector1, Vector2); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_max_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) // Don't use vec_max to avoid undefined behavior if NAN return vec_sel(Vector2, Vector1, vec_cmpgt(Vector1, Vector2)); #elif defined(MLAS_WASM_RELAXED_SIMD_INTRINSICS) @@ -2333,7 +2515,7 @@ MlasMinimumFloat32x4(MLAS_FLOAT32X4 Vector1, MLAS_FLOAT32X4 Vector2) return vminq_f32(Vector1, Vector2); #elif defined(MLAS_SSE2_INTRINSICS) return _mm_min_ps(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) // Don't use vec_min to avoid undefined behavior if NAN return vec_sel(Vector2, Vector1, vec_cmpgt(Vector2, Vector1)); #elif defined(MLAS_WASM_RELAXED_SIMD_INTRINSICS) @@ -2374,7 +2556,7 @@ MlasReduceAddFloat32x4(MLAS_FLOAT32X4 Vector) VectorLow = vpadd_f32(VectorLow, VectorHigh); VectorLow = vpadd_f32(VectorLow, VectorHigh); return vget_lane_f32(VectorLow, 0); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) Vector = MlasAddFloat32x4(Vector, MLAS_FLOAT32X4(vec_splat((__vector long long)Vector, 1))); Vector = MlasAddFloat32x4(Vector, vec_splat(Vector, 1)); return Vector[0]; @@ -2397,7 +2579,7 @@ MlasReduceMaximumFloat32x4(MLAS_FLOAT32X4 Vector) VectorLow = vpmax_f32(VectorLow, VectorHigh); VectorLow = vpmax_f32(VectorLow, VectorHigh); return vget_lane_f32(VectorLow, 0); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) Vector = MlasMaximumFloat32x4(Vector, MLAS_FLOAT32X4(vec_splat((__vector long long)Vector, 1))); Vector = MlasMaximumFloat32x4(Vector, vec_splat(Vector, 1)); return Vector[0]; @@ -2420,7 +2602,7 @@ MlasReduceMinimumFloat32x4(MLAS_FLOAT32X4 Vector) VectorLow = vpmin_f32(VectorLow, VectorHigh); VectorLow = vpmin_f32(VectorLow, VectorHigh); return vget_lane_f32(VectorLow, 0); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) Vector = MlasMinimumFloat32x4(Vector, MLAS_FLOAT32X4(vec_splat((__vector long long)Vector, 1))); Vector = MlasMinimumFloat32x4(Vector, vec_splat(Vector, 1)); return Vector[0]; @@ -2446,7 +2628,7 @@ MlasPowerOf2Float32x4(MLAS_FLOAT32X4 Vector) #if defined(MLAS_SSE2_INTRINSICS) typedef __m128d MLAS_FLOAT64X2; -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) typedef __vector double MLAS_FLOAT64X2; #elif defined(MLAS_LSX_INTRINSICS) typedef __m128d MLAS_FLOAT64X2; @@ -2456,7 +2638,7 @@ typedef __m128d MLAS_FLOAT64X2; #ifndef MLAS_FLOAT64X2_UNSUPPORTED -#if defined(MLAS_VSX_INTRINSICS) +#if defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) template MLAS_FORCEINLINE double @@ -2505,7 +2687,7 @@ MlasBroadcastFloat64x2(double Value) { #if defined(MLAS_SSE2_INTRINSICS) return _mm_set1_pd(Value); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return MLAS_FLOAT64X2{Value, Value}; #elif defined(MLAS_LSX_INTRINSICS) return MLAS_FLOAT64X2{Value, Value}; @@ -2518,7 +2700,7 @@ MlasZeroFloat64x2(void) { #if defined(MLAS_SSE2_INTRINSICS) return _mm_setzero_pd(); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return MlasBroadcastFloat64x2(0.0f); #elif defined(MLAS_LSX_INTRINSICS) return MlasBroadcastFloat64x2(0.0f); @@ -2533,6 +2715,8 @@ MlasLoadFloat64x2(const double* Buffer) return _mm_loadu_pd(Buffer); #elif defined(MLAS_VSX_INTRINSICS) return vec_vsx_ld(0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + return vec_xl(0, Buffer); #elif defined(MLAS_LSX_INTRINSICS) return MLAS_FLOAT64X2(__lsx_vld((const MLAS_INT32X4 *)Buffer, 0)); #endif @@ -2546,6 +2730,8 @@ MlasStoreFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) _mm_storeu_pd(Buffer, Vector); #elif defined(MLAS_VSX_INTRINSICS) vec_vsx_st(Vector, 0, Buffer); +#elif defined(MLAS_ZVECTOR_INTRINSICS) + vec_xst(Vector, 0, Buffer); #elif defined(MLAS_LSX_INTRINSICS) (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); #endif @@ -2557,7 +2743,7 @@ MlasStoreAlignedFloat64x2(double* Buffer, MLAS_FLOAT64X2 Vector) { #if defined(MLAS_SSE2_INTRINSICS) _mm_store_pd(Buffer, Vector); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) *((MLAS_FLOAT64X2*)Buffer) = Vector; #elif defined(MLAS_LSX_INTRINSICS) (__lsx_vst(MLAS_INT32X4(Vector), Buffer, 0)); @@ -2570,7 +2756,7 @@ MlasMultiplyFloat64x2(MLAS_FLOAT64X2 Vector1, MLAS_FLOAT64X2 Vector2) { #if defined(MLAS_SSE2_INTRINSICS) return _mm_mul_pd(Vector1, Vector2); -#elif defined(MLAS_VSX_INTRINSICS) +#elif defined(MLAS_VSX_INTRINSICS) || defined(MLAS_ZVECTOR_INTRINSICS) return Vector1 * Vector2; #elif defined(MLAS_LSX_INTRINSICS) return __lsx_vfmul_d(Vector1, Vector2); diff --git a/src/lib/platform.cpp b/src/lib/platform.cpp index 7e875c2..796bfd1 100644 --- a/src/lib/platform.cpp +++ b/src/lib/platform.cpp @@ -16,6 +16,12 @@ Module Name: --*/ #include "mlasi.h" +#ifdef MLAS_USE_SVE +#include "sve/mlasi_sve.h" +#endif +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +#include "kleidiai/mlasi_kleidiai.h" +#endif #include #include @@ -31,6 +37,11 @@ Module Name: #endif #endif + +#if defined(MLAS_TARGET_S390X) +#include +#endif + #if defined(MLAS_TARGET_ARM64) #if defined(_WIN32) @@ -168,7 +179,6 @@ MlasReadExtendedControlRegister( #if defined(__linux__) #include -#include #endif bool @@ -286,6 +296,8 @@ Return Value: this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; + this->DequantizeLinearS8Kernel = MlasDequantizeLinearS8Kernel; + this->DequantizeLinearU8Kernel = MlasDequantizeLinearU8Kernel; #ifndef __APPLE__ #ifndef FORCE_GENERIC_ALGORITHMS this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelSse; @@ -553,6 +565,17 @@ Return Value: this->SoftmaxDispatch = &MlasSoftmaxDispatchNeon; this->EltwiseDispatch = &MlasEltwiseDispatchNeon; +#if defined(MLAS_USE_ARM_NEON_NCHWC) + this->ConvNchwFloatKernel = MlasConvNchwFloatKernelNeon; + this->ConvNchwcFloatKernel = MlasConvNchwcFloatKernelNeon; + this->ConvDepthwiseFloatKernel = MlasConvDepthwiseFloatKernelNeon; + this->ConvPointwiseFloatKernel = MlasConvPointwiseFloatKernelNeon; + this->PoolFloatKernel[MlasMaximumPooling] = MlasPoolMaximumFloatKernelNeon; + this->PoolFloatKernel[MlasAveragePoolingExcludePad] = MlasPoolAverageExcludePadFloatKernelNeon; + this->PoolFloatKernel[MlasAveragePoolingIncludePad] = MlasPoolAverageIncludePadFloatKernelNeon; + this->NchwcBlockSize = MLAS_NEON_NCHWC_BLOCK_SIZE; +#endif + // // Check if the processor supports ASIMD dot product instructions. // @@ -577,18 +600,51 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } - this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions); +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ + this->MlasGemmBatchOverride = ArmKleidiAI::MlasGemmBatch; + this->MlasGemmPackBSizeOverride = ArmKleidiAI::MlasGemmPackBSize; + this->MlasGemmPackBOverride = ArmKleidiAI::MlasGemmPackB; + this->MlasConvPrepareOverride = ArmKleidiAI::MlasConvPrepare; + this->MlasConvOverride = ArmKleidiAI::MlasConv; + } +#endif + +#if defined(MLAS_USE_SVE) + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmSve()) { + this->ErfKernelRoutine = MlasSveErfKernel; + this->LogisticKernelRoutine = MlasSveLogisticKernel; + this->ReduceMaximumF32Kernel = MlasSveReduceMaximumF32Kernel; + this->ComputeSumExpF32Kernel = MlasSveComputeSumExpF32Kernel; + this->ComputeLogSoftmaxOutputF32Kernel = MlasSveComputeLogSoftmaxOutputF32Kernel; + this->ComputeSoftmaxOutputF32Kernel = MlasSveComputeSoftmaxOutputF32Kernel; + } + else{ + this->ErfKernelRoutine = MlasErfKernel; + this->LogisticKernelRoutine = MlasLogisticKernel; + this->ReduceMaximumF32Kernel = MlasReduceMaximumF32Kernel; + this->ComputeSumExpF32Kernel = MlasComputeSumExpF32Kernel; + this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32Kernel; + this->ComputeSoftmaxOutputF32Kernel = MlasComputeSoftmaxOutputF32Kernel; + } +#endif -#if defined(__linux__) // // Check if the processor supports ASIMD I8MM instructions. // - if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + + const bool HasI8MMInstructions = MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM(); + if (HasI8MMInstructions) { +#if defined(__linux__) + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla; this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla; this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla; - } #endif + } + + this->ArmNeonIsQuantActivationsUnsigned = HasI8MMInstructions ? false : true; + this->QNBitGemmDispatch = &GetMlasQNBitGemmDispatchNeon(HasDotProductInstructions, HasI8MMInstructions); #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) this->CastF16ToF32Kernel = &MlasCastF16ToF32KernelNeon; @@ -612,6 +668,11 @@ Return Value: bool HasP9Instructions = hwcap2 & PPC_FEATURE2_ARCH_3_00; #elif defined(_AIX) bool HasP9Instructions = __power_9_andup(); +#elif defined(__FreeBSD__) + unsigned long hwcap2; + elf_aux_info(AT_HWCAP2, &hwcap2, sizeof(hwcap2)); + + bool HasP9Instructions = hwcap2 & PPC_FEATURE2_ARCH_3_00; #endif // __linux__ if (HasP9Instructions) { this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelVSX; @@ -621,7 +682,7 @@ Return Value: #if defined(POWER10) #if (defined(__GNUC__) && ((__GNUC__ > 10) || (__GNUC__== 10 && __GNUC_MINOR__ >= 2))) || \ (defined(__clang__) && (__clang_major__ >= 12)) -#if defined(__linux__) +#if defined(__linux__) || defined(__FreeBSD__) bool HasP10Instructions = ((hwcap2 & PPC_FEATURE2_MMA) && (hwcap2 & PPC_FEATURE2_ARCH_3_1)); #elif defined(_AIX) bool HasP10Instructions = (__power_10_andup() && __power_mma_version() == MMA_V31); @@ -636,6 +697,26 @@ Return Value: #endif // MLAS_TARGET_POWER +#if defined(MLAS_TARGET_S390X) + this->GemmFloatKernel = MlasSgemmKernel; + this->GemmDoubleKernel = MlasDgemmKernel; + this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8Kernel; + this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8Kernel; + this->QuantizeLinearS16Kernel = MlasQuantizeLinearS16Kernel; + this->QuantizeLinearU16Kernel = MlasQuantizeLinearU16Kernel; + this->QuantizeLinearS4Kernel = MlasQuantizeLinearS4Kernel; + this->QuantizeLinearU4Kernel = MlasQuantizeLinearU4Kernel; + + bool HasVXEInstructions = getauxval(AT_HWCAP) & HWCAP_S390_VXE; + if (HasVXEInstructions) { + this->GemmFloatKernel = MlasSgemmKernelZVECTOR; + this->GemmU8X8Dispatch = &MlasGemm8X8DispatchZVECTOR; + + this->QuantizeLinearS8Kernel = MlasQuantizeLinearS8KernelZVECTOR; + this->QuantizeLinearU8Kernel = MlasQuantizeLinearU8KernelZVECTOR; + } +#endif // MLAS_TARGET_S390X + #if defined(MLAS_TARGET_LARCH64) // @@ -661,6 +742,9 @@ Return Value: this->ComputeLogSoftmaxOutputF32Kernel = MlasComputeLogSoftmaxOutputF32KernelLasx; this->TransposePackB16x4Routine = MlasSgemmTransposePackB16x4Lasx; + // add new sqn-lasx kernel + this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchLasx; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchLSX; this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchLSX; }else if( cap_lsx ){ diff --git a/src/lib/pooling.cpp b/src/lib/pooling.cpp index 50dcf19..6bb23df 100644 --- a/src/lib/pooling.cpp +++ b/src/lib/pooling.cpp @@ -1533,7 +1533,7 @@ Return Value: c -= 8; } -#elif defined(MLAS_TARGET_POWER) +#elif defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) while (c >= 32) { auto MaximumVector0 = vec_splats(std::numeric_limits::lowest()); diff --git a/src/lib/power/qgemm_kernel_power10.cpp b/src/lib/power/qgemm_kernel_power10.cpp index 0f3bc1d..b00e37b 100644 --- a/src/lib/power/qgemm_kernel_power10.cpp +++ b/src/lib/power/qgemm_kernel_power10.cpp @@ -437,49 +437,51 @@ MlasGemmQuantCopyPackA8x8( Vtype a2 = vmask; Vtype a3 = vmask; Vtype a1 = *reinterpret_cast(&a[0]); - if (CountM == 3) { - a3 = *reinterpret_cast(&a[lda * 2]); - } - if (CountM >= 2) { + if (CountM == 1) { + vec_t va1 = AIsSigned ? reinterpret_cast(a1) : reinterpret_cast(vec_sub(a1, vmask)); + *reinterpret_cast(&D[0]) = (vec_t)va1; + vsum = vec_sum4s(va1, vsum); + } else { a2 = *reinterpret_cast(&a[lda]); + if (CountM == 3) { + a3 = *reinterpret_cast(&a[lda * 2]); + } + Vtype vx = + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); + Vtype vx1 = + reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); + Vtype vx2 = + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), reinterpret_cast<__vector int>(a2))); + Vtype vx3 = + reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), reinterpret_cast<__vector int>(a4))); + Vtype vx4 = vec_xxpermdi(vx, vx1, 0); + Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); + Vtype vx6 = vec_xxpermdi(vx, vx1, 3); + Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); + vec_t vx0 = AIsSigned ? reinterpret_cast(vx4) : reinterpret_cast(vec_sub(vx4, vmask)); + *reinterpret_cast(&D[0]) = vx0; + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx5) : reinterpret_cast(vec_sub(vx5, vmask)); + *reinterpret_cast(&D[16]) = vx0; + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx6) : reinterpret_cast(vec_sub(vx6, vmask)); + *reinterpret_cast(&D[32]) = vx0; + vsum = vec_sum4s(vx0, vsum); + vx0 = AIsSigned ? reinterpret_cast(vx7) : reinterpret_cast(vec_sub(vx7, vmask)); + *reinterpret_cast(&D[48]) = vx0; + vsum = vec_sum4s(vx0, vsum); + } + if (CountM == 1) { + D += 16; + } else { + D += 16 * 4; } - Vtype vx = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - Vtype vx1 = - reinterpret_cast(vec_mergee(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx2 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a1), - reinterpret_cast<__vector int>(a2))); - Vtype vx3 = - reinterpret_cast(vec_mergeo(reinterpret_cast<__vector int>(a3), - reinterpret_cast<__vector int>(a4))); - Vtype vx4 = vec_xxpermdi(vx, vx1, 0); - Vtype vx5 = vec_xxpermdi(vx2, vx3, 0); - Vtype vx6 = vec_xxpermdi(vx, vx1, 3); - Vtype vx7 = vec_xxpermdi(vx2, vx3, 3); - vec_t vx0 = - AIsSigned ? reinterpret_cast(vx4) : - reinterpret_cast(vec_sub(vx4, vmask)); - *reinterpret_cast(&D[0]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx5) : - reinterpret_cast(vec_sub(vx5, vmask)); - *reinterpret_cast(&D[16]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx6) : - reinterpret_cast(vec_sub(vx6, vmask)); - *reinterpret_cast(&D[32]) = vx0; - vsum = vec_sum4s(vx0, vsum); - vx0 = AIsSigned ? reinterpret_cast(vx7) : - reinterpret_cast(vec_sub(vx7, vmask)); - *reinterpret_cast(&D[48]) = vx0; - vsum = vec_sum4s(vx0, vsum); - D += 16 * 4; a += 16; y -= 16; } + if (CountM == 1) { + vsum[0] += (vsum[1] + vsum[2] + vsum[3]); + } while (y >= 4) { Vtype vb = vmask; @@ -496,7 +498,11 @@ MlasGemmQuantCopyPackA8x8( reinterpret_cast(vec_sub(reinterpret_cast(vx1), vmask)); *reinterpret_cast(&D[0]) = vx; vsum = vec_sum4s(vx, vsum); - D += 16; + if (CountM == 1) { + D += 4; + } + else + D += 16; a += 4; y -= 4; } @@ -1059,6 +1065,186 @@ MlasQgemmComputeMMA( } } }; + +MLAS_FORCEINLINE +void +MlasGemmQuantKernel_M1( + const MLAS_GEMM_QUANT_KERNEL_POWER10::PackedAType *A, + const MLAS_GEMM_QUANT_KERNEL_POWER10::PackedBType *B, + int32_t *C, + size_t PackedCountK, + size_t CountN, + size_t ldc, + const int32_t *RowSumBuffer, + const int32_t *ColumnSumBuffer, + const int32_t *ZeroPointB, + bool ZeroMode +) +{ + size_t Mval = 1; + while (CountN > 0) { + const int8_t *a = A; + typedef __vector unsigned char vec_t; + typedef __vector signed char svec_t; + const uint8_t *b = B; + MLAS_INT32X4 result = {0}; + __vector signed int VecC = {0, 0, 0, 0}; + __vector signed int VecC2 = {0, 0, 0, 0}; + __vector signed int VecC3 = {0, 0, 0, 0}; + __vector signed int VecC4 = {0, 0, 0, 0}; + size_t k = PackedCountK * MLAS_GEMM_QUANT_KERNEL_POWER10::PackedK; + size_t k1 = PackedCountK; + __vector unsigned char va[4]; + __vector unsigned char pat = {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}; + __vector unsigned char pat2 = {4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7}; + __vector unsigned char pat3 = {8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11}; + __vector unsigned char pat4 = {12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15}; + while (k >= 16) { + const vec_t *vecA = reinterpret_cast(a); + const vec_t *vb = reinterpret_cast(b); + va[0] = vec_perm(vecA[0], vecA[0], pat); + va[1] = vec_perm(vecA[0], vecA[0], pat2); + va[2] = vec_perm(vecA[0], vecA[0], pat3); + va[3] = vec_perm(vecA[0], vecA[0], pat4); + VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC); + VecC = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC); + VecC = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC); + VecC = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC); + vb = reinterpret_cast(&b[k1 * 16]); + VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2); + VecC2 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC2); + VecC2 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC2); + VecC2 = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC2); + vb = reinterpret_cast(&b[k1 * 32]); + VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3); + VecC3 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC3); + VecC3 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC3); + VecC3 = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC3); + vb = reinterpret_cast(&b[k1 * 48]); + VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4); + VecC4 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC4); + VecC4 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC4); + VecC4 = vec_msum((svec_t)va[3], (vec_t)vb[3], VecC4); + b += 64; + a += 16; + k -= 16; + } + if (k >= 12) { + const vec_t *vecA = reinterpret_cast(a); + const vec_t *vb = reinterpret_cast(b); + va[0] = vec_perm(vecA[0], vecA[0], pat); + va[1] = vec_perm(vecA[0], vecA[0], pat2); + va[2] = vec_perm(vecA[0], vecA[0], pat3); + VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC); + VecC = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC); + VecC = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC); + vb = reinterpret_cast(&b[k1 * 16]); + VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2); + VecC2 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC2); + VecC2 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC2); + vb = reinterpret_cast(&b[k1 * 32]); + VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3); + VecC3 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC3); + VecC3 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC3); + vb = reinterpret_cast(&b[k1 * 48]); + VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4); + VecC4 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC4); + VecC4 = vec_msum((svec_t)va[2], (vec_t)vb[2], VecC4); + a += 12; + b += 48; + k -= 12; + } + if (k >= 8) { + const vec_t *vecA = reinterpret_cast(a); + const vec_t *vb = reinterpret_cast(b); + va[0] = vec_perm(vecA[0], vecA[0], pat); + va[1] = vec_perm(vecA[0], vecA[0], pat2); + VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC); + VecC = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC); + vb = reinterpret_cast(&b[k1 * 16]); + VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2); + VecC2 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC2); + vb = reinterpret_cast(&b[k1 * 32]); + VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3); + VecC3 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC3); + vb = reinterpret_cast(&b[k1 * 48]); + VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4); + VecC4 = vec_msum((svec_t)va[1], (vec_t)vb[1], VecC4); + a += 8; + b += 32; + k -= 8; + } + if (k >= 4) { + const vec_t *vecA = reinterpret_cast(a); + const vec_t *vb = reinterpret_cast(b); + va[0] = vec_perm(vecA[0], vecA[0], pat); + VecC = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC); + vb = reinterpret_cast(&b[k1 * 16]); + VecC2 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC2); + vb = reinterpret_cast(&b[k1 * 32]); + VecC3 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC3); + vb = reinterpret_cast(&b[k1 * 48]); + VecC4 = vec_msum((svec_t)va[0], (vec_t)vb[0], VecC4); + a += 4; + b += 16; + k -= 4; + } + if (CountN >= 16) { + MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorMMA<4>(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorMMA<8>(&VecC3, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); + MlasQgemmStoreVectorMMA<12>(&VecC4, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 12); + INC_BUFFER(16); + CountN -= 16; + B += 16 * 4 * PackedCountK; + C += 16; + } else { + if (CountN >= 12) { + MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorMMA<4>(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorMMA<8>(&VecC3, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); + INC_BUFFER(12); + if (CountN - 12 > 0) + result = VecC4; + CountN -= 12; + C += 12; + } else if (CountN >= 8) { + MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorMMA<4>(&VecC2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + INC_BUFFER(8); + if (CountN - 8 > 0) + result = VecC3; + CountN -= 8; + C += 8; + } else if (CountN >= 4) { + MlasQgemmStoreVectorMMA<0>(&VecC, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + INC_BUFFER(4); + if (CountN - 4 > 0) + result = VecC2; + CountN -= 4; + C += 4; + } else + result = VecC; + CountN &= 3; + + // Output the remaining partial output block. + if (CountN > 0) { + MlasQgemmStoreScalarMMA<0>(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + INC_BUFFER(1); + } + if (CountN >= 2) { + MlasQgemmStoreScalarMMA<1>(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + INC_BUFFER(1); + } + if (CountN >= 3) { + MlasQgemmStoreScalarMMA<2>(&result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + INC_BUFFER(1); + } + CountN = 0; + } + } +} + template<> size_t MlasGemmQuantKernel( @@ -1075,6 +1261,10 @@ MlasGemmQuantKernel( bool ZeroMode ) { + if (CountM == 1) { + MlasGemmQuantKernel_M1(A, B, C, PackedCountK, CountN, ldc, RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + return 1; + } if (CountM < 8 && CountM >= 4) { CountM = 4; } diff --git a/src/lib/q4gemm.h b/src/lib/q4gemm.h index d16798e..89528fd 100644 --- a/src/lib/q4gemm.h +++ b/src/lib/q4gemm.h @@ -126,7 +126,7 @@ MlasQ4GemmOperation( size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) || defined(MLAS_TARGET_S390X) auto RowsHandled = GetMlasPlatform().GemmFloatKernel( a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true); #else diff --git a/src/lib/q4gemm_avx512.cpp b/src/lib/q4gemm_avx512.cpp index f7af82e..72d2a62 100644 --- a/src/lib/q4gemm_avx512.cpp +++ b/src/lib/q4gemm_avx512.cpp @@ -18,6 +18,7 @@ Module Name: --*/ #include "q4gemm.h" +#include #include #include diff --git a/src/lib/qgemm.cpp b/src/lib/qgemm.cpp index f5b33d2..62a23c8 100644 --- a/src/lib/qgemm.cpp +++ b/src/lib/qgemm.cpp @@ -14,10 +14,16 @@ Module Name: operation (QGEMM). --*/ - +#include #include "mlasi.h" #include "qgemm.h" +// TODO: When overrides are implemented, remove this +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) +#include "kleidiai/mlasi_kleidiai.h" +#endif + + // // Define the parameters to execute segments of a QGEMM operation on worker // threads. @@ -195,6 +201,26 @@ MlasGemmBatch( }); } +void +MLASCALL +MlasDynamicQGemmBatch ( + const MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS& Shape, + const MLAS_GEMM_DYN_QUANT_DATA_PARAMS* DataParams, + const size_t BatchN, + MLAS_THREADPOOL* ThreadPool +) { +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback and putting in guards + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){ + ArmKleidiAI::MlasDynamicQGemmBatch(Shape, DataParams, BatchN, ThreadPool); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(Shape); + MLAS_UNREFERENCED_PARAMETER(DataParams); + MLAS_UNREFERENCED_PARAMETER(BatchN); + MLAS_UNREFERENCED_PARAMETER(ThreadPool); +} int32_t MlasSymmQgemmGetKernelOutputCnt() @@ -293,10 +319,35 @@ MlasSymmQgemmBatch( }); } + + #if defined(_MSC_VER) && !defined(__clang__) #pragma warning(pop) #endif +size_t +MLASCALL +MlasDynamicQgemmPackBSize( + size_t N, + size_t K +) +{ + size_t bytes = 0; +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback available + //TODO: Insert Override + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override + bytes = ArmKleidiAI::MlasDynamicQgemmPackBSize(N, K); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + + return bytes; +} + + size_t MLASCALL MlasGemmPackBSize( @@ -354,10 +405,38 @@ Return Value: const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); const size_t AlignedBytesRequired = (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + //If this gemm B argument is used in a dynamically quantization gemm operation we can optimize for + //this use case. Concat both packed representations for later decision. + return AlignedBytesRequired + MlasDynamicQgemmPackBSize(N, K); +} - return AlignedBytesRequired; +void +MLASCALL +MlasDynamicQgemmPackB( + size_t N, + size_t K, + const int8_t* B, + const float* Scales, + const float* Bias, + void* PackedB +) +{ +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + //No fallback + if(MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()){//Still require this since no override + ArmKleidiAI::MlasDynamicQgemmPackB(N, K, B, Scales, Bias, PackedB); + } +#endif + + MLAS_UNREFERENCED_PARAMETER(N); + MLAS_UNREFERENCED_PARAMETER(K); + MLAS_UNREFERENCED_PARAMETER(B); + MLAS_UNREFERENCED_PARAMETER(Scales); + MLAS_UNREFERENCED_PARAMETER(Bias); + MLAS_UNREFERENCED_PARAMETER(PackedB); } + void MLASCALL MlasGemmPackB( @@ -400,7 +479,6 @@ Return Value: // // Retrieve the packing parameters. // - const auto* GemmQuantDispatch = MlasGemmQuantGetDispatch(AIsSigned, BIsSigned); size_t PackedK = GemmQuantDispatch->PackedK; @@ -515,7 +593,6 @@ MlasSymmQgemmPackBSize( #pragma warning(pop) #endif - void MLASCALL MlasSymmQgemmPackB( diff --git a/src/lib/qgemm.h b/src/lib/qgemm.h index 596267c..2730b1a 100644 --- a/src/lib/qgemm.h +++ b/src/lib/qgemm.h @@ -909,6 +909,10 @@ MlasGemmQuantGetDispatch( GemmQuantDispatch = BIsSigned ? GetMlasPlatform().GemmU8S8Dispatch : GetMlasPlatform().GemmU8U8Dispatch; } +#elif defined(MLAS_TARGET_S390X) + if (GetMlasPlatform().GemmU8X8Dispatch == &MlasGemm8X8DispatchZVECTOR) { + GemmQuantDispatch = GetMlasPlatform().GemmU8X8Dispatch; + } #endif #endif // !defined(FORCE_GENERIC_ALGORITHMS) diff --git a/src/lib/qgemm_kernel_amx.cpp b/src/lib/qgemm_kernel_amx.cpp index 5d4f1fd..04f1762 100644 --- a/src/lib/qgemm_kernel_amx.cpp +++ b/src/lib/qgemm_kernel_amx.cpp @@ -19,6 +19,7 @@ Module Name: #include "amx_common.h" #include + #define TMM0 0 #define TMM1 1 #define TMM2 2 diff --git a/src/lib/qgemm_kernel_wasmrelaxedsimd.cpp b/src/lib/qgemm_kernel_wasmrelaxedsimd.cpp index a3a0fa7..56c67aa 100644 --- a/src/lib/qgemm_kernel_wasmrelaxedsimd.cpp +++ b/src/lib/qgemm_kernel_wasmrelaxedsimd.cpp @@ -89,7 +89,7 @@ MlasGemmQuantCopyPackA( { MLAS_UNREFERENCED_PARAMETER(AIsSigned); const v128_t ZeroVector = wasm_i64x2_const(0, 0); - const v128_t OnesWordBroadcast = wasm_i16x8_splat(1); + const v128_t OnesByteBroadcast = wasm_i8x16_splat(1); uint8_t PaddedMatrixAData[8] = { 0 }; // @@ -109,19 +109,23 @@ MlasGemmQuantCopyPackA( // but CountK is aligned up to a multiple of 4 to maintain 32-bit // alignment. All extra bytes are zero-padded. // - // Zero extend the source bytes to 16-bits and accumulate - // into an intermediate per-row - // accumulator. CountK cannot be greater than 128 to avoid overflowing - // these signed 16-bit accumulators. - // + // Accumulate into an intermediate per-row accumulator. - while (k >= 8) { + while (k >= 16) { - v128_t Bytes = wasm_v128_load64_zero(&a[0]); - v128_t Words = wasm_i8x16_unpacklo_relaxed(Bytes, ZeroVector); + v128_t Bytes = wasm_v128_load(&a[0]); + ReductionVector = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(OnesByteBroadcast, Bytes, ReductionVector); + + wasm_v128_store(&D[0], Bytes); - ReductionVector = wasm_i16x8_add(ReductionVector, Words); + a += 16; + D += 16; + k -= 16; + } + if (k >= 8) { + v128_t Bytes = wasm_v128_load64_zero(&a[0]); + ReductionVector = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(OnesByteBroadcast, Bytes, ReductionVector); wasm_v128_store64_lane(&D[0], Bytes, 0); a += 8; @@ -145,9 +149,7 @@ MlasGemmQuantCopyPackA( } while (padded < padded_end); v128_t Bytes = wasm_v128_load64_zero(PaddedMatrixAData); - v128_t Words = wasm_i8x16_unpacklo_relaxed(Bytes, ZeroVector); - - ReductionVector = wasm_i16x8_add(ReductionVector, Words); + ReductionVector = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(OnesByteBroadcast, Bytes, ReductionVector); // // Copy quads of 8-bit values from the vector to the packed @@ -165,7 +167,6 @@ MlasGemmQuantCopyPackA( // Reduce the partial accumulators. // - ReductionVector = wasm_i32x4_dot_i16x8(ReductionVector, OnesWordBroadcast); ReductionVector = wasm_i32x4_add(ReductionVector, wasm_i32x4_shuffle(ReductionVector, wasm_i32x4_splat(0), 2, 3, 2, 3)); ReductionVector = wasm_i32x4_add(ReductionVector, @@ -376,181 +377,215 @@ MlasGemmQuantCopyPackB( } } -MLAS_FORCEINLINE -void -MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd( - v128_t ABroadcast, - const uint8_t* B, - v128_t Accumulators[2] -) -{ - v128_t BElements0 = wasm_v128_load(&B[0]); - v128_t BElements1 = wasm_v128_load(&B[16]); - Accumulators[0] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BElements0, ABroadcast, Accumulators[0]); - Accumulators[1] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BElements1, ABroadcast, Accumulators[1]); +//-------------------------------------------------------------------------- +// Small helper that performs one (A row) × (8 B columns) FMA step. +//-------------------------------------------------------------------------- +MLAS_FORCEINLINE void DotPairAdd(v128_t ABroadcast, + v128_t BVec0, + v128_t BVec1, + v128_t Acc[2]) { + Acc[0] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BVec0, ABroadcast, Acc[0]); + Acc[1] = wasm_i32x4_relaxed_dot_i8x16_i7x16_add(BVec1, ABroadcast, Acc[1]); } +//-------------------------------------------------------------------------- +// Generic RowCount×8 kernel implementation (RowCount is 6 or 1 at compile‑time) +//-------------------------------------------------------------------------- -template<> -size_t -MlasGemmQuantKernel( +template +static size_t GemmQuantKernelNx8Impl( const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, int32_t* C, size_t PackedCountK, - size_t CountM, + size_t /*CountM — ignored*/, size_t CountN, size_t ldc, const int32_t* RowSumBuffer, const int32_t* ColumnSumBuffer, const int32_t* ZeroPointB, - bool ZeroMode - ) + bool ZeroMode) { - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); + constexpr size_t ColBlock = 8; + const auto PackedK = MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK; + + // Build row‑wise pointer tables (a[r] & c[r]). + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* a[RowCount]; + int32_t* c[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + a[r] = (const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType*)(A + r * PackedK * PackedCountK); + c[r] = (int32_t*)(C + r * ldc); + } while (CountN > 0) { - - v128_t Accumulators[2]; - - // - // Initialize the accumulators with the row and column sums. - // - - int32_t RowSumValue = RowSumBuffer[0]; - - if (ZeroPointB != nullptr) { - - int32_t ScaledRowSumBuffer[8]; - - for (size_t i = 0; i < 8; i++) { - ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; - } - + // ------------------------------------------------------------------ + // 1) Initialize accumulators with row & column sums (and zero‑points) + // ------------------------------------------------------------------ + v128_t Acc[RowCount][2]; + + if (ZeroPointB) { + v128_t zp0 = wasm_v128_load(ZeroPointB + 0); + v128_t zp1 = wasm_v128_load(ZeroPointB + 4); ZeroPointB += 8; - Accumulators[0] = wasm_v128_load(&ScaledRowSumBuffer[0]); - Accumulators[1] = wasm_v128_load(&ScaledRowSumBuffer[4]); - + for (size_t r = 0; r < RowCount; ++r) { + v128_t RowSumValues = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][0] = wasm_i32x4_mul(RowSumValues, zp0); + Acc[r][1] = wasm_i32x4_mul(RowSumValues, zp1); + } + } else { + for (size_t r = 0; r < RowCount; ++r) { + Acc[r][0] = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][1] = Acc[r][0]; + } } - else { - Accumulators[0] = wasm_i32x4_splat(RowSumValue); - Accumulators[1] = Accumulators[0]; + v128_t col0 = wasm_v128_load(ColumnSumBuffer + 0); // first 4 col sums + v128_t col1 = wasm_v128_load(ColumnSumBuffer + 4); // next 4 col sums + for (size_t r = 0; r < RowCount; ++r) { + Acc[r][0] = wasm_i32x4_add(Acc[r][0], col0); + Acc[r][1] = wasm_i32x4_add(Acc[r][1], col1); } - - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&ColumnSumBuffer[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&ColumnSumBuffer[4])); ColumnSumBuffer += 8; - // - // Broadcast each pair of 16-bit values from the matrix A and multiply - // with the pair of 16-bit values from matrix B, and add the 32-bit + // ------------------------------------------------------------------ + // 2) Broadcast each pair of 8-bit values from the matrix A and multiply + // with the pair of 8-bit values from matrix B, and add the 32-bit // intermediate into the accumulator registers. - // - - const uint8_t* a = A; + // ------------------------------------------------------------------ size_t k = PackedCountK; - - while (k >= 4) { - - v128_t AElements = wasm_v128_load((v128_t*)a); - v128_t ABroadcast; - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 0, 0, 0, 0); - MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[0], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 1, 1, 1, 1); - MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[32], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 2, 2, 2, 2); - MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[64], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 3, 3, 3, 3); - MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[96], Accumulators); - - a += 4 * 4; - B += 4 * 32; - k -= 4; - } - while (k > 0) { + v128_t ABroadcast[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + ABroadcast[r] = wasm_v128_load32_splat(a[r]); // broadcast 4 × u8 + a[r] += 4; + } - v128_t ABroadcast = wasm_i32x4_splat(*((int32_t*)a)); - MlasGemmU8X8MultiplyAccumulateRowWasmRelaxedSimd(ABroadcast, &B[0], Accumulators); - - a += 4; + v128_t B0 = wasm_v128_load(&B[0]); + v128_t B1 = wasm_v128_load(&B[16]); + for (size_t r = 0; r < RowCount; ++r) { + DotPairAdd(ABroadcast[r], B0, B1, Acc[r]); + } B += 32; k -= 1; } - // - // Output the accumulator block after optionally accumulating the values + // ------------------------------------------------------------------ + // 3) Output the accumulator block after optionally accumulating the values // from matrix C. - // + // ------------------------------------------------------------------ if (CountN >= 8) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&C[4])); - } - - wasm_v128_store(&C[0], Accumulators[0]); - wasm_v128_store(&C[4], Accumulators[1]); - - C += 8; - CountN -= 8; - - } - else { - - // - // Output the remaining partial output block. - // - - if ((CountN & 4) != 0) { - + // ---- Full 8‑column tile ---- + for (size_t r = 0; r < RowCount; ++r) { if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); + Acc[r][0] = wasm_i32x4_add(Acc[r][0], wasm_v128_load(c[r] + 0)); + Acc[r][1] = wasm_i32x4_add(Acc[r][1], wasm_v128_load(c[r] + 4)); } - - wasm_v128_store(&C[0], Accumulators[0]); - C += 4; - - Accumulators[0] = Accumulators[1]; + wasm_v128_store(c[r] + 0, Acc[r][0]); + wasm_v128_store(c[r] + 4, Acc[r][1]); + a[r] -= PackedCountK * 4; // Rewind a[r] for next N‑tile (PackedCountK * 4 bytes each). + c[r] += ColBlock; } - - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load64_zero(&C[0])); + CountN -= ColBlock; + } else { + // ---- 4/2/1‑column tails ---- + auto Tail = [&](size_t cols, auto load_c, auto store_c) { + for (size_t r = 0; r < RowCount; ++r) { + if (!ZeroMode) Acc[r][0] = wasm_i32x4_add(Acc[r][0], load_c(c[r])); } - - wasm_v128_store64_lane(&C[0], Accumulators[0], 0); - C += 2; - - Accumulators[0] = wasm_i32x4_shuffle(Accumulators[0], wasm_i32x4_splat(0), 2, 3, 2, 3); + for (size_t r = 0; r < RowCount; ++r) store_c(c[r], Acc[r][0]); + for (size_t r = 0; r < RowCount; ++r) c[r] += cols; + }; + + if (CountN & 4) { + Tail(4, + [](int32_t* p) { return wasm_v128_load(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store(p, v); }); + for (size_t r = 0; r < RowCount; ++r) Acc[r][0] = Acc[r][1]; } - - if ((CountN & 1) != 0) { - - int32_t AccumulatorValue = wasm_i32x4_extract_lane(Accumulators[0], 0); - - if (!ZeroMode) { - AccumulatorValue += C[0]; + if (CountN & 2) { + Tail(2, + [](int32_t* p) { return wasm_v128_load64_zero(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store64_lane(p, v, 0); }); + for (size_t r = 0; r < RowCount; ++r) + Acc[r][0] = wasm_i32x4_shuffle(Acc[r][0], wasm_i32x4_splat(0), 2, 3, 2, 3); + } + if (CountN & 1) { + for (size_t r = 0; r < RowCount; ++r) { + int32_t v = wasm_i32x4_extract_lane(Acc[r][0], 0); + if (!ZeroMode) v += *c[r]; + *c[r] = v; } - - C[0] = AccumulatorValue; } - CountN = 0; } } + return RowCount; +} + + +size_t MlasGemmQuantKernel6x8( + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<6>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} - return 1; +size_t MlasGemmQuantKernel1x8( + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<1>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} + + +template <> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode +) +{ + size_t RowsHandled = 0; + if (CountM >= 6) { + RowsHandled = MlasGemmQuantKernel6x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } else { + RowsHandled = MlasGemmQuantKernel1x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } + return RowsHandled; } const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmRelaxedSimd = { @@ -559,5 +594,5 @@ const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmRelaxedSimd = { nullptr, MLAS_GEMM_U8X8_KERNEL_WASMRELAXEDSIMD::PackedK, 0, - 4 // multiple of kernel stride M + 6 // multiple of kernel stride M }; diff --git a/src/lib/qgemm_kernel_wasmsimd.cpp b/src/lib/qgemm_kernel_wasmsimd.cpp index 1f33d77..84f6c6b 100644 --- a/src/lib/qgemm_kernel_wasmsimd.cpp +++ b/src/lib/qgemm_kernel_wasmsimd.cpp @@ -322,181 +322,209 @@ MlasGemmQuantCopyPackB( } } -MLAS_FORCEINLINE -void -MlasGemmU8X8MultiplyAccumulateRowWasmSimd( - v128_t ABroadcast, - const int16_t* B, - v128_t Accumulators[2] -) -{ - v128_t BElements0 = wasm_v128_load(&B[0]); - v128_t BElements1 = wasm_v128_load(&B[8]); - - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_i32x4_dot_i16x8(BElements0, ABroadcast)); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_i32x4_dot_i16x8(BElements1, ABroadcast)); +//------------------------------------------------------------------ +// Helper – dot‑product add for i16×i16 → i32 pairs. +//------------------------------------------------------------------ +MLAS_FORCEINLINE void DotPairAddI16(v128_t ABroadcast, + v128_t BVec0, + v128_t BVec1, + v128_t Acc[2]) { + Acc[0] = wasm_i32x4_add(Acc[0], wasm_i32x4_dot_i16x8(BVec0, ABroadcast)); + Acc[1] = wasm_i32x4_add(Acc[1], wasm_i32x4_dot_i16x8(BVec1, ABroadcast)); } +//------------------------------------------------------------------ +// Generic RowCount×8 kernel (RowCount = 4 or 1) for WASM SIMD. +//------------------------------------------------------------------ -template<> -size_t -MlasGemmQuantKernel( +template +static size_t GemmQuantKernelNx8Impl( const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, int32_t* C, size_t PackedCountK, - size_t CountM, + size_t /*CountM — ignored*/, size_t CountN, size_t ldc, const int32_t* RowSumBuffer, const int32_t* ColumnSumBuffer, const int32_t* ZeroPointB, - bool ZeroMode - ) + bool ZeroMode) { - MLAS_UNREFERENCED_PARAMETER(CountM); - MLAS_UNREFERENCED_PARAMETER(ldc); - - while (CountN > 0) { - - v128_t Accumulators[2]; - // - // Initialize the accumulators with the row and column sums. - // - - int32_t RowSumValue = RowSumBuffer[0]; - - if (ZeroPointB != nullptr) { + constexpr size_t ColBlock = 8; + const auto PackedK = MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedK; // ==2 - int32_t ScaledRowSumBuffer[8]; - - for (size_t i = 0; i < 8; i++) { - ScaledRowSumBuffer[i] = RowSumValue * ZeroPointB[i]; - } + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* a[RowCount]; + int32_t* c[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + a[r] = (const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType*)(A + r * PackedK * PackedCountK); + c[r] = (int32_t*)(C + r * ldc); + } + while (CountN > 0) { + // ------------------------------------------------------------------ + // 1) Initialize accumulators with row & column sums (and zero‑points) + // ------------------------------------------------------------------ + v128_t Acc[RowCount][2]; + v128_t col0 = wasm_v128_load(ColumnSumBuffer + 0); + v128_t col1 = wasm_v128_load(ColumnSumBuffer + 4); + + if (ZeroPointB) { + v128_t zp0 = wasm_v128_load(ZeroPointB + 0); + v128_t zp1 = wasm_v128_load(ZeroPointB + 4); ZeroPointB += 8; - Accumulators[0] = wasm_v128_load(&ScaledRowSumBuffer[0]); - Accumulators[1] = wasm_v128_load(&ScaledRowSumBuffer[4]); - - } - else { - - Accumulators[0] = wasm_i32x4_splat(RowSumValue); - Accumulators[1] = Accumulators[0]; + for (size_t r = 0; r < RowCount; ++r) { + v128_t RowSumValues = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][0] = wasm_i32x4_add(wasm_i32x4_mul(RowSumValues, zp0), col0); + Acc[r][1] = wasm_i32x4_add(wasm_i32x4_mul(RowSumValues, zp1), col1); + } + } else { + for (size_t r = 0; r < RowCount; ++r) { + v128_t RowSumValues = wasm_v128_load32_splat(RowSumBuffer + r); + Acc[r][0] = wasm_i32x4_add(RowSumValues, col0); + Acc[r][1] = wasm_i32x4_add(RowSumValues, col1); + } } - - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&ColumnSumBuffer[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&ColumnSumBuffer[4])); ColumnSumBuffer += 8; - // - // Broadcast each pair of 16-bit values from the matrix A and multiply + // ---------------------------------------------------------------------- + // 2) Broadcast each pair of 16-bit values from the matrix A and multiply // with the pair of 16-bit values from matrix B, and add the 32-bit // intermediate into the accumulator registers. - // - - const int16_t* a = A; + // ---------------------------------------------------------------------- size_t k = PackedCountK; + while (k > 0) { + v128_t ABroadcast[RowCount]; + for (size_t r = 0; r < RowCount; ++r) { + ABroadcast[r] = wasm_v128_load32_splat(a[r]); + a[r] += 2; + } - while (k >= 4) { - - v128_t AElements = wasm_v128_load((v128_t*)a); - v128_t ABroadcast; - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 0, 0, 0, 0); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[0], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 1, 1, 1, 1); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[16], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 2, 2, 2, 2); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[32], Accumulators); - - ABroadcast = wasm_i32x4_shuffle(AElements, wasm_i32x4_splat(0), 3, 3, 3, 3); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[48], Accumulators); - - a += 4 * 2; - B += 4 * 16; - k -= 4; - } + v128_t B0 = wasm_v128_load(B + 0); // cols 0‑3 (8 i16) + v128_t B1 = wasm_v128_load(B + 8); // cols 4‑7 (8 i16) - while (k > 0) { - v128_t ABroadcast = wasm_i32x4_splat(*((int32_t*)a)); - MlasGemmU8X8MultiplyAccumulateRowWasmSimd(ABroadcast, &B[0], Accumulators); + for (size_t r = 0; r < RowCount; ++r) { + DotPairAddI16(ABroadcast[r], B0, B1, Acc[r]); + } - a += 2; B += 16; k -= 1; } - // - // Output the accumulator block after optionally accumulating the values + // ------------------------------------------------------------------ + // 3) Output the accumulator block after optionally accumulating the values // from matrix C. - // - + // ------------------------------------------------------------------ if (CountN >= 8) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); - Accumulators[1] = wasm_i32x4_add(Accumulators[1], wasm_v128_load(&C[4])); - } - - wasm_v128_store(&C[0], Accumulators[0]); - wasm_v128_store(&C[4], Accumulators[1]); - - C += 8; - CountN -= 8; - - } - else { - - // - // Output the remaining partial output block. - // - - if ((CountN & 4) != 0) { - + for (size_t r = 0; r < RowCount; ++r) { if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load(&C[0])); + Acc[r][0] = wasm_i32x4_add(Acc[r][0], wasm_v128_load(c[r] + 0)); + Acc[r][1] = wasm_i32x4_add(Acc[r][1], wasm_v128_load(c[r] + 4)); } - - wasm_v128_store(&C[0], Accumulators[0]); - C += 4; - - Accumulators[0] = Accumulators[1]; + wasm_v128_store(c[r] + 0, Acc[r][0]); + wasm_v128_store(c[r] + 4, Acc[r][1]); + c[r] += ColBlock; + a[r] -= PackedCountK * 2; // Rewind a[r] for next N-tile (PackedCountK * 2 elements, 16-bit each). } - - if ((CountN & 2) != 0) { - - if (!ZeroMode) { - Accumulators[0] = wasm_i32x4_add(Accumulators[0], wasm_v128_load64_zero(&C[0])); + CountN -= 8; + } else { + // ---- 4/2/1‑column tails ---- + auto Tail = [&](size_t cols, auto load_c, auto store_c) { + for (size_t r = 0; r < RowCount; ++r) { + if (!ZeroMode) Acc[r][0] = wasm_i32x4_add(Acc[r][0], load_c(c[r])); } - - wasm_v128_store64_lane(&C[0], Accumulators[0], 0); - C += 2; - - Accumulators[0] = wasm_i32x4_shuffle(Accumulators[0], wasm_i32x4_splat(0), 2, 3, 2, 3); + for (size_t r = 0; r < RowCount; ++r) store_c(c[r], Acc[r][0]); + for (size_t r = 0; r < RowCount; ++r) c[r] += cols; + }; + + if (CountN & 4) { + Tail(4, + [](int32_t* p) { return wasm_v128_load(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store(p, v); }); + for (size_t r = 0; r < RowCount; ++r) Acc[r][0] = Acc[r][1]; } - - if ((CountN & 1) != 0) { - - int32_t AccumulatorValue = wasm_i32x4_extract_lane(Accumulators[0], 0); - - if (!ZeroMode) { - AccumulatorValue += C[0]; + if (CountN & 2) { + Tail(2, + [](int32_t* p) { return wasm_v128_load64_zero(p); }, + [](int32_t* p, v128_t v) { wasm_v128_store64_lane(p, v, 0); }); + for (size_t r = 0; r < RowCount; ++r) + Acc[r][0] = wasm_i32x4_shuffle(Acc[r][0], wasm_i32x4_splat(0), 2, 3, 2, 3); + } + if (CountN & 1) { + for (size_t r = 0; r < RowCount; ++r) { + int32_t v = wasm_i32x4_extract_lane(Acc[r][0], 0); + if (!ZeroMode) v += *c[r]; + *c[r] = v; } - - C[0] = AccumulatorValue; } - CountN = 0; } } + return RowCount; +} + +size_t MlasGemmQuantKernel4x8( + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<4>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} + +size_t MlasGemmQuantKernel1x8( + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) { + MLAS_UNREFERENCED_PARAMETER(CountM); + return GemmQuantKernelNx8Impl<1>(A, B, C, PackedCountK, 0, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); +} - return 1; +template<> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_WASMSIMD::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + size_t RowsHandled = 0; + if (CountM >= 4) { + RowsHandled = MlasGemmQuantKernel4x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } else { + RowsHandled = MlasGemmQuantKernel1x8(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB, ZeroMode); + } + return RowsHandled; } const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd = { diff --git a/src/lib/qladd.cpp b/src/lib/qladd.cpp index 5dafa17..4bfa074 100644 --- a/src/lib/qladd.cpp +++ b/src/lib/qladd.cpp @@ -552,6 +552,113 @@ MlasQLinearAddKernelHelper( InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); } } +#elif defined(MLAS_TARGET_S390X) +template +static +void +MlasQLinearAddKernelHelper( + const DataType* InputA, + float ScaleA, + int32_t ZeroPointA, + const DataType* InputB, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC, + DataType* OutputC, + size_t N + ) +{ + if (N >= 16) { + float ScaleRatio_AC = ScaleA / ScaleC; + float ScaleRatio_BC = ScaleB / ScaleC; + MLAS_FLOAT32X4 VectorScaleRatio_AC = MlasBroadcastFloat32x4(ScaleRatio_AC); + MLAS_FLOAT32X4 VectorScaleRatio_BC = MlasBroadcastFloat32x4(ScaleRatio_BC); + MLAS_FLOAT32X4 VectorFixedPart = MlasBroadcastFloat32x4((float)ZeroPointC - (ScaleRatio_AC * ZeroPointA + ScaleRatio_BC * ZeroPointB)); + MLAS_FLOAT32X4 vb0_lo, vb0_hi, vb1_lo, vb1_hi; + const uint8_t flip = 128; + MLAS_UNREFERENCED_PARAMETER(flip); + __vector unsigned char vmask = reinterpret_cast<__vector unsigned char>(vec_splats(flip)); + __vector signed short vmask1 = reinterpret_cast<__vector signed short>(vec_splats((short)flip)); + + if (IsScalarB) { + vb0_lo = MlasBroadcastFloat32x4((float)*InputB); + VectorFixedPart = __builtin_s390_vfmasb(vb0_lo, VectorScaleRatio_BC, VectorFixedPart); + } + while (N >= 16) { + MLAS_INT32X4 r_lo, r_hi; + MLAS_FLOAT32X4 va_lo, va_hi; + MLAS_UNREFERENCED_PARAMETER(VectorScaleRatio_AC); + MLAS_UNREFERENCED_PARAMETER(VectorScaleRatio_BC); + auto va = MlasPackL8(InputA, vmask); + auto vshort = vec_unpackh(va); + vshort = MlasPackS16(vshort, vmask1); + auto va1 = vec_unpackl(vshort); + auto va0 = vec_unpackh(vshort); + va_lo = vec_float(va0); + va_hi = vec_float(va1); + if (!IsScalarB) { + auto vb = MlasPackL8(InputB, vmask); + vshort = vec_unpackh(vb); + vshort = MlasPackS16(vshort, vmask1); + auto vb1 = vec_unpackl(vshort); + auto vb0 = vec_unpackh(vshort); + vb0_lo = vec_float(vb0); + vb0_hi = vec_float(vb1); + vshort = vec_unpackl(vb); + vshort = MlasPackS16(vshort, vmask1); + vb1 = vec_unpackl(vshort); + vb0 = vec_unpackh(vshort); + vb1_lo = vec_float(vb0); + vb1_hi = vec_float(vb1); + InputB += 16; + } + va_lo = va_lo * VectorScaleRatio_AC; + va_hi = va_hi * VectorScaleRatio_AC; + if (IsScalarB) { + r_lo = vec_signed(vec_round(va_lo + VectorFixedPart)); + r_hi = vec_signed(vec_round(va_hi + VectorFixedPart)); + } else { + vb0_lo = vb0_lo * VectorScaleRatio_BC; + vb0_hi = vb0_hi * VectorScaleRatio_BC; + r_lo = vec_signed(vec_round(VectorFixedPart + va_lo + vb0_lo)); + r_hi = vec_signed(vec_round(VectorFixedPart + va_hi + vb0_hi)); + } + const auto vc0 = vec_packs(r_lo, r_hi); + vshort = vec_unpackl(va); + vshort = MlasPackS16(vshort, vmask1); + va1 = vec_unpackl(vshort); + va0 = vec_unpackh(vshort); + va_lo = vec_float(va0); + va_hi = vec_float(va1); + va_lo = va_lo * VectorScaleRatio_AC; + va_hi = va_hi * VectorScaleRatio_AC; + if (IsScalarB) { + r_lo = vec_signed(vec_round(VectorFixedPart + va_lo)); + r_hi = vec_signed(vec_round(VectorFixedPart + va_hi)); + } else { + vb1_lo = vb1_lo * VectorScaleRatio_BC; + vb1_hi = vb1_hi * VectorScaleRatio_BC; + r_lo = vec_signed(vec_round(VectorFixedPart + va_lo + vb1_lo)); + r_hi = vec_signed(vec_round(VectorFixedPart + va_hi + vb1_hi)); + } + const auto vc1 = vec_packs(r_lo, r_hi); + MLAS_INT32X4 vc = MlasPackS16_128(vc0, vc1); + vec_xst(vc, 0, reinterpret_cast(OutputC)); + + // Workaround for bad GCC warning that variable is set but not used. + MLAS_UNREFERENCED_PARAMETER(vc); + + N -= 16; + InputA += 16; + OutputC += 16; + } + } + if (N > 0) { + MlasQLinearAddKernelRawHelper( + InputA, ScaleA, ZeroPointA, InputB, ScaleB, ZeroPointB, ScaleC, ZeroPointC, OutputC, N); + } +} #elif defined(MLAS_LSX_INTRINSICS) template diff --git a/src/lib/qladd.h b/src/lib/qladd.h index 9456894..369708e 100644 --- a/src/lib/qladd.h +++ b/src/lib/qladd.h @@ -453,6 +453,101 @@ MlasPackS16_128( return reinterpret_cast(vec_packsu(a, b)); } +template <> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + __vector short a, + __vector short b + ) +{ + return reinterpret_cast(vec_packs(a, b)); +} +#elif defined(MLAS_TARGET_S390X) +typedef __vector signed char MLAS_INT8; +typedef __vector short MLAS_SHORT; +template +MLAS_FORCEINLINE +MLAS_INT8 +MlasPackL8( + const DataType* Input, + __vector unsigned char vmask + ); + +template <> +MLAS_FORCEINLINE +MLAS_INT8 +MlasPackL8( + const uint8_t* Input, + __vector unsigned char vmask + ) +{ + __vector unsigned char va = vec_xl(0,Input); + return reinterpret_cast(reinterpret_cast<__vector unsigned char>(va) - vmask); +} + +template <> +MLAS_FORCEINLINE +MLAS_INT8 +MlasPackL8( + const int8_t* Input, + __vector unsigned char vmask + ) +{ + MLAS_UNREFERENCED_PARAMETER(vmask); + return reinterpret_cast(vec_xl(0,Input)); +} + +template +MLAS_FORCEINLINE +MLAS_SHORT +MlasPackS16( + __vector short a, + __vector short b + ); + +template <> +MLAS_FORCEINLINE +MLAS_SHORT +MlasPackS16( + __vector short a, + __vector short b + ) +{ + return a + b; +} + +template <> +MLAS_FORCEINLINE +MLAS_SHORT +MlasPackS16( + __vector short a, + __vector short b + ) +{ + MLAS_UNREFERENCED_PARAMETER(b); + return a; +} + +template +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + __vector short a, + __vector short b + ); + +template <> +MLAS_FORCEINLINE +MLAS_INT32X4 +MlasPackS16_128( + __vector short a, + __vector short b + ) +{ + return reinterpret_cast(vec_packsu(a, b)); +} + template <> MLAS_FORCEINLINE MLAS_INT32X4 diff --git a/src/lib/qlgavgpool.cpp b/src/lib/qlgavgpool.cpp index f0d2b48..746d722 100644 --- a/src/lib/qlgavgpool.cpp +++ b/src/lib/qlgavgpool.cpp @@ -15,7 +15,7 @@ Module Name: --*/ #include "mlasi.h" -#include +#include size_t MLASCALL diff --git a/src/lib/qlmul.cpp b/src/lib/qlmul.cpp index 4a6d57d..e518483 100644 --- a/src/lib/qlmul.cpp +++ b/src/lib/qlmul.cpp @@ -384,6 +384,98 @@ MlasQLinearMulKernel( MLAS_UNREFERENCED_PARAMETER(ScaleBVector); MLAS_UNREFERENCED_PARAMETER(ValueBVector); } +#elif defined(MLAS_ZVECTOR_INTRINSICS) + +template +static +void +MlasQLinearMulKernel( + const DataType* InputA, + float ScaleA, + int32_t ZeroPointA, + const DataType* InputB, + float ScaleB, + int32_t ZeroPointB, + float ScaleC, + int32_t ZeroPointC, + DataType* OutputC, + size_t N + ) +{ + const float MinimumValue = (float)((int)std::numeric_limits::min() - ZeroPointC); + const float MaximumValue = (float)((int)std::numeric_limits::max() - ZeroPointC); + + auto ZeroPointAVector = vec_splats(int32_t(ZeroPointA)); + auto ZeroPointBVector = vec_splats(int32_t(ZeroPointB)); + auto ZeroPointCVector = vec_splats(float(ZeroPointC)); + + auto ScaleAVector = vec_splats(ScaleA); + auto ScaleBVector = vec_splats(ScaleB); + auto ScaleCVector = vec_splats(ScaleC); + + auto MinimumVector = vec_splats(MinimumValue); + auto MaximumVector = vec_splats(MaximumValue); + + float ValueB; + __vector float ValueBVector; + + if (IsScalarB) { + ValueB = ScaleB * (int32_t(InputB[0]) - ZeroPointB); + ValueBVector = vec_splats(ValueB); + } + + while (N >= 4) { + __vector int32_t IntegerAVector {InputA[0], InputA[1], InputA[2], InputA[3]}; + auto IntegerVector = IntegerAVector - ZeroPointAVector; + auto ValueAVector = ScaleAVector * vec_float(IntegerVector); + + if (!IsScalarB) { + __vector int32_t IntegerBVector {InputB[0], InputB[1], InputB[2], InputB[3]}; + IntegerVector = IntegerBVector - ZeroPointBVector; + ValueBVector = ScaleBVector * vec_float(IntegerVector); + } + + auto ValueCVector = ValueAVector * ValueBVector / ScaleCVector; + ValueCVector = vec_min(vec_max(ValueCVector, MinimumVector), MaximumVector); + ValueCVector = vec_round(ValueCVector + ZeroPointCVector); + + auto IntegerValueCVector = vec_signed(ValueCVector); + OutputC[0] = (DataType) IntegerValueCVector[0]; + OutputC[1] = (DataType) IntegerValueCVector[1]; + OutputC[2] = (DataType) IntegerValueCVector[2]; + OutputC[3] = (DataType) IntegerValueCVector[3]; + + OutputC += 4; + InputA += 4; + InputB += 4; + + N -= 4; + + // Suppress wrong GCC warnings + MLAS_UNREFERENCED_PARAMETER(ValueAVector); + } + + while (N > 0) { + float ValueA = ScaleA * (int32_t(*InputA) - ZeroPointA); + if (!IsScalarB) { + ValueB = ScaleB * (int32_t(*InputB) - ZeroPointB); + } + float ValueC = (ValueA * ValueB) / ScaleC; + ValueC = std::min(std::max(ValueC, MinimumValue), MaximumValue); + + *OutputC = (DataType)(int32_t)std::nearbyintf(ValueC + ZeroPointC); + + InputA++; + InputB++; + OutputC++; + N--; + } + + // Suppress wrong GCC warnings + MLAS_UNREFERENCED_PARAMETER(ScaleAVector); + MLAS_UNREFERENCED_PARAMETER(ScaleBVector); + MLAS_UNREFERENCED_PARAMETER(ValueBVector); +} #elif defined(MLAS_LSX_INTRINSICS) diff --git a/src/lib/qnbitgemm.cpp b/src/lib/qnbitgemm.cpp index 19d11a6..f34128d 100644 --- a/src/lib/qnbitgemm.cpp +++ b/src/lib/qnbitgemm.cpp @@ -132,7 +132,7 @@ QNBitGemmPerGemmWorkspaceSize( } if (BlkBitWidth == 4 || BlkBitWidth == 8) { - return Dispatch->QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType); + return Dispatch->QNBitGemmPerGemmWorkspaceSize(M, N, K, BlkLen, HasZeroPoint, ComputeType, BlkBitWidth); } return 0; @@ -266,7 +266,7 @@ MlasQNBitGemmPackQuantBData( if (BlkBitWidth == 4) { if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen, false); Dispatch->SQ4BitGemmPackQuantBDataAndBlkSum( N, K, @@ -307,7 +307,8 @@ MlasQNBitGemmPackQuantBData( } else if (BlkBitWidth == 8) { if (ComputeType == SQNBIT_CompInt8 && Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum != nullptr) { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(PackedQuantBDataAndOrBlkSumWorkspace, N, BlockCountK, + BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); Dispatch->SQ8BitGemmPackQuantBDataAndBlkSum( N, K, @@ -470,7 +471,7 @@ SQ4BitGemm_CompFp32( size_t RowsRemaining = RangeCountM; while (RowsRemaining > 0) { -#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) || defined(MLAS_TARGET_LARCH64) auto RowsHandled = GetMlasPlatform().GemmFloatKernel( a_row, dequant_b, c_blk, K, RowsRemaining, CountN, lda, ldc, 1.f, true ); @@ -742,6 +743,8 @@ SQ8BitGemm_CompInt8( : static_cast(DataParams->QuantBZeroPoint) + RangeStartN * k_blks_zp_bytes; const float* ABlockSum = per_gemm_quant_a_workspace->BlockSum + RangeStartM * k_blks; const float* QuantBBlkSum = DataParams->QuantBBlkSum + RangeStartN * k_blks; + const float* BlkUnsignedQuantAZeroPointCorrection = + DataParams->BlkUnsignedQuantAZeroPointCorrection ? DataParams->BlkUnsignedQuantAZeroPointCorrection + RangeStartN * k_blks : nullptr; float* C = DataParams->C + RangeStartM * ldc + RangeStartN; const float* Bias = (DataParams->Bias == nullptr) ? nullptr : DataParams->Bias + RangeStartN; @@ -759,6 +762,8 @@ SQ8BitGemm_CompInt8( if (GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { const float* b_blk_sum = QuantBBlkSum + n * k_blks; + const float* blk_unsigned_quant_A_zp_correction = BlkUnsignedQuantAZeroPointCorrection ? + BlkUnsignedQuantAZeroPointCorrection + n * k_blks : nullptr; GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8( BlkLen, QuantA, @@ -774,7 +779,8 @@ SQ8BitGemm_CompInt8( bias, ldc, ABlockSum, - b_blk_sum + b_blk_sum, + blk_unsigned_quant_A_zp_correction ); if (DataParams->PostProcessor != nullptr) { @@ -798,7 +804,8 @@ InitializeWorkspace_CompInt8( const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth ); template <> @@ -812,7 +819,8 @@ InitializeWorkspace_CompInt8( const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth ) { MLAS_UNREFERENCED_PARAMETER(N); @@ -826,7 +834,7 @@ InitializeWorkspace_CompInt8( const size_t QuantAStride = BlockCountK * Q8BlkSize(BlkLen); // TODO: try parallel on BatchN * M threads because BatchN is usually 1. - if (UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) { + if (BlkBitWidth == 4 && UsePacked && QuantizeA_Packed && UsePacked(K, BlkLen, DataParams->QuantBZeroPoint)) { MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { const auto& data = DataParams[gemm_idx]; @@ -834,38 +842,63 @@ InitializeWorkspace_CompInt8( std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; QuantizeA_Packed(BlkLen, ARowPtr, M, K, QuantARowPtr); }); - } else if (QuantizeARow) { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - - const float* ARowPtr = data.A; - std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - for (size_t m = 0; m < M; ++m) { - QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); - - ARowPtr += data.lda; - QuantARowPtr += QuantAStride; - } - }); } else { - MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { - const auto& data = DataParams[gemm_idx]; - const float* ARowPtr = data.A; - - void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; - PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); - std::byte* QuantARowPtr = quant_a_data.QuantData; - float* QuantARowScalePtr = quant_a_data.QuantScale; - float* QuantARowBlkSum = quant_a_data.BlockSum; - for (size_t m = 0; m < M; ++m) { - QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); - ARowPtr += data.lda; - QuantARowPtr += BlockCountK * BlkLen; - QuantARowScalePtr += BlockCountK; - QuantARowBlkSum += BlockCountK; + // TODO(hasesh): Clean-up the following logic so that it is clean AND it works as expected on all platforms + if (BlkBitWidth == 4) { + if (QuantizeARow) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + + const float* ARowPtr = data.A; + std::byte* QuantARowPtr = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + for (size_t m = 0; m < M; ++m) { + QuantizeARow(BlkLen, ARowPtr, K, QuantARowPtr); + + ARowPtr += data.lda; + QuantARowPtr += QuantAStride; + } + }); + } else if (QuantizeARow2) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; + } + }); } - }); - } + } else if (BlkBitWidth == 8) { + if (QuantizeARow2) { + MlasTrySimpleParallel(ThreadPool, BatchN, [&](ptrdiff_t gemm_idx) { + const auto& data = DataParams[gemm_idx]; + const float* ARowPtr = data.A; + + void* PerGemmWorkspace = static_cast(Workspace) + gemm_idx * PerGemmWorkspaceStride; + PerGemmQuantAWorkspace quant_a_data(PerGemmWorkspace, M, BlockCountK, BlkLen); + std::byte* QuantARowPtr = quant_a_data.QuantData; + float* QuantARowScalePtr = quant_a_data.QuantScale; + float* QuantARowBlkSum = quant_a_data.BlockSum; + for (size_t m = 0; m < M; ++m) { + QuantizeARow2(BlkLen, ARowPtr, K, QuantARowPtr, QuantARowScalePtr, QuantARowBlkSum); + ARowPtr += data.lda; + QuantARowPtr += BlockCountK * BlkLen; + QuantARowScalePtr += BlockCountK; + QuantARowBlkSum += BlockCountK; + } + }); + } + } + } } template <> @@ -879,7 +912,8 @@ InitializeWorkspace_CompInt8( const MLAS_QNBIT_GEMM_DATA_PARAMS* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth ) { MLAS_UNREFERENCED_PARAMETER(M); MLAS_UNREFERENCED_PARAMETER(N); @@ -890,6 +924,7 @@ InitializeWorkspace_CompInt8( MLAS_UNREFERENCED_PARAMETER(Workspace); MLAS_UNREFERENCED_PARAMETER(PerGemmWorkspaceStride); MLAS_UNREFERENCED_PARAMETER(ThreadPool); + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); } template @@ -902,7 +937,8 @@ using InitializeWorkspaceFn = std::function* DataParams, void* Workspace, size_t PerGemmWorkspaceStride, - MLAS_THREADPOOL* ThreadPool + MLAS_THREADPOOL* ThreadPool, + size_t BlkBitWidth )>; template @@ -1015,7 +1051,7 @@ MlasQNBitGemmBatch( if (const auto InitializeWorkspaceOperation = GetInitializeWorkspace(Variant); InitializeWorkspaceOperation != nullptr) { InitializeWorkspaceOperation( - M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool + M, N, K, BatchN, BlkLen, DataParams, Workspace, PerGemmWorkspaceStride, ThreadPool, BlkBitWidth ); } @@ -1029,17 +1065,19 @@ MlasQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, false); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + const_cast*>(Data)->BlkUnsignedQuantAZeroPointCorrection = packed_quant_b.BlkUnsignedQuantAZeroPointCorrection; + PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, 0, M, 0, N); } else { @@ -1107,7 +1145,7 @@ MlasQNBitGemmBatch( void* PerGemmWorkspace = reinterpret_cast(Workspace) + gemm_i * PerGemmWorkspaceStride; if (Variant == SQ4BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ4BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, false); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; @@ -1115,10 +1153,11 @@ MlasQNBitGemmBatch( PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); } else if (Variant == SQ8BitGemmVariant_CompInt8 && GetMlasPlatform().QNBitGemmDispatch->SQ8BitGemmKernel_BlkSum_CompInt8 != nullptr) { - PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen); + PackedQuantBDataStruct packed_quant_b(const_cast(Data->QuantBDataWorkspace), N, BlockCountK, BlkLen, GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned); const_cast*>(Data)->PackedQuantBData = packed_quant_b.PackedQuantBData; const_cast*>(Data)->QuantBBlkSum = packed_quant_b.QuantBBlkSum; const_cast*>(Data)->QuantBScale = packed_quant_b.PackedQuantBScale; + const_cast*>(Data)->BlkUnsignedQuantAZeroPointCorrection = packed_quant_b.BlkUnsignedQuantAZeroPointCorrection; PerGemmQuantAWorkspace per_gemm_quant_a_workspace(PerGemmWorkspace, M, BlockCountK, BlkLen); ComputeOperation(BlkLen, K, Data, &per_gemm_quant_a_workspace, RangeStartM, RangeCountM, RangeStartN, RangeCountN); diff --git a/src/lib/qnbitgemm.h b/src/lib/qnbitgemm.h new file mode 100644 index 0000000..7ec80c6 --- /dev/null +++ b/src/lib/qnbitgemm.h @@ -0,0 +1,566 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qnbitgemm.h + +Abstract: + + This module includes kernel function prototypes and helper functions for + implementing SQNBitGemm. + + SQNBitGemm is a matrix/matrix multiplication, A*B, where A is a float + matrix and B is a n-bit quantized integer matrix. B is block quantized, + meaning values of B are divided into blocks and each block has its own + scale and optional zero point. + +--*/ + +#pragma once + +#include "mlas_qnbit.h" +#include "mlasi.h" + +constexpr MLAS_FORCEINLINE size_t +MlasQNBitQuantBBlkSumAlignment() +{ + // 16 floats. this alignment is required by GemmFloatKernel + return 16 * sizeof(float); +} + +constexpr MLAS_FORCEINLINE size_t +MlasQNBitBlkDataSizeInBytes(size_t BlkBitWidth, size_t BlkLen) +{ + return BlkLen * BlkBitWidth / 8; +} + +MLAS_FORCEINLINE void* +MlasAlignAddress(void* addr, const size_t alignment) +{ + const uintptr_t QuantBBlkSumAddr = reinterpret_cast(addr); + addr = (void*)((QuantBBlkSumAddr + alignment - 1) & (~(alignment - 1))); + return addr; +} + +template +struct PackedQuantBDataStruct { + PackedQuantBDataStruct(void* PackedQuantBWorkspace, size_t N, size_t BlockCountK, size_t BlkLen, bool QuantAUnsigned) + : QuantBWorkspace_(PackedQuantBWorkspace), N_(N), BlockCountK_(BlockCountK), BlkLen_(BlkLen) + { + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(T); +#if defined(MLAS_TARGET_AMD64_IX86) + // avx512 requires alignment on a 64-byte boundary + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 64); +#elif defined (MLAS_TARGET_ARM64) + // Only for 8-bit Gemms is the `PackedQuantBData` is to be 32-byte aligned and + // there is enough memory allocated to support this alignment. + // See QNBitGemmPackQuantBDataSize(). + // When bit width is 4, there is no alignment guarantee. + // TODO(hasesh): Can we unify the alignment for 4-bit and 8-bit ARM64 Gemms so as to + // simpify this logic and make code here cleaner ? + if constexpr (BlkBitWidth == 8) { + PackedQuantBData = (std::byte*)MlasAlignAddress(PackedQuantBWorkspace, 32); + } + else { + PackedQuantBData = (std::byte*)PackedQuantBWorkspace; + } +#else + PackedQuantBData = (std::byte*)PackedQuantBWorkspace; +#endif + + QuantBBlkSum = (T*)(PackedQuantBData + PackedQuantBDataSize); + QuantBBlkSum = (T*)MlasAlignAddress(QuantBBlkSum, MlasQNBitQuantBBlkSumAlignment()); + + if (QuantAUnsigned) { + BlkUnsignedQuantAZeroPointCorrection = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); + BlkUnsignedQuantAZeroPointCorrection = (T*)MlasAlignAddress(BlkUnsignedQuantAZeroPointCorrection, MlasQNBitQuantBBlkSumAlignment()); + PackedQuantBScale = (T*)((std::byte*)BlkUnsignedQuantAZeroPointCorrection + BlkSumSize); + } else { + BlkUnsignedQuantAZeroPointCorrection = nullptr; + PackedQuantBScale = (T*)((std::byte*)QuantBBlkSum + BlkSumSize); + } + } + + std::byte* PackedQuantBData; + T* PackedQuantBScale; + T* QuantBBlkSum; + T* BlkUnsignedQuantAZeroPointCorrection; + + void* QuantBWorkspace_; + size_t N_, BlockCountK_, BlkLen_; +}; + +template +constexpr MLAS_FORCEINLINE size_t +MlasQNBitZeroPointsForBlksSizeInBytes(size_t BlkCount) +{ + if constexpr (BlkBitWidth <= 4) { + return MlasDivRoundup(BlkCount, 2); // 2 blocks per byte + } else { + return BlkCount; + } +} + +// +// Kernel dispatch structure. +// + +struct MLAS_QNBIT_GEMM_DISPATCH { + // + // Quantized B data packing function prototypes. + // + + /** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ + typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr; + + /** Gets size of packed quantized B data containing 8-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */ + typedef size_t(Q8BitGemmPackQuantBDataSize_Fn)( + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + Q8BitGemmPackQuantBDataSize_Fn* Q8BitGemmPackQuantBDataSize = nullptr; + + /** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */ + typedef void(Q4BitGemmPackQuantBData_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool + ); + + Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr; + Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr; + + typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool + ); + + SQ4BitGemmPackQuantBDataAndSumBlk_Fn* SQ4BitGemmPackQuantBDataAndBlkSum = nullptr; + + typedef void(SQ8BitGemmPackQuantBDataAndSumBlk_Fn)( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool + ); + + SQ8BitGemmPackQuantBDataAndSumBlk_Fn* SQ8BitGemmPackQuantBDataAndBlkSum = nullptr; + + // + // Workspace size calculation function prototypes. + // + + /** + * @brief Gets the required size in bytes of the per-GEMM intermediate workspace. + * Returns a size of zero if no intermediate workspace is needed. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BlkLen number of quantized values per block + * @param[in] HasZeroPoint whether zero points are provided + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ + typedef size_t(QNBitGemmPerGemmWorkspaceSize_Fn)( + size_t M, + size_t N, + size_t K, + size_t BlkLen, + bool HasZeroPoint, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t BlkBitWidth + ); + + QNBitGemmPerGemmWorkspaceSize_Fn* QNBitGemmPerGemmWorkspaceSize = nullptr; + + /** + * @brief Gets the required byte alignment of the per-GEMM intermediate workspace. + * + * @param[in] BlkLen number of quantized values per block + * @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values) + */ + typedef size_t(QNBitGemmPerGemmWorkspaceAlignment_Fn)( + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + ); + + QNBitGemmPerGemmWorkspaceAlignment_Fn* QNBitGemmPerGemmWorkspaceAlignment = nullptr; + + // + // SQNBIT_CompFp32 kernel function prototypes. + // + + /** + * @brief Multiply float matrix A with quantized 4-bit integer matrix B. + * B is block quantized and column major. + * This kernel handles the special case where M, the number of rows of A and C, is 1. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + */ + typedef void(SQ4BitGemmM1Kernel_CompFp32_Fn)( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias + ); + + SQ4BitGemmM1Kernel_CompFp32_Fn* SQ4BitGemmM1Kernel_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * It should have enough space for + * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen + * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are + * useful, but the kernel implementation can be simplified with the extra space. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp32_Fn)( + size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp32_Fn* SQ4BitBlkDequantBForSgemm_CompFp32 = nullptr; + + /** + * @brief Dequantize B into the format expected by the Sgemm kernel. + * B is a quantized 4-bit integer matrix that is block quantized and column major. + * This is equivalent to dequantizing B and then running MlasSgemmCopyPackB. + * + * @param BlkLen Number of values in a block. + * @param[out] FpData Supplies the output buffer for the dequantized B float data. + * It should have enough space for + * (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen + * elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are + * useful, but the kernel implementation can be simplified with the extra space. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param CountN Number of columns of B. + * @param CountK Number of rows of B. + * @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix. + */ + typedef void(Q4BitBlkDequantBForSgemm_CompFp16_Fn)( + size_t BlkLen, + MLAS_FP16* FpData, + const std::byte* QuantBData, + const MLAS_FP16* QuantBScale, + const std::byte* QuantBZeroPoint, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB + ); + + Q4BitBlkDequantBForSgemm_CompFp16_Fn* HQ4BitBlkDequantBForHgemm_CompFp16 = nullptr; + + // + // SQNBIT_CompInt8 kernel function prototypes. + // + + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * A should be packed using QuantizeA_Packed_CompInt8. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param PackedQuantBData Supplies the packed quantized B matrix data. + * @param[out] C Supplies the output C matrix. + * @param RangeStartM Start of M range. + * @param RangeCountM Number of rows of A and C. + * @param RangeStartN Start of N range. + * @param RangeCountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param ldc Number of elements between adjacent rows of C. + */ + typedef void(SQ4BitGemmKernel_Packed_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* PackedQuantBData, + float* C, + const size_t RangeStartM, + const size_t RangeCountM, + const size_t RangeStartN, + const size_t RangeCountN, + size_t CountK, + size_t ldc, + const float* Bias + ); + + SQ4BitGemmKernel_Packed_CompInt8_Fn* SQ4BitGemmKernel_Packed_CompInt8 = nullptr; + + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + */ + typedef size_t(SQ4BitGemmKernel_BlkSum_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum + ); + + SQ4BitGemmKernel_BlkSum_CompInt8_Fn* SQ4BitGemmKernel_BlkSum_CompInt8 = nullptr; + + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 8-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountN Number of columns of B and C. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks between adjacent columns of the quantized B matrix. + * @param Bias Bias vector of length N. + * @param ldc Number of elements between adjacent rows of C.. + * @param ABlockSum Supplies the blksum of A. + * @param QuantBBlkSum Supplies the blksum of B. + * @param BlkUnsignedQuantAZeroPointCorrection Supplies the optional input to de-bias the Gemm output to account for the +128 bias + addition when the activation input A is quantized to uint8. + */ + typedef size_t(SQ8BitGemmKernel_BlkSum_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum, + const float* BlkUnsignedQuantAZeroPointCorrection + ); + + SQ8BitGemmKernel_BlkSum_CompInt8_Fn* SQ8BitGemmKernel_BlkSum_CompInt8 = nullptr; + + /** + * @brief Multiply quantized 8-bit integer matrix A with quantized 4-bit integer matrix B. + * A and B are block quantized and B is column major. + * + * @param BlkLen Number of values in a block. + * @param QuantA Supplies the quantized A matrix. + Binary data containing block quantized int8 data and scale values. + * @param QuantBData Supplies the quantized B matrix block data. + * @param QuantBScale Supplies the quantized B matrix block scale values. + * @param QuantBZeroPoint Supplies the quantized B matrix block zero point values. Optional. + * @param[out] C Supplies the output C matrix. + * @param CountM Number of rows of A and C to process, an upper bound. + * @param CountN Number of columns of B and C to process. + * @param CountK Number of columns of A and rows of B. + * @param BlockCountK Number of blocks in one row of A and one column of B. + * @param ldc Number of elements between adjacent rows of C. + * @param Bias Bias vector of length N. + * + * @return The number of rows of A and C that were processed, at most CountM. + */ + typedef size_t(SQ4BitGemmKernel_CompInt8_Fn)( + size_t BlkLen, + const std::byte* QuantA, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float* Bias + ); + + SQ4BitGemmKernel_CompInt8_Fn* SQ4BitGemmKernel_CompInt8 = nullptr; + + /** + * @brief Whether to use SQ4BitGemmKernel_Packed_CompInt8 for this problem. + */ + typedef bool(UsePacked_CompInt8_Fn)( + size_t K, + size_t BlkLen, + bool HasZp + ); + + UsePacked_CompInt8_Fn* UsePacked_CompInt8 = nullptr; + + /** + * @brief Block quantize values from matrix A from floats to quantized 8-bit integers. + * Used in conjunction with SQ4BitGemmKernel_Packed_CompInt8. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param CountM Number of rows of A. + * @param CountK Number of columns of A. + * @param[out] QuantA Supplies the output quantized A matrix. + * Binary data containing block quantized int8 data and scale values. + */ + typedef void(QuantizeA_Packed_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountM, + size_t CountK, + std::byte* QuantA + ); + + QuantizeA_Packed_CompInt8_Fn* QuantizeA_Packed_CompInt8 = nullptr; + + /** + * @brief Block quantize values from one row of matrix A from floats to quantized 8-bit integers. + * + * @param BlkLen Number of values in a block. + * @param A Supplies the A matrix. + * @param CountK Number of columns of A. + * @param[out] QuantA Supplies the output quantized A matrix. + * Binary data containing block quantized int8 data and scale values. + */ + typedef void(QuantizeARow_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA + ); + + QuantizeARow_CompInt8_Fn* QuantizeARow_CompInt8 = nullptr; + + typedef void(QuantizeARowComputeBlkSum_CompInt8_Fn)( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledGroupSum // scale_k * Sum_blklen(a_i) + ); + QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr; + + /** + * @brief Multiply fp16 matrix A rows with fp16 matrix B columns. + * Results are written to fp16 matrix C. + * If bias is provided, the bias are added to the result. + * + * @param A first row of the A matrix segment. Row major. + * @param B first column of the B matrix segment. Column major. + * @param Bias the bias at the target column. Optional. + * @param[out] C first element of the output matrix segment. Row major. + * @param CountM the number of rows of A chunk. + * @param CountN the number of columns of B chunk. + * @param K the number of columns of A matrix and rows of B matrix. + * @param lda the leading dimension of A. + * @param ldb the leading dimension of B. + * @param ldc the leading dimension of C. + */ + typedef void(HQ4BitGemmKernel_CompFp16_Fn)( + const MLAS_FP16* A, + const MLAS_FP16* B, + const MLAS_FP16* Bias, + MLAS_FP16* C, + size_t CountM, + size_t CountN, + size_t K, + size_t lda, + size_t ldb, + size_t ldc + ); + + HQ4BitGemmKernel_CompFp16_Fn* HQ4BitGemmKernel_CompFp16 = nullptr; +}; diff --git a/src/lib/qnbitgemm_kernel_neon.cpp b/src/lib/qnbitgemm_kernel_neon.cpp index 0d06eb0..ba2b68e 100644 --- a/src/lib/qnbitgemm_kernel_neon.cpp +++ b/src/lib/qnbitgemm_kernel_neon.cpp @@ -21,6 +21,7 @@ Module Name: #include #include +#include #include "qnbitgemm.h" #include "sqnbitgemm_q8_block.h" @@ -42,8 +43,9 @@ namespace // Quantized B data packing function implementation. // +template size_t -Q4BitGemmPackQuantBDataSize( +QNBitGemmPackQuantBDataSize( size_t N, size_t K, size_t BlkLen, @@ -51,26 +53,49 @@ Q4BitGemmPackQuantBDataSize( MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType ) { + if constexpr (BlkBitWidth == 4) { #ifndef USE_KLEIDIAI - MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); - MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType + MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); + MLAS_UNREFERENCED_PARAMETER(ComputeType); // same size regardless of ComputeType #endif #ifdef USE_KLEIDIAI - if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { - const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = GetKleidiAIGemmUKernel(); - const size_t nr = ukernel.get_nr(); - const size_t kr = ukernel.get_kr(); - const size_t sr = ukernel.get_sr(); - return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); - } else + if (ComputeType == SQNBIT_CompInt8 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { + const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = GetKleidiAIGemmUKernel(); + const size_t nr = ukernel.get_nr(); + const size_t kr = ukernel.get_kr(); + const size_t sr = ukernel.get_sr(); + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(N, K, nr, kr, sr, BlkLen, kai_dt_bf16); + } else #endif - { - constexpr size_t BlkBitWidth = 4; - + { + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize; + } + } else { const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); - return PackedQuantBDataSize; + size_t PackedQuantBDataSize = N * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + if (ComputeType == SQNBIT_CompInt8) { + const size_t ScaleSize = N * BlockCountK * sizeof(float); + size_t BlkSumSize = MlasDivRoundup(N, 16) * BlockCountK * 16 * sizeof(float); + + // align on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += PackedQuantBDataAlignment - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += BlkSumAlignment - 1; + + if constexpr (QuantAUnsigned) { + // 2 block sum + return PackedQuantBDataSize + ScaleSize + BlkSumSize + BlkSumSize; + } else { + return PackedQuantBDataSize + ScaleSize + BlkSumSize; + } + } else { + return PackedQuantBDataSize; + } } } @@ -199,6 +224,167 @@ SQ4BitGemmPackQuantBDataAndBlkSum( } } +void +Q8PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + float* BlkUnsignedQuantAZeroPointCorrectionBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t K, + const size_t BlkLen) +{ + constexpr size_t SubBlkLen = 4; + const size_t BlkCountK = MlasDivRoundup(K, BlkLen); + const size_t SubBlkPerBlk = BlkLen / SubBlkLen; + const size_t StrideN = BlkCountK * BlkLen; + const size_t Iterations = N * BlkCountK; + + // 4 rows x 8 columns pack together, then 4 rows x 4 columns, then per column. + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t c = tid / BlkCountK; + const size_t c8 = c & (~7), c8_res = c & 7; + const size_t c4 = c & (~3), c4_res = c & 3; + const size_t r_blk = tid % BlkCountK; + size_t r_subblk = r_blk * SubBlkPerBlk; + + const std::byte* src = QuantBDataBegin + c * StrideN + r_blk * BlkLen; + const uint8_t* src8 = reinterpret_cast(src); + + for (size_t i = 0; i < SubBlkPerBlk; ++i, src += SubBlkLen, ++r_subblk) { + if (c8 + 8 <= N) { // full 8 cols + std::byte* dest = + PackedQuantBDataBegin + c8 * StrideN + r_subblk * SubBlkLen * 8 + c8_res * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } else if (c4 + 4 <= N) { // full 4 cols + std::byte* dest = + PackedQuantBDataBegin + c4 * StrideN + r_subblk * SubBlkLen * 4 + c4_res * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } else { // remainder cols + std::byte* dest = + PackedQuantBDataBegin + c * StrideN + r_subblk * SubBlkLen; + std::copy(src, src + SubBlkLen, dest); + } + } + + if (BlkUnsignedQuantAZeroPointCorrectionBegin) { + const int accu = std::accumulate(src8, src8 + std::min(BlkLen, K - r_blk * BlkLen), 0); + + // for sgemmc + const size_t dst_offset = ((c / 16) * BlkCountK + r_blk) * 16 + c % 16; + BlkUnsignedQuantAZeroPointCorrectionBegin[dst_offset] = static_cast(accu); + } + } + ); +} + +void +Q8ComputePackBlkSum( + const size_t BlkLen, + const size_t N, + const size_t K, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + float* BlockSum2Begin, + MLAS_THREADPOOL* ThreadPool) +{ + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + std::vector QuantBScaleBeginCopy(N * BlockCountK); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t n8 = n & (~7), n8_res = n & 7; + const size_t n4 = n & (~3), n4_res = n & 3; + const size_t k_blk = tid % BlockCountK; + + const size_t src_blk_offset = n * BlockCountK + k_blk; + const float QuantBScale = QuantBScaleBeginCopy[src_blk_offset]; + uint8_t zp = 128; + if (QuantBZPBegin) { + const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset; + zp = (uint8_t)(*QuantBZP); + } + + // BlockSum is a width 16 row major matrix + const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset) = -QuantBScale * zp; + if (BlockSum2Begin) { + BlockSum2Begin[dst_offset] = QuantBScale * (static_cast(zp) * std::min(BlkLen, K - k_blk * BlkLen) - BlockSum2Begin[dst_offset]); + } + + // re-arrange scale to the same order as packed data + if (n4 + 4 > N) { // remainder cols + *(QuantBScaleBegin + n * BlockCountK + k_blk) = QuantBScale; + } else if (n8 + 8 > N) { // full 4 cols + *(QuantBScaleBegin + n4 * BlockCountK + k_blk * 4 + n4_res) = QuantBScale; + } else { // full 8 cols + *(QuantBScaleBegin + n8 * BlockCountK + k_blk * 8 + n8_res) = QuantBScale; + } + }); +} + +/** + * 4 rows x 8 cols pack together, along all K. Then 4 rows x 4 cols, along all K. + * When rows < 4, keep original layout. + * + * dotprod: vdotq_laneq_u32. + * convert quant a from int8 to uint8. zp is 128. + * + * i8mm: vusdotq_laneq_s32. + */ +void +SQ8BitGemmPackQuantBDataAndBlkSum( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType */, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // Pack the quantized weights + if (QuantBDataBegin) { + Q8PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, PackedQuantB.BlkUnsignedQuantAZeroPointCorrection, ThreadPool, N, K, BlkLen); + } else { + // We ignore the scales and zero points if they are provided when pre-packing the weights as there is + // some "state" associated with 'BlkUnsignedQuantAZeroPointCorrection'. + + // We accumulate the block sum into 'BlkUnsignedQuantAZeroPointCorrection' while packing the weights + // in the previous step. If we were to use 'scales' while pre-packing the weights and if there were no + // zero points, then we would enter 'Q8ComputePackBlkSum' twice - once while pre-packing the weights + // and once while pre-packing the scales which would lead to erroneous 'BlkUnsignedQuantAZeroPointCorrection' + // computation as the buffer is "used" in-place for the "block sum" temporary values (obtained while pre-packing + // the weights) and the actual 'BlkUnsignedQuantAZeroPointCorrection' which will use the scales. + // Hence, to ensure that the piece of logic to calculate 'BlkUnsignedQuantAZeroPointCorrection' is only invoked + // once, we do it while we are pre-packing the scales and ignore any provided 'scales' and 'zero points' while + // pre-packing the weights. + // The flip side is that the user has to ensure that this function is called once each for 'weights', + // 'scales', and 'zero points'. This is a reasonable expectation and hence we go with that design. + + // Pack the block scales + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); + } + + // Pack the blksum (and BlkUnsignedQuantAZeroPointCorrection if applicable) + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + Q8ComputePackBlkSum(BlkLen, N, K, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, PackedQuantB.BlkUnsignedQuantAZeroPointCorrection, ThreadPool); + } + } +} + // // Workspace size calculation function implementation. // @@ -210,19 +396,21 @@ QNBitGemmPerGemmWorkspaceSize( size_t K, size_t BlkLen, bool HasZeroPoint, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t BlkBitWidth ) { MLAS_UNREFERENCED_PARAMETER(N); #ifndef USE_KLEIDIAI MLAS_UNREFERENCED_PARAMETER(HasZeroPoint); + MLAS_UNREFERENCED_PARAMETER(BlkBitWidth); #endif switch (ComputeType) { case SQNBIT_CompInt8: { // workspace buffer is used for block quantization of A to int8 #ifdef USE_KLEIDIAI - if (UseKleidiAI(K, BlkLen, HasZeroPoint)) { + if (BlkBitWidth == 4 && UseKleidiAI(K, BlkLen, HasZeroPoint)) { const kai_matmul_clamp_f32_qai8dxp_qsi4c32p_ukernel& ukernel = M == 1? GetKleidiAIGemvUKernel() : GetKleidiAIGemmUKernel(); @@ -233,8 +421,10 @@ QNBitGemmPerGemmWorkspaceSize( } else #endif { + // workspace buffer is used for block quantization of A to int8 const size_t BlockCountK = MlasDivRoundup(K, BlkLen); - const size_t PerGemmWorkspaceSize = M * BlockCountK * Q8BlkSize(BlkLen); + // QuantData + Scale + BlkSum + const size_t PerGemmWorkspaceSize = M * BlockCountK * (Q8BlkSize(BlkLen) + sizeof(float)); return PerGemmWorkspaceSize; } } @@ -278,6 +468,77 @@ UseKleidiAI(size_t K, size_t BlkLen, bool HasZp) #endif } +template +size_t +SQ8BitGemmKernel_BlkSum_CompInt8( + const size_t BlkLen, + const std::byte* QuantA, + const float* QuantAScale, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* /*QuantBZeroPoint*/, + float* C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + const float* Bias, + size_t ldc, + const float* ABlockSum, + const float* QuantBBlkSum, + const float* BlkUnsignedQuantAZeroPointCorrection +) +{ + MlasQ8Int8GemmKernelNeon( + BlkLen, + reinterpret_cast*>(QuantA), + QuantAScale, + reinterpret_cast(QuantBData), + QuantBScale, + C, + CountM, + CountN, + CountK, + Bias, + ldc + ); + + { + float* c_blk = C; + const float* b_blk_sum = QuantBBlkSum; + + size_t RowsRemaining = CountM; + const float* a_blksum_row = ABlockSum; + while (RowsRemaining > 0) { + auto RowsHandled = MlasSgemmKernelAdd(a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f); + + c_blk += ldc * RowsHandled; + a_blksum_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + + if constexpr (QuantAUnsigned) { + { + assert(BlkUnsignedQuantAZeroPointCorrection != nullptr); + float* c_blk = C; + const float* b_blk_sum2 = BlkUnsignedQuantAZeroPointCorrection; + + size_t RowsRemaining = CountM; + const float* a_scale_row = QuantAScale; + while (RowsRemaining > 0) { + auto RowsHandled = MlasSgemmKernelAdd(a_scale_row, b_blk_sum2, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 128.f); + + c_blk += ldc * RowsHandled; + a_scale_row += BlockCountK * RowsHandled; + RowsRemaining -= RowsHandled; + } + } + } + + return CountM; +} + } // namespace sqnbitgemm_neon // @@ -286,7 +547,8 @@ UseKleidiAI(size_t K, size_t BlkLen, bool HasZp) const MLAS_QNBIT_GEMM_DISPATCH& GetMlasQNBitGemmDispatchNeon( - bool InitializeWithDotSupport + bool InitializeWithDotSupport, + bool InitializeWithI8MMSupport ) { // Note: The InitializeWithX parameters are only used in the invocation of this method that initializes the static @@ -295,9 +557,11 @@ GetMlasQNBitGemmDispatchNeon( static const MLAS_QNBIT_GEMM_DISPATCH MlasQNBitGemmDispatchNeon = [&]() { MLAS_QNBIT_GEMM_DISPATCH d; - d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::Q4BitGemmPackQuantBDataSize; + d.Q4BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<4, false>; + d.Q8BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<8, true>; d.SQ4BitGemmPackQuantBData = sqnbitgemm_neon::SQ4BitGemmPackQuantBData; d.SQ4BitGemmPackQuantBDataAndBlkSum = sqnbitgemm_neon::SQ4BitGemmPackQuantBDataAndBlkSum; + d.SQ8BitGemmPackQuantBDataAndBlkSum = sqnbitgemm_neon::SQ8BitGemmPackQuantBDataAndBlkSum; d.QNBitGemmPerGemmWorkspaceSize = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceSize; d.QNBitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::QNBitGemmPerGemmWorkspaceAlignment; @@ -310,12 +574,21 @@ GetMlasQNBitGemmDispatchNeon( d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8; d.UsePacked_CompInt8 = sqnbitgemm_neon::UsePacked_CompInt8; + d.QuantizeARowComputeBlkSum_CompInt8 = sqnbitgemm_neon::QuantizeARowComputeBlkSum_CompInt8; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = sqnbitgemm_neon::SQ8BitGemmKernel_BlkSum_CompInt8; + #ifdef USE_KLEIDIAI d.SQ4BitGemmKernel_Packed_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_Packed_CompInt8; d.QuantizeA_Packed_CompInt8 = sqnbitgemm_neon::QuantizeA_Packed_CompInt8; #endif } + if (InitializeWithI8MMSupport) { + d.Q8BitGemmPackQuantBDataSize = sqnbitgemm_neon::QNBitGemmPackQuantBDataSize<8, false>; + d.QuantizeARowComputeBlkSum_CompInt8 = sqnbitgemm_neon::QuantizeARowComputeBlkSum_CompInt8; + d.SQ8BitGemmKernel_BlkSum_CompInt8 = sqnbitgemm_neon::SQ8BitGemmKernel_BlkSum_CompInt8; + } + #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) d.HQ4BitGemmPackQuantBData = sqnbitgemm_neon::HQ4BitGemmPackQuantBData_CompFp16; d.HQ4BitBlkDequantBForHgemm_CompFp16 = sqnbitgemm_neon::HQ4BitBlkDequantBForHgemm_CompFp16; diff --git a/src/lib/qnbitgemm_kernel_neon.h b/src/lib/qnbitgemm_kernel_neon.h index a254ec9..c8be42b 100644 --- a/src/lib/qnbitgemm_kernel_neon.h +++ b/src/lib/qnbitgemm_kernel_neon.h @@ -123,6 +123,36 @@ QuantizeARow_CompInt8( std::byte* QuantA ); +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + +template +using QuantAType = typename std::conditional::type; + +template +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const QuantAType* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +); + size_t SQ4BitGemmKernel_CompInt8( size_t BlkLen, diff --git a/src/lib/quantize.cpp b/src/lib/quantize.cpp index fad174f..c5bf2b4 100644 --- a/src/lib/quantize.cpp +++ b/src/lib/quantize.cpp @@ -766,7 +766,7 @@ MlasQuantizeLinear( #else -#if defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) template<> void @@ -902,7 +902,7 @@ Return Value: } } -#if !defined(MLAS_TARGET_POWER) +#if !defined(MLAS_TARGET_POWER) && !defined(MLAS_TARGET_S390X) template void MLASCALL @@ -1681,6 +1681,204 @@ MlasRequantizeOutput( } } +#elif defined(MLAS_TARGET_S390X) + +template +void +MLASCALL +MlasRequantizeOutput( + const int32_t* Input, + size_t InputLeadingDimension, + OutputType* Output, + size_t OutputLeadingDimension, + const int32_t* Bias, + const float* Scale, + bool PerColumnScale, + OutputType ZeroPoint, + size_t StartM, + size_t StartN, + size_t CountM, + size_t CountN + ) +{ + float PerMatrixScaleValue = PerColumnScale ? 0.0f : *Scale; + float MinimumValue = float(std::numeric_limits::lowest() - ZeroPoint); + float MaximumValue = float(std::numeric_limits::max() - ZeroPoint); + + auto PerMatrixScaleVector = vec_splats(PerMatrixScaleValue); + auto MinimumVector = vec_splats(MinimumValue); + auto MaximumVector = vec_splats(MaximumValue); + auto ZeroPointVector = vec_splats(int32_t(ZeroPoint)); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(PerMatrixScaleVector); + + if (nullptr != Bias) { + Bias += StartN; + } + if (PerColumnScale) { + Scale += StartN; + } + + Input += StartM * InputLeadingDimension + StartN; + Output += StartM * OutputLeadingDimension + StartN; + + // + // Step through each row of the output matrix. + // + + while (CountM-- > 0) { + + const int32_t* bias = Bias; + const float* scale = PerColumnScale ? Scale : nullptr; + size_t n = CountN; + + auto* RowInput = Input; + auto* RowOutput = Output; + + // Process 16 cols at a time + + while (n >= 16) { + + auto IntegerVector0 = vec_xl(0, &RowInput[0]); + auto IntegerVector1 = vec_xl(0, &RowInput[4]); + auto IntegerVector2 = vec_xl(0, &RowInput[8]); + auto IntegerVector3 = vec_xl(0, &RowInput[12]); + RowInput += 16; + + if (bias != nullptr) { + IntegerVector0 = IntegerVector0 + vec_xl(0, &bias[0]); + IntegerVector1 = IntegerVector1 + vec_xl(0, &bias[4]); + IntegerVector2 = IntegerVector2 + vec_xl(0, &bias[8]); + IntegerVector3 = IntegerVector3 + vec_xl(0, &bias[12]); + bias += 16; + } + + auto FloatVector0 = vec_float(IntegerVector0); + auto FloatVector1 = vec_float(IntegerVector1); + auto FloatVector2 = vec_float(IntegerVector2); + auto FloatVector3 = vec_float(IntegerVector3); + + if (scale != nullptr) { + FloatVector0 = FloatVector0 * vec_xl(0, &scale[0]); + FloatVector1 = FloatVector1 * vec_xl(0, &scale[4]); + FloatVector2 = FloatVector2 * vec_xl(0, &scale[8]); + FloatVector3 = FloatVector3 * vec_xl(0, &scale[12]); + scale += 16; + } else { + FloatVector0 = FloatVector0 * PerMatrixScaleVector; + FloatVector1 = FloatVector1 * PerMatrixScaleVector; + FloatVector2 = FloatVector2 * PerMatrixScaleVector; + FloatVector3 = FloatVector3 * PerMatrixScaleVector; + } + + FloatVector0 = vec_max(FloatVector0, MinimumVector); + FloatVector1 = vec_max(FloatVector1, MinimumVector); + FloatVector2 = vec_max(FloatVector2, MinimumVector); + FloatVector3 = vec_max(FloatVector3, MinimumVector); + + FloatVector0 = vec_min(FloatVector0, MaximumVector); + FloatVector1 = vec_min(FloatVector1, MaximumVector); + FloatVector2 = vec_min(FloatVector2, MaximumVector); + FloatVector3 = vec_min(FloatVector3, MaximumVector); + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + auto IntegerOutVector0 = vec_signed(FloatVector0); + auto IntegerOutVector1 = vec_signed(FloatVector1); + auto IntegerOutVector2 = vec_signed(FloatVector2); + auto IntegerOutVector3 = vec_signed(FloatVector3); + + IntegerOutVector0 = IntegerOutVector0 + ZeroPointVector; + IntegerOutVector1 = IntegerOutVector1 + ZeroPointVector; + IntegerOutVector2 = IntegerOutVector2 + ZeroPointVector; + IntegerOutVector3 = IntegerOutVector3 + ZeroPointVector; + + auto ShortVector0 = vec_pack(IntegerOutVector0, IntegerOutVector1); + auto ShortVector1 = vec_pack(IntegerOutVector2, IntegerOutVector3); + auto CharVector = vec_pack(ShortVector0, ShortVector1); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(CharVector); + + vec_xst(CharVector, 0, (int8_t *) RowOutput); + RowOutput += 16; + n -= 16; + } + + while (n >= 4) { + int8_t OutputBuffer[16]; + + auto IntegerVector = vec_xl(0, &RowInput[0]); + RowInput += 4; + + if (bias != nullptr) { + IntegerVector = IntegerVector + vec_xl(0, &bias[0]); + bias += 4; + } + + auto FloatVector = vec_float(IntegerVector); + + if (scale != nullptr) { + FloatVector = FloatVector * vec_xl(0, scale); + scale += 4; + } else { + FloatVector = FloatVector * PerMatrixScaleVector; + } + + FloatVector = vec_max(FloatVector, MinimumVector); + FloatVector = vec_min(FloatVector, MaximumVector); + FloatVector = vec_round(FloatVector); + + auto IntegerOutVector = vec_signed(FloatVector); + IntegerOutVector = IntegerOutVector + ZeroPointVector; + + auto ShortVector = vec_pack(IntegerOutVector, vec_splats((int32_t) 0)); + auto CharVector = vec_pack(ShortVector, vec_splats((int16_t) 0)); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(CharVector); + + vec_xst(CharVector, 0, &(OutputBuffer[0])); + memcpy(RowOutput, OutputBuffer, 4); + + RowOutput += 4; + n -= 4; + } + + while (n > 0) { + auto IntegerValue = RowInput[0]; + RowInput += 1; + + if (bias != nullptr) { + IntegerValue += bias[0]; + bias += 1; + } + + float FloatValue = float(IntegerValue); + float ScaleValue = PerColumnScale ? *scale++ : PerMatrixScaleValue; + + FloatValue *= ScaleValue; + FloatValue = std::max(FloatValue, MinimumValue); + FloatValue = std::min(FloatValue, MaximumValue); + + IntegerValue = int32_t(MlasBitsOfFp32(FloatValue + MLAS_ROUNDING_BIAS_MAGIC)) - + MLAS_ROUNDING_BIAS_MAGIC_BITS; + + *RowOutput++ = OutputType(IntegerValue + ZeroPoint); + + n -= 1; + } + + // Next Row + Input += InputLeadingDimension; + Output += OutputLeadingDimension; + } +} + #elif defined(MLAS_LSX_INTRINSICS) template diff --git a/src/lib/rotary_embedding_kernel_avx2.cpp b/src/lib/rotary_embedding_kernel_avx2.cpp index 024e67d..e5d1327 100644 --- a/src/lib/rotary_embedding_kernel_avx2.cpp +++ b/src/lib/rotary_embedding_kernel_avx2.cpp @@ -235,8 +235,11 @@ RopeKernel_Avx2_fp32_Impl( __m256i in_mask_vec = _mm256_set_epi32(7, 6, 3, 2, 5, 4, 1, 0); float32x8_t real = _mm256_permutevar8x32_ps(real_s, in_mask_vec); float32x8_t imag = _mm256_permutevar8x32_ps(imag_s, in_mask_vec); - float32x8_t sin_val = _mm256_loadu_ps(sin_data+ i / 2); - float32x8_t cos_val = _mm256_loadu_ps(cos_data + i / 2); + // Use masked loads for sin/cos data to avoid reading beyond buffer bounds + size_t cos_sin_rem = rem / 2; + const __m256i cos_sin_mask = _mm256_loadu_si256((const __m256i*)(mask_buffer + 8 - cos_sin_rem)); + float32x8_t sin_val = _mm256_maskload_ps(sin_data + i / 2, cos_sin_mask); + float32x8_t cos_val = _mm256_maskload_ps(cos_data + i / 2, cos_sin_mask); //Compute Real and Imaginary output values float32x8_t real_out = _mm256_fmsub_ps(real, cos_val, _mm256_mul_ps(imag, sin_val)); float32x8_t imag_out = _mm256_fmadd_ps(real, sin_val, _mm256_mul_ps(imag, cos_val)); diff --git a/src/lib/s390x/DgemmKernel.cpp b/src/lib/s390x/DgemmKernel.cpp new file mode 100644 index 0000000..f7b7f7c --- /dev/null +++ b/src/lib/s390x/DgemmKernel.cpp @@ -0,0 +1,87 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernel.cpp + +Abstract: + + This module implements the kernels for the double precision matrix/matrix + multiply operation (DGEMM). + +--*/ +#include "DgemmKernelZVECTOR.h" + +size_t +MLASCALL +MlasDgemmKernel( + const double* A, + const double* B, + double* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + double alpha, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A - Supplies the address of matrix A. + + B - Supplies the address of matrix B. The matrix data has been packed using + MlasDgemmCopyPackB or MlasDgemmTransposePackB. + + C - Supplies the address of matrix C. + + CountK - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + CountM - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda - Supplies the first dimension of matrix A. + + ldc - Supplies the first dimension of matrix C. + + alpha - Supplies the scalar multiplier (see DGEMM definition). + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ +{ + size_t RowsHandled; + + MLAS_FLOAT64X2 AlphaBroadcast = MlasBroadcastFloat64x2(alpha); + + if (CountM >= 4) { + RowsHandled = MlasDgemmProcessCount<4>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else if (CountM >= 2) { + RowsHandled = MlasDgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else { + RowsHandled = MlasDgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } + + return RowsHandled; +} diff --git a/src/lib/s390x/DgemmKernelZVECTOR.h b/src/lib/s390x/DgemmKernelZVECTOR.h new file mode 100644 index 0000000..48dc7de --- /dev/null +++ b/src/lib/s390x/DgemmKernelZVECTOR.h @@ -0,0 +1,122 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + DgemmKernelZVECTOR.h + +Abstract: + + This module implements the kernels for the double precision matrix/matrix + multiply operation (DGEMM). + +--*/ + +#include "FgemmKernelZVECTOR.h" + +template +MLAS_FORCEINLINE +size_t +MlasDgemmProcessCount( + const double* A, + const double* B, + double* C, + size_t CountK, + size_t CountN, + size_t lda, + size_t ldc, + MLAS_FLOAT64X2 AlphaBroadcast, + bool ZeroMode + ) +{ + do { + + const double* a = A; + size_t k = CountK; + + MLAS_FLOAT64X2 Accumulators[RowCount][4]; + MLAS_FLOAT64X2 AElements[RowCount]; + MLAS_FLOAT64X2 ABroadcast[RowCount]; + + // + // Clear the block accumulators. + // + + MlasLoopUnroll()(Accumulators); + + // + // Compute the output block. + // + while (k >= 2) { + + MlasLoopUnroll()(AElements, a, lda); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 8); + + a += 2; + B += 8 * 2; + k -= 2; + } + if (k > 0) { + + MlasLoopUnroll()(ABroadcast, a, lda); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); + + a += 1; + B += 8; + k -= 1; + } + + if (CountN >= 8) { + + // + // Store the entire output block. + // + + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + + } else { + + // + // Store the partial output block. + // + + // + if (CountN >= 6) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } else if (CountN >= 4) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } else if (CountN >= 2) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } + // + // Store the remaining unaligned columns. + // + C += (CountN & ~1); + CountN &= 1; + + if (CountN > 0) { + + MlasLoopUnroll()(Accumulators, AlphaBroadcast); + + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + } + + break; + } + + C += 8; + CountN -= 8; + + } while (CountN > 0); + + return RowCount; +} + diff --git a/src/lib/s390x/FgemmKernelZVECTOR.h b/src/lib/s390x/FgemmKernelZVECTOR.h new file mode 100644 index 0000000..af77974 --- /dev/null +++ b/src/lib/s390x/FgemmKernelZVECTOR.h @@ -0,0 +1,333 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + FgemmKernelZVECTOR.h + +Abstract: + + This module implements the kernels for the single/double precision matrix/matrix + multiply operation (DGEMM/SGEMM). + +--*/ + +#include "mlasi.h" +#if defined(SINGLE) +#define MLAS_FLOATTYPE MLAS_FLOAT32X4 +#define MLAS_GEMMTYPE float +#define MLAS_LOAD_FLOAT MlasLoadFloat32x4 +#define MLAS_ZERO_FLOAT MlasZeroFloat32x4 +#define MLAS_STORE_FLOAT MlasStoreFloat32x4 +#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat32x4 +#define MLAS_MUL_FLOAT MlasMultiplyFloat32x4 +#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat32x4 +#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat32x4 +#else +#define MLAS_FLOATTYPE MLAS_FLOAT64X2 +#define MLAS_GEMMTYPE double +#define MLAS_LOAD_FLOAT MlasLoadFloat64x2 +#define MLAS_ZERO_FLOAT MlasZeroFloat64x2 +#define MLAS_STORE_FLOAT MlasStoreFloat64x2 +#define MLAS_EXTRACT_FLOAT MlasExtractLaneFloat64x2 +#define MLAS_MUL_FLOAT MlasMultiplyFloat64x2 +#define MLAS_MULADD_FLOAT MlasMultiplyAddFloat64x2 +#define MLAS_BROADCAST_FLOAT MlasBroadcastFloat64x2 +#endif +// +// Templates to ensure that a loop is unrolled. +// + +template +struct MlasLoopUnrollStep +{ + template + MLAS_FORCEINLINE + static + void + Step( + IterationArgs&&... Arguments + ) + { + IterationType::template Iteration(Arguments...); + MlasLoopUnrollStep::template Step(Arguments...); + } +}; + +template +struct MlasLoopUnrollStep +{ + template + MLAS_FORCEINLINE + static + void + Step( + IterationArgs&&... + ) + { + // Terminate the loop. + } +}; + +template +struct MlasLoopUnroll +{ + template + MLAS_FORCEINLINE + void + operator()( + IterationArgs&&... Arguments + ) + { + MlasLoopUnrollStep::template Step(Arguments...); + } +}; + +// +// Templates used with loop unrolling to perform an action on one row of the +// output. +// + +struct MlasFgemmZeroAccumulators +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4] + ) + { + Accumulators[Row][0] = MLAS_ZERO_FLOAT(); + Accumulators[Row][1] = MLAS_ZERO_FLOAT(); + Accumulators[Row][2] = MLAS_ZERO_FLOAT(); + Accumulators[Row][3] = MLAS_ZERO_FLOAT(); + } +}; + +struct MlasFgemmLoadAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE AElements[RowCount], + const MLAS_GEMMTYPE* A, + size_t lda + ) + { + AElements[Row] = MLAS_LOAD_FLOAT(A + Row * lda); + } +}; + +struct MlasFgemmBroadcastAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE ABroadcast[RowCount], + const MLAS_GEMMTYPE* A, + size_t lda + ) + { + ABroadcast[Row] = MLAS_BROADCAST_FLOAT(A + Row * lda); + } +}; + +template +struct MlasFgemmSplatAElements +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE AElements[RowCount], + MLAS_FLOATTYPE ABroadcast[RowCount] + ) + { + ABroadcast[Row] = vec_splat(AElements[Row], Lane); + } +}; + +struct MlasFgemmMultiplyAddRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_FLOATTYPE ABroadcast[RowCount], + MLAS_FLOATTYPE BElements[4] + ) + { + Accumulators[Row][0] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[0], Accumulators[Row][0]); + Accumulators[Row][1] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[1], Accumulators[Row][1]); + Accumulators[Row][2] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[2], Accumulators[Row][2]); + Accumulators[Row][3] = MLAS_MULADD_FLOAT(ABroadcast[Row], BElements[3], Accumulators[Row][3]); + } +}; + +template +MLAS_FORCEINLINE +void +MlasFgemmComputeBlock( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_FLOATTYPE ABroadcast[RowCount], + const MLAS_GEMMTYPE* B + ) +{ + MLAS_FLOATTYPE BElements[4]; +#if defined(SINGLE) + BElements[0] = MLAS_LOAD_FLOAT(B); + BElements[1] = MLAS_LOAD_FLOAT(B + 4); + BElements[2] = MLAS_LOAD_FLOAT(B + 8); + BElements[3] = MLAS_LOAD_FLOAT(B + 12); +#else + BElements[0] = MLAS_LOAD_FLOAT(B); + BElements[1] = MLAS_LOAD_FLOAT(B + 2); + BElements[2] = MLAS_LOAD_FLOAT(B + 4); + BElements[3] = MLAS_LOAD_FLOAT(B + 6); +#endif + + MlasLoopUnroll()(Accumulators, ABroadcast, BElements); +} + +struct MlasFgemmMultiplyAlphaRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[4], + MLAS_FLOATTYPE AlphaBroadcast + ) + { + Accumulators[Index] = MLAS_MUL_FLOAT(Accumulators[Index], AlphaBroadcast); + } +}; + +struct MlasFgemmMultiplyAlphaAddRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[4], + MLAS_FLOATTYPE AlphaBroadcast, + const MLAS_GEMMTYPE* C + ) + { +#if defined(SINGLE) + Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index], + AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 4)); +#else + Accumulators[Index] = MLAS_MULADD_FLOAT(Accumulators[Index], + AlphaBroadcast, MLAS_LOAD_FLOAT(C + Index * 2)); +#endif + } +}; + +struct MlasFgemmStoreRow +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[4], + MLAS_GEMMTYPE* C + ) + { +#if defined(SINGLE) + MLAS_STORE_FLOAT(C + Index * 4, Accumulators[Index]); +#else + MLAS_STORE_FLOAT(C + Index * 2, Accumulators[Index]); +#endif + } +}; + +template +struct MlasFgemmStoreVector +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_GEMMTYPE* C, + size_t ldc, + MLAS_FLOATTYPE AlphaBroadcast, + bool ZeroMode + ) + { + MLAS_GEMMTYPE* c = C + Row * ldc; + + if (ZeroMode) { + MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast); + } else { + MlasLoopUnroll()(Accumulators[Row], AlphaBroadcast, c); + } + + MlasLoopUnroll()(Accumulators[Row], c); + + // + // Shift down any unaligned elements to the bottom for further processing. + // + + if (VectorCount < 4) { + Accumulators[Row][0] = Accumulators[Row][VectorCount]; + } + } +}; + +struct MlasFgemmMultiplyAlphaTrailing +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_FLOATTYPE AlphaBroadcast + ) + { + Accumulators[Row][0] = MLAS_MUL_FLOAT(Accumulators[Row][0], AlphaBroadcast); + } +}; + +template +struct MlasFgemmStoreScalar +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOATTYPE Accumulators[RowCount][4], + MLAS_GEMMTYPE* C, + size_t ldc, + bool ZeroMode + ) + { + MLAS_GEMMTYPE* c = C + Row * ldc + Lane; + MLAS_GEMMTYPE Value = MLAS_EXTRACT_FLOAT(Accumulators[Row][0]); + + if (!ZeroMode) { + Value += *c; + } + + *c = Value; + } +}; + diff --git a/src/lib/s390x/Quantize.cpp b/src/lib/s390x/Quantize.cpp new file mode 100644 index 0000000..6bb4475 --- /dev/null +++ b/src/lib/s390x/Quantize.cpp @@ -0,0 +1,300 @@ +#include +#include "mlasi.h" +#include + +// NOTE: Vector commands (e.g., vec_xst) need C-style casting to support various compiler versions. +// ONNX Runtime CI pipelines do not build with all compiler versions. + +template +void +MLASCALL +MlasQuantizeLinearKernel( + const float* Input, + OutputType* Output, + size_t N, + float Scale, + OutputType ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + constexpr int32_t MinimumValue = std::numeric_limits::lowest(); + constexpr int32_t MaximumValue = std::numeric_limits::max(); + + auto ScaleVector = vec_splats(Scale); + auto MinimumValueVector = vec_splats(float(MinimumValue)); + auto MaximumValueVector = vec_splats(float(MaximumValue)); + auto ZeroPointVector = vec_splats(float(ZeroPoint)); + + while (N >= 16) { + auto FloatVector0 = vec_xl(0, Input); + auto FloatVector1 = vec_xl(0, Input + 4); + auto FloatVector2 = vec_xl(0, Input + 8); + auto FloatVector3 = vec_xl(0, Input + 12); + + FloatVector0 = FloatVector0 / ScaleVector; + FloatVector1 = FloatVector1 / ScaleVector; + FloatVector2 = FloatVector2 / ScaleVector; + FloatVector3 = FloatVector3 / ScaleVector; + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + FloatVector0 = FloatVector0 + ZeroPointVector; + FloatVector1 = FloatVector1 + ZeroPointVector; + FloatVector2 = FloatVector2 + ZeroPointVector; + FloatVector3 = FloatVector3 + ZeroPointVector; + + FloatVector0 = vec_max(FloatVector0, MinimumValueVector); + FloatVector1 = vec_max(FloatVector1, MinimumValueVector); + FloatVector2 = vec_max(FloatVector2, MinimumValueVector); + FloatVector3 = vec_max(FloatVector3, MinimumValueVector); + + FloatVector0 = vec_min(FloatVector0, MaximumValueVector); + FloatVector1 = vec_min(FloatVector1, MaximumValueVector); + FloatVector2 = vec_min(FloatVector2, MaximumValueVector); + FloatVector3 = vec_min(FloatVector3, MaximumValueVector); + + auto IntegerVector0 = vec_signed(FloatVector0); + auto IntegerVector1 = vec_signed(FloatVector1); + auto IntegerVector2 = vec_signed(FloatVector2); + auto IntegerVector3 = vec_signed(FloatVector3); + + auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); + auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); + + if constexpr (std::is_same_v || std::is_same_v) { + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, (int8_t *)Output); + } else { + static_assert(std::is_same_v || std::is_same_v); + vec_xst(ShortVector0, 0, (int16_t *)Output); + vec_xst(ShortVector1, 0, (int16_t *)&Output[8]); + } + + Output += 16; + Input += 16; + N -= 16; + } + + for (size_t n = 0; n < N; n++) { + float FloatValue = std::nearbyintf(Input[n] / Scale) + float(ZeroPoint); + FloatValue = std::max(FloatValue, float(MinimumValue)); + FloatValue = std::min(FloatValue, float(MaximumValue)); + Output[n] = (OutputType)(int32_t)FloatValue; + } +} + +template +void +MLASCALL +MlasQuantizeLinearInt4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +/*++ + +Routine Description: + + This routine quantizes the input buffer as int4 using the supplied quantization + parameters. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. Contains packed 4-bit elements. + + N - Supplies the number of elements to process. + + Scale - Supplies the quantization scale. + + ZeroPoint - Supplies the quantization zero point value. + +Return Value: + + None. + +--*/ +{ + constexpr int32_t MinimumValue = Int4Traits::Min; + constexpr int32_t MaximumValue = Int4Traits::Max; + using UnpackedType = typename Int4Traits::UnpackedType; + + auto ScaleVector = vec_splats(Scale); + auto MinimumValueVector = vec_splats(float(MinimumValue)); + auto MaximumValueVector = vec_splats(float(MaximumValue)); + auto ZeroPointVector = vec_splats(float(ZeroPoint)); + + // Holds 16 quantized 8-bit values that will be packed into the output as packed 4-bit values. + UnpackedType TmpOutput[16] = {}; + + while (N >= 16) { + auto FloatVector0 = vec_xl(0, Input); + auto FloatVector1 = vec_xl(0, Input + 4); + auto FloatVector2 = vec_xl(0, Input + 8); + auto FloatVector3 = vec_xl(0, Input + 12); + + FloatVector0 = FloatVector0 / ScaleVector; + FloatVector1 = FloatVector1 / ScaleVector; + FloatVector2 = FloatVector2 / ScaleVector; + FloatVector3 = FloatVector3 / ScaleVector; + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + FloatVector0 = FloatVector0 + ZeroPointVector; + FloatVector1 = FloatVector1 + ZeroPointVector; + FloatVector2 = FloatVector2 + ZeroPointVector; + FloatVector3 = FloatVector3 + ZeroPointVector; + + FloatVector0 = vec_max(FloatVector0, MinimumValueVector); + FloatVector1 = vec_max(FloatVector1, MinimumValueVector); + FloatVector2 = vec_max(FloatVector2, MinimumValueVector); + FloatVector3 = vec_max(FloatVector3, MinimumValueVector); + + FloatVector0 = vec_min(FloatVector0, MaximumValueVector); + FloatVector1 = vec_min(FloatVector1, MaximumValueVector); + FloatVector2 = vec_min(FloatVector2, MaximumValueVector); + FloatVector3 = vec_min(FloatVector3, MaximumValueVector); + + auto IntegerVector0 = vec_signed(FloatVector0); + auto IntegerVector1 = vec_signed(FloatVector1); + auto IntegerVector2 = vec_signed(FloatVector2); + auto IntegerVector3 = vec_signed(FloatVector3); + + auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); + auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); + + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, (int8_t *)(&TmpOutput[0])); + + MlasPackInt4Elements(Output++, TmpOutput[0], TmpOutput[1]); + MlasPackInt4Elements(Output++, TmpOutput[2], TmpOutput[3]); + MlasPackInt4Elements(Output++, TmpOutput[4], TmpOutput[5]); + MlasPackInt4Elements(Output++, TmpOutput[6], TmpOutput[7]); + MlasPackInt4Elements(Output++, TmpOutput[8], TmpOutput[9]); + MlasPackInt4Elements(Output++, TmpOutput[10], TmpOutput[11]); + MlasPackInt4Elements(Output++, TmpOutput[12], TmpOutput[13]); + MlasPackInt4Elements(Output++, TmpOutput[14], TmpOutput[15]); + + Input += 16; + N -= 16; + } + + for (size_t n = 0; n < N; n++) { + float FloatValue = std::nearbyintf(Input[n] / Scale) + static_cast(ZeroPoint); + FloatValue = std::max(FloatValue, static_cast(MinimumValue)); + FloatValue = std::min(FloatValue, static_cast(MaximumValue)); + UnpackedType IntValue = static_cast(FloatValue); + + MlasSetInt4Element(Output, n, IntValue); + } +} + +void +MLASCALL +MlasQuantizeLinearU8Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS8Kernel( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU16Kernel( + const float* Input, + uint16_t* Output, + size_t N, + float Scale, + uint16_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS16Kernel( + const float* Input, + int16_t* Output, + size_t N, + float Scale, + int16_t ZeroPoint + ) +{ + MlasQuantizeLinearKernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearU4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS4Kernel( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearInt4Kernel(Input, Output, N, Scale, ZeroPoint); +} diff --git a/src/lib/s390x/QuantizeZVECTOR.cpp b/src/lib/s390x/QuantizeZVECTOR.cpp new file mode 100644 index 0000000..d6d86d0 --- /dev/null +++ b/src/lib/s390x/QuantizeZVECTOR.cpp @@ -0,0 +1,149 @@ +#include "mlasi.h" +#include + +template +void +MLASCALL +MlasQuantizeLinearZVECTOR( + const float* Input, + OutputType* Output, + size_t N, + float Scale, + OutputType ZeroPoint + ) +{ + // Workaround for bad GCC warning that Scale is set but not used. + MLAS_UNREFERENCED_PARAMETER(Scale); + + constexpr int32_t MinimumValue = std::numeric_limits::min(); + constexpr int32_t MaximumValue = std::numeric_limits::max(); + + auto ScaleVector = vec_splats(Scale); + auto MinimumValueVector = vec_splats(float(MinimumValue)); + auto MaximumValueVector = vec_splats(float(MaximumValue)); + auto ZeroPointVector = vec_splats(float(ZeroPoint)); + + while (N >= 16) { + auto FloatVector0 = vec_xl(0, Input); + auto FloatVector1 = vec_xl(0, Input + 4); + auto FloatVector2 = vec_xl(0, Input + 8); + auto FloatVector3 = vec_xl(0, Input + 12); + + FloatVector0 /= ScaleVector; + FloatVector1 /= ScaleVector; + FloatVector2 /= ScaleVector; + FloatVector3 /= ScaleVector; + + FloatVector0 = vec_round(FloatVector0); + FloatVector1 = vec_round(FloatVector1); + FloatVector2 = vec_round(FloatVector2); + FloatVector3 = vec_round(FloatVector3); + + FloatVector0 += ZeroPointVector; + FloatVector1 += ZeroPointVector; + FloatVector2 += ZeroPointVector; + FloatVector3 += ZeroPointVector; + + FloatVector0 = vec_max(FloatVector0, MinimumValueVector); + FloatVector1 = vec_max(FloatVector1, MinimumValueVector); + FloatVector2 = vec_max(FloatVector2, MinimumValueVector); + FloatVector3 = vec_max(FloatVector3, MinimumValueVector); + + FloatVector0 = vec_min(FloatVector0, MaximumValueVector); + FloatVector1 = vec_min(FloatVector1, MaximumValueVector); + FloatVector2 = vec_min(FloatVector2, MaximumValueVector); + FloatVector3 = vec_min(FloatVector3, MaximumValueVector); + + auto IntegerVector0 = vec_signed(FloatVector0); + auto IntegerVector1 = vec_signed(FloatVector1); + auto IntegerVector2 = vec_signed(FloatVector2); + auto IntegerVector3 = vec_signed(FloatVector3); + + auto ShortVector0 = vec_pack(IntegerVector0, IntegerVector1); + auto ShortVector1 = vec_pack(IntegerVector2, IntegerVector3); + auto CharVector = vec_pack(ShortVector0, ShortVector1); + vec_xst(CharVector, 0, (int8_t *) Output); + + // Workaround for bad GCC warning that variable is set but not used. + MLAS_UNREFERENCED_PARAMETER(CharVector); + + Output += 16; + Input += 16; + N -= 16; + } + + while (N >= 4) { + auto FloatVector = vec_xl(0, Input); + FloatVector /= ScaleVector; + FloatVector = vec_round(FloatVector); + FloatVector += ZeroPointVector; + + FloatVector = vec_max(FloatVector, MinimumValueVector); + FloatVector = vec_min(FloatVector, MaximumValueVector); + auto IntegerVector = vec_signed(FloatVector); + + auto ShortVector = vec_pack(IntegerVector, vec_splats((int32_t) 0)); + auto CharVector = vec_pack(ShortVector, vec_splats((int16_t) 0)); + + OutputType tmp_output[sizeof(__vector float)/sizeof(OutputType)]; + vec_xst(CharVector, 0, (int8_t *) tmp_output); + memcpy(Output, tmp_output, N); + + // Workaround for bad GCC warning that variable is set but not used. + MLAS_UNREFERENCED_PARAMETER(CharVector); + + Output += 4; + Input += 4; + N -= 4; + } + + if (N > 0) { + float tmp_input[sizeof(__vector float) / sizeof(float)] = {}; + memcpy(tmp_input, Input, 4*N); + auto FloatVector = vec_xl(0, &(tmp_input[0])); + + FloatVector /= ScaleVector; + FloatVector = vec_round(FloatVector); + FloatVector += ZeroPointVector; + + FloatVector = vec_max(FloatVector, MinimumValueVector); + FloatVector = vec_min(FloatVector, MaximumValueVector); + auto IntegerVector = vec_signed(FloatVector); + + auto ShortVector = vec_pack(IntegerVector, vec_splats((int32_t) 0)); + auto CharVector = vec_pack(ShortVector, vec_splats((int16_t) 0)); + + OutputType tmp_output[sizeof(__vector float)/sizeof(OutputType)]; + vec_xst(CharVector, 0, (int8_t *) tmp_output); + memcpy(Output, tmp_output, N); + + // Workaround for bad GCC warning that variable is set but not used. + MLAS_UNREFERENCED_PARAMETER(CharVector); + } +} + +void +MLASCALL +MlasQuantizeLinearU8KernelZVECTOR( + const float* Input, + uint8_t* Output, + size_t N, + float Scale, + uint8_t ZeroPoint + ) +{ + MlasQuantizeLinearZVECTOR(Input, Output, N, Scale, ZeroPoint); +} + +void +MLASCALL +MlasQuantizeLinearS8KernelZVECTOR( + const float* Input, + int8_t* Output, + size_t N, + float Scale, + int8_t ZeroPoint + ) +{ + MlasQuantizeLinearZVECTOR(Input, Output, N, Scale, ZeroPoint); +} diff --git a/src/lib/s390x/SgemmKernel.cpp b/src/lib/s390x/SgemmKernel.cpp new file mode 100644 index 0000000..e7d8cde --- /dev/null +++ b/src/lib/s390x/SgemmKernel.cpp @@ -0,0 +1,87 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernel.cpp + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + +--*/ +#include "SgemmKernelZVECTOR.h" + +size_t +MLASCALL +MlasSgemmKernel( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A - Supplies the address of matrix A. + + B - Supplies the address of matrix B. The matrix data has been packed using + MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C - Supplies the address of matrix C. + + CountK - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + CountM - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda - Supplies the first dimension of matrix A. + + ldc - Supplies the first dimension of matrix C. + + alpha - Supplies the scalar multiplier (see SGEMM definition). + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ +{ + size_t RowsHandled; + + MLAS_FLOAT32X4 AlphaBroadcast = MlasBroadcastFloat32x4(alpha); + + if (CountM >= 4) { + RowsHandled = MlasSgemmProcessCount<4>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else if (CountM >= 2) { + RowsHandled = MlasSgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else { + RowsHandled = MlasSgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } + + return RowsHandled; +} diff --git a/src/lib/s390x/SgemmKernelZVECTOR.cpp b/src/lib/s390x/SgemmKernelZVECTOR.cpp new file mode 100644 index 0000000..e87913a --- /dev/null +++ b/src/lib/s390x/SgemmKernelZVECTOR.cpp @@ -0,0 +1,451 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelZVECTOR.cpp + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + +--*/ + +#include "SgemmKernelZVECTOR.h" + +#include + +struct MlasSgemmBroadcastAElementsZVECTOR +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 ABroadcast[RowCount], + const float* A, + size_t lda + ) + { + ABroadcast[0][Row] = A [Row * lda]; + } +}; + +template +MLAS_FORCEINLINE +void +MlasSgemmComputeAElements( + MLAS_FLOAT32X4 AElements[RowCount], + MLAS_FLOAT32X4 ABroadcast[RowCount] + ) +{ + const __vector unsigned char mask0 = { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23 }; + const __vector unsigned char mask3 = { 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 }; + const __vector unsigned char mask_even = { 0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27 }; + const __vector unsigned char mask_odd = { 4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31 }; + + __vector float a1,a2; + + a1 = vec_perm(AElements[0], AElements[1], mask_even); + a2 = vec_perm(AElements[2], AElements[3], mask_even); + ABroadcast[0] = vec_perm(a1, a2, mask0); + ABroadcast[2] = vec_perm(a1, a2, mask3); + a1 = vec_perm(AElements[0], AElements[1], mask_odd); + a2 = vec_perm(AElements[2], AElements[3], mask_odd); + ABroadcast[1] = vec_perm(a1, a2, mask0); + ABroadcast[3] = vec_perm(a1, a2, mask3); +} +template +MLAS_FORCEINLINE +void +MlasSgemmComputeBlockZVECTOR( + MLAS_FLOAT32X4 acc[32], + MLAS_FLOAT32X4 ABroadcast, + MLAS_FLOAT32X4 A2Broadcast, + const float* B, + size_t CountM + ) +{ + + MLAS_FLOAT32X4 AElements[8]; + + AElements[0] = vec_splats(ABroadcast[0]); + AElements[1] = vec_splats(ABroadcast[1]); + AElements[2] = vec_splats(ABroadcast[2]); + AElements[3] = vec_splats(ABroadcast[3]); + + if (CountM == 8) { + AElements[4] = vec_splats(A2Broadcast[0]); + AElements[5] = vec_splats(A2Broadcast[1]); + AElements[6] = vec_splats(A2Broadcast[2]); + AElements[7] = vec_splats(A2Broadcast[3]); + } + + MLAS_FLOAT32X4 BElements[4]; + + BElements[0] = MlasLoadFloat32x4(B); + BElements[1] = MlasLoadFloat32x4(B + 4); + BElements[2] = MlasLoadFloat32x4(B + 8); + BElements[3] = MlasLoadFloat32x4(B + 12); + + acc[0] = __builtin_s390_vfmasb(AElements[0], BElements[0], acc[0]); + acc[1] = __builtin_s390_vfmasb(AElements[1], BElements[0], acc[1]); + acc[2] = __builtin_s390_vfmasb(AElements[2], BElements[0], acc[2]); + acc[3] = __builtin_s390_vfmasb(AElements[3], BElements[0], acc[3]); + + acc[4] = __builtin_s390_vfmasb(AElements[0], BElements[1], acc[4]); + acc[5] = __builtin_s390_vfmasb(AElements[1], BElements[1], acc[5]); + acc[6] = __builtin_s390_vfmasb(AElements[2], BElements[1], acc[6]); + acc[7] = __builtin_s390_vfmasb(AElements[3], BElements[1], acc[7]); + + acc[8] = __builtin_s390_vfmasb(AElements[0], BElements[2], acc[8]); + acc[9] = __builtin_s390_vfmasb(AElements[1], BElements[2], acc[9]); + acc[10] = __builtin_s390_vfmasb(AElements[2], BElements[2], acc[10]); + acc[11] = __builtin_s390_vfmasb(AElements[3], BElements[2], acc[11]); + + acc[12] = __builtin_s390_vfmasb(AElements[0], BElements[3], acc[12]); + acc[13] = __builtin_s390_vfmasb(AElements[1], BElements[3], acc[13]); + acc[14] = __builtin_s390_vfmasb(AElements[2], BElements[3], acc[14]); + acc[15] = __builtin_s390_vfmasb(AElements[3], BElements[3], acc[15]); + + if (CountM == 8) { + acc[16] = __builtin_s390_vfmasb(AElements[4], BElements[0], acc[16]); + acc[17] = __builtin_s390_vfmasb(AElements[5], BElements[0], acc[17]); + acc[18] = __builtin_s390_vfmasb(AElements[6], BElements[0], acc[18]); + acc[19] = __builtin_s390_vfmasb(AElements[7], BElements[0], acc[19]); + + acc[20] = __builtin_s390_vfmasb(AElements[4], BElements[1], acc[20]); + acc[21] = __builtin_s390_vfmasb(AElements[5], BElements[1], acc[21]); + acc[22] = __builtin_s390_vfmasb(AElements[6], BElements[1], acc[22]); + acc[23] = __builtin_s390_vfmasb(AElements[7], BElements[1], acc[23]); + + acc[24] = __builtin_s390_vfmasb(AElements[4], BElements[2], acc[24]); + acc[25] = __builtin_s390_vfmasb(AElements[5], BElements[2], acc[25]); + acc[26] = __builtin_s390_vfmasb(AElements[6], BElements[2], acc[26]); + acc[27] = __builtin_s390_vfmasb(AElements[7], BElements[2], acc[27]); + + acc[28] = __builtin_s390_vfmasb(AElements[4], BElements[3], acc[28]); + acc[29] = __builtin_s390_vfmasb(AElements[5], BElements[3], acc[29]); + acc[30] = __builtin_s390_vfmasb(AElements[6], BElements[3], acc[30]); + acc[31] = __builtin_s390_vfmasb(AElements[7], BElements[3], acc[31]); + } +} +template +struct MlasSgemmStoreVectorZVECTOR +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Result[4], + float* C, + size_t ldc, + MLAS_FLOAT32X4 AlphaBroadcast, + bool ZeroMode + ) + { + MLAS_FLOAT32X4 *rowC; + if (ZeroMode) { + rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); + rowC[0] = Result[Row] * AlphaBroadcast; + } else { + rowC = reinterpret_cast(&C[Row * ldc + VectorCount]); + rowC[0] += Result[Row] * AlphaBroadcast; + } + } +}; + +struct MlasSgemmMultiplyAlphaTrailingZVECTOR +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[RowCount], + MLAS_FLOAT32X4 AlphaBroadcast + ) + { + Accumulators[Row] = MlasMultiplyFloat32x4(Accumulators[Row], AlphaBroadcast); + } +}; +template +struct MlasSgemmStoreScalarZVECTOR +{ + template + MLAS_FORCEINLINE + static + void + Iteration( + MLAS_FLOAT32X4 Accumulators[RowCount], + float* C, + size_t ldc, + bool ZeroMode + ) + { + float* c = C + Row * ldc + Lane; + float Value = Accumulators[Row][Lane]; + if (!ZeroMode) { + Value += *c; + } + + *c = Value; + } +}; + +template +MLAS_FORCEINLINE +size_t +MlasSgemmZVECTORProcessCount( + const float* A, + const float* B, + float* C, + size_t CountM, + size_t CountK, + size_t CountN, + size_t lda, + size_t ldc, + MLAS_FLOAT32X4 AlphaBroadcast, + bool ZeroMode + ) +{ + do { + + const float* a = A; + size_t k = CountK; + + MLAS_FLOAT32X4 AElements[RowCount]; + MLAS_FLOAT32X4 ABroadcast[RowCount] = { 0 }; + MLAS_FLOAT32X4 A2Broadcast[RowCount] = { 0 }; + MLAS_FLOAT32X4 acc[32] = { 0 }; + MLAS_FLOAT32X4 Accumulators[2][RowCount] = {{0}}; + + // + // Compute the output block. + // + while (k >= 4) { + + MlasLoopUnroll()(AElements, a, lda); + MlasSgemmComputeAElements(AElements, ABroadcast); + if (CountM == 8) { + MlasLoopUnroll()(AElements, a + ( lda * 4), lda); + MlasSgemmComputeAElements(AElements, A2Broadcast); + } + MlasSgemmComputeBlockZVECTOR(&acc[0], ABroadcast[0], A2Broadcast[0], B, CountM); + MlasSgemmComputeBlockZVECTOR(&acc[0], ABroadcast[1], A2Broadcast[1], B+16, CountM); + MlasSgemmComputeBlockZVECTOR(&acc[0], ABroadcast[2], A2Broadcast[2], B+32, CountM); + MlasSgemmComputeBlockZVECTOR(&acc[0], ABroadcast[3], A2Broadcast[3], B+48, CountM); + B += 16 * 4; + a += 4; + k -= 4; + } + + while (k > 0) { + MlasLoopUnroll()(ABroadcast, a, lda); + if (CountM == 8) { + MlasLoopUnroll()(A2Broadcast, a + (lda * 4), lda); + } + MlasSgemmComputeBlockZVECTOR(&acc[0], ABroadcast[0], A2Broadcast[0], B, CountM); + a += 1; + B += 16; + k -= 1; + } + if (CountN >= 16) { + + // + // Store the entire output block. + // + MlasLoopUnroll>()(acc, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 4, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 8, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 12, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(acc + 16, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 20, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 24, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 28, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + } + } else { + + // + // Store the partial output block. + // + + if (CountN >= 12) { + MlasLoopUnroll>()(acc, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 4, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 8, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(acc + 16, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 20, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 24, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + if (CountN - 12 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[1][i] = acc[i + 28]; + } + } + } + if (CountN - 12 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[0][i] = acc[i + 12]; + } + } + } else if (CountN >= 8) { + MlasLoopUnroll>()(acc, C, ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 4, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(acc + 16, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + MlasLoopUnroll>()(acc + 20, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + if (CountN - 8 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[1][i] = acc[i + 24]; + } + } + } + if (CountN - 8 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[0][i] = acc[i + 8]; + } + } + } else if (CountN >= 4) { + MlasLoopUnroll>()(acc, C, ldc, AlphaBroadcast, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(acc + 16, C + (ldc*4), ldc, AlphaBroadcast, ZeroMode); + if (CountN - 4 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[1][i] = acc[i + 20]; + } + } + } + if (CountN - 4 > 0) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[0][i] = acc[i + 4]; + } + } + } else { + for (size_t i = 0; i < 4; ++i) { + Accumulators[0][i] = acc[i]; + } + + if (CountM == 8) { + for (size_t i = 0; i < 4; ++i) { + Accumulators[1][i] = acc[i + 16]; + } + } + } + + // + // Store the remaining unaligned columns. + // + + C += (CountN & ~3); + CountN &= 3; + + if (CountN > 0) { + + MlasLoopUnroll()(Accumulators[0], AlphaBroadcast); + MlasLoopUnroll>()(Accumulators[0], C, ldc, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll()(Accumulators[1], AlphaBroadcast); + MlasLoopUnroll>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode); + } + if (CountN >= 2) { + MlasLoopUnroll>()(Accumulators[0], C, ldc, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode); + } + } + if (CountN >= 3) { + MlasLoopUnroll>()(Accumulators[0], C, ldc, ZeroMode); + if (CountM == 8) { + MlasLoopUnroll>()(Accumulators[1], C + (ldc*4), ldc, ZeroMode); + } + } + } + + break; + } + + C += 16; + CountN -= 16; + + } while (CountN > 0); + + return CountM; +} + +size_t +MLASCALL +MlasSgemmKernelZVECTOR( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + float alpha, + bool ZeroMode + ) +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A - Supplies the address of matrix A. + + B - Supplies the address of matrix B. The matrix data has been packed using + MlasSgemmCopyPackB or MlasSgemmTransposePackB. + + C - Supplies the address of matrix C. + + CountK - Supplies the number of columns from matrix A and the number of rows + from matrix B to iterate over. + + CountM - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda - Supplies the first dimension of matrix A. + + ldc - Supplies the first dimension of matrix C. + + alpha - Supplies the scalar multiplier (see SGEMM definition). + + ZeroMode - Supplies true if the output matrix must be zero initialized, + else false if the output matrix is accumulated into. + +Return Value: + + Returns the number of rows handled. + +--*/ +{ + size_t RowsHandled; + MLAS_FLOAT32X4 AlphaBroadcast = MlasBroadcastFloat32x4(alpha); + + if (CountM >= 8) { + RowsHandled = MlasSgemmZVECTORProcessCount<4>(A, B, C, 8 ,CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else if (CountM >= 4) { + RowsHandled = MlasSgemmZVECTORProcessCount<4>(A, B, C, 4, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else if (CountM >= 2) { + RowsHandled = MlasSgemmProcessCount<2>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } else { + RowsHandled = MlasSgemmProcessCount<1>(A, B, C, CountK, CountN, lda, ldc, AlphaBroadcast, ZeroMode); + } + + return RowsHandled; +} diff --git a/src/lib/s390x/SgemmKernelZVECTOR.h b/src/lib/s390x/SgemmKernelZVECTOR.h new file mode 100644 index 0000000..83c66eb --- /dev/null +++ b/src/lib/s390x/SgemmKernelZVECTOR.h @@ -0,0 +1,139 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + SgemmKernelZVECTOR.h + +Abstract: + + This module implements the kernels for the single precision matrix/matrix + multiply operation (SGEMM). + +--*/ + +#include "FgemmKernelZVECTOR.h" + +template +MLAS_FORCEINLINE +size_t +MlasSgemmProcessCount( + const float* A, + const float* B, + float* C, + size_t CountK, + size_t CountN, + size_t lda, + size_t ldc, + MLAS_FLOAT32X4 AlphaBroadcast, + bool ZeroMode + ) +{ + do { + + const float* a = A; + size_t k = CountK; + + MLAS_FLOAT32X4 Accumulators[RowCount][4]; + MLAS_FLOAT32X4 AElements[RowCount]; + MLAS_FLOAT32X4 ABroadcast[RowCount]; + + // + // Clear the block accumulators. + // + + MlasLoopUnroll()(Accumulators); + + // + // Compute the output block. + // + + while (k >= 4) { + + MlasLoopUnroll()(AElements, a, lda); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 16); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 32); + + MlasLoopUnroll>()(AElements, ABroadcast); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B + 48); + + a += 4; + B += 16 * 4; + k -= 4; + } + + while (k > 0) { + + MlasLoopUnroll()(ABroadcast, a, lda); + MlasFgemmComputeBlock(Accumulators, ABroadcast, B); + + a += 1; + B += 16; + k -= 1; + } + + if (CountN >= 16) { + + // + // Store the entire output block. + // + + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + + } else { + + // + // Store the partial output block. + // + + if (CountN >= 12) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } else if (CountN >= 8) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } else if (CountN >= 4) { + MlasLoopUnroll>()(Accumulators, C, ldc, AlphaBroadcast, ZeroMode); + } + + // + // Store the remaining unaligned columns. + // + + C += (CountN & ~3); + CountN &= 3; + + if (CountN > 0) { + + MlasLoopUnroll()(Accumulators, AlphaBroadcast); + + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + + if (CountN >= 2) { + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + } + + if (CountN >= 3) { + MlasLoopUnroll>()(Accumulators, C, ldc, ZeroMode); + } + } + + break; + } + + C += 16; + CountN -= 16; + + } while (CountN > 0); + + return RowCount; +} + diff --git a/src/lib/s390x/qgemm_kernel_zvector.cpp b/src/lib/s390x/qgemm_kernel_zvector.cpp new file mode 100644 index 0000000..92aa72b --- /dev/null +++ b/src/lib/s390x/qgemm_kernel_zvector.cpp @@ -0,0 +1,1409 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_zvector.cpp + +Abstract: + + This module implements QGEMM kernel for S390X. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" +#include + +struct MLAS_GEMM_QUANT_KERNEL_ZVECTOR +{ + typedef int8_t PackedAType; + typedef uint8_t PackedBType; + typedef int8_t OffsetAType; + typedef uint8_t OffsetBType; + static constexpr size_t PackedK = 4; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{ 16, 256, 384 }; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{ 16, 128, 128 }; +}; + +constexpr size_t MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_QUANT_KERNEL_ZVECTOR::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedStrides; + +using vector_int = __attribute__((vector_size(16))) int; + +template +MLAS_FORCEINLINE +static +vector_int vec_sum4s_impl(Vtype value) +{ + const __vector unsigned char mask_interleave = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + + __vector signed char signed_value = (__vector signed char) vec_perm(value, value, mask_interleave); + + auto tmp1 = vec_unpackh(vec_unpackh(signed_value)); + auto tmp2 = vec_unpackl(vec_unpackh(signed_value)); + auto tmp3 = vec_unpackh(vec_unpackl(signed_value)); + auto tmp4 = vec_unpackl(vec_unpackl(signed_value)); + + return (__vector int) (tmp1 + tmp2 + tmp3 + tmp4); +} + +#define INC_BUFFER(cnt) \ + ColumnSumBuffer += cnt; \ + if (ZeroPointB != nullptr) { \ + ZeroPointB += cnt; \ + } +template<> +MLAS_FORCEINLINE constexpr +int32_t +MlasGemmQuantFixupZeroPointA( + int32_t ZeroPointA, + bool AIsSigned + ) +{ + if (!AIsSigned) { + ZeroPointA = MLAS_GEMM_QUANT_KERNEL_ZVECTOR::OffsetAType(ZeroPointA ^ 0x80); + } + return ZeroPointA; +} + +template<> +MLAS_FORCEINLINE +int32_t +MlasGemmQuantFixupZeroPointB( + int32_t ZeroPointB, + bool BIsSigned + ) +{ + if (BIsSigned) { + ZeroPointB = MLAS_GEMM_QUANT_KERNEL_ZVECTOR::OffsetBType(ZeroPointB ^ 0x80); + } + return ZeroPointB; + +} + +template +void +MlasGemmQuantCopyPackA8x8( + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer + ) +{ + constexpr uint8_t Flip = (AIsSigned ? 0 : 0x80); + Vtype vmask = reinterpret_cast(vec_splats(Flip)); + + const __vector unsigned char mask0 = { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23 }; + const __vector unsigned char mask3 = { 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 }; + const __vector unsigned char mask_even = { 0, 1, 2, 3, 16, 17, 18, 19, 8, 9, 10, 11, 24, 25, 26, 27 }; + const __vector unsigned char mask_odd = { 4, 5, 6, 7, 20, 21, 22, 23, 12, 13, 14, 15, 28, 29, 30, 31 }; + + // Process eight rows of matrix A in a loop. + // + // The buffer is packed as a series of 4x4 byte vectors to help + // in getting into MMA loop. + // + // Unsigned buffers are converted to signed buffers in order to + // share a common kernel. + // This pattern is repeated (CountK / 4) times. + // + // If CountK is not aligned to a multiple of four, then the vector is padded + // with zeroes. + // + while (CountM >= 8) { + const uint8_t *a = A; + __vector int vsum = { 0 }; + __vector int vsum2 = { 0 }; + size_t y = CountK; + while (y >= 16) { + Vtype a1 = *reinterpret_cast(&a[0]); + Vtype a2 = *reinterpret_cast(&a[lda]); + Vtype a3 = *reinterpret_cast(&a[lda * 2]); + Vtype a4 = *reinterpret_cast(&a[lda * 3]); + Vtype vx = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + Vtype vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + Vtype vx2 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_odd)); + Vtype vx3 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_odd)); + Vtype vx4 = vec_perm(vx, vx1, mask0); + Vtype vx5 = vec_perm(vx2, vx3, mask0); + Vtype vx6 = vec_perm(vx, vx1, mask3); + Vtype vx7 = vec_perm(vx2, vx3, mask3); + a1 = *reinterpret_cast(&a[lda*4]); + a2 = *reinterpret_cast(&a[lda*5]); + a3 = *reinterpret_cast(&a[lda*6]); + a4 = *reinterpret_cast(&a[lda*7]); + vx = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + vx2 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_odd)); + vx3 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_odd)); + Vtype vx8 = vec_perm(vx, vx1, mask0); + Vtype vx9 = vec_perm(vx2, vx3, mask0); + Vtype vx10 = vec_perm(vx, vx1, mask3); + Vtype vx11 = vec_perm(vx2, vx3, mask3); + Vtype vxx = AIsSigned ? vx4 : vx4 - vmask; + vsum += vec_sum4s_impl(vxx); + *reinterpret_cast(&D[0]) = vxx; + vxx = AIsSigned ? vx5 : vx5 - vmask; + vsum += vec_sum4s_impl(vxx); + *reinterpret_cast(&D[16]) = vxx; + vxx = AIsSigned ? vx6 : vx6 - vmask; + vsum += vec_sum4s_impl(vxx); + *reinterpret_cast(&D[32]) = vxx; + vxx = AIsSigned ? vx7 : vx7 - vmask; + vsum += vec_sum4s_impl(vxx); + *reinterpret_cast(&D[48]) = vxx; + vxx = AIsSigned ? vx8 : vx8 - vmask; + *reinterpret_cast(&D[64]) = vxx; + vsum2 += vec_sum4s_impl(vxx); + vxx = AIsSigned ? vx9 : vx9 - vmask; + *reinterpret_cast(&D[80]) = vxx; + vsum2 += vec_sum4s_impl(vxx); + vxx = AIsSigned ? vx10 : vx10 - vmask; + *reinterpret_cast(&D[96]) = vxx; + vsum2 += vec_sum4s_impl(vxx); + vxx = AIsSigned ? vx11 : vx11 - vmask; + *reinterpret_cast(&D[112]) = vxx; + vsum2 += vec_sum4s_impl(vxx); + D += 16 * 8; + a += 16; + y -= 16; + } + size_t yval = y; + while (y >= 4) + { + int a1 = *reinterpret_cast(&a[0]); + int a2 = *reinterpret_cast(&a[lda]); + int a3 = *reinterpret_cast(&a[lda*2]); + int a4 = *reinterpret_cast(&a[lda*3]); + __vector int vx1 = { a1, a2, a3, a4}; + Vtype vx = AIsSigned ? reinterpret_cast(vx1) : reinterpret_cast(vx1) - vmask; + vsum += vec_sum4s_impl(vx); + *reinterpret_cast(&D[0]) = vx; + a1 = *reinterpret_cast(&a[lda*4]); + a2 = *reinterpret_cast(&a[lda*5]); + a3 = *reinterpret_cast(&a[lda*6]); + a4 = *reinterpret_cast(&a[lda*7]); + __vector int vx2 = { a1, a2, a3, a4}; + vx = AIsSigned ? reinterpret_cast(vx2) : reinterpret_cast(vx2) - vmask; + vsum2 += vec_sum4s_impl(vx); + if (CountK & 3) { + if (yval >= 12) { + *reinterpret_cast(&D[64]) = vx; + } else if (yval >= 8) { + *reinterpret_cast(&D[48]) = vx; + } else { + *reinterpret_cast(&D[32]) = vx; + } + } else { + if (yval >= 12) { + *reinterpret_cast(&D[48]) = vx; + } else if (yval >= 8) { + *reinterpret_cast(&D[32]) = vx; + } else { + *reinterpret_cast(&D[16]) = vx; + } + } + D += 16; + a += 4; + y -= 4; + } + if (yval >= 12) { + if (!(CountK & 3)) { + D += 48; + } + } else if (yval >= 8) { + if (!(CountK & 3)) { + D += 32; + } + } else if (yval >= 4) { + if (!(CountK & 3)) { + D += 16; + } + } + if (y >= 1) + { + Vtype a1 = vmask; + Vtype a2 = vmask; + Vtype a3 = vmask; + Vtype a4 = vmask; + a1[0] = a[0]; + a2[0] = a[lda]; + a3[0] = a[lda * 2]; + a4[0] = a[lda * 3]; + if (y >= 2) { + a1[1] = a[1]; + a2[1] = a[lda + 1]; + a3[1] = a[lda * 2 + 1]; + a4[1] = a[lda * 3 + 1]; + } + if (y >= 3) { + a1[2] = a[2]; + a2[2] = a[lda + 2]; + a3[2] = a[lda * 2 + 2]; + a4[2] = a[lda * 3 + 2]; + } + Vtype vx = reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + Vtype vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + Vtype vx2 = vec_perm(vx, vx1, mask0); + Vtype vx3 = AIsSigned ? vx2 : vx2 - vmask; + vsum += vec_sum4s_impl(vx3); + + *reinterpret_cast(&D[0]) = vx3; + a1 = vmask; + a2 = vmask; + a3 = vmask; + a4 = vmask; + a1[0] = a[lda * 4]; + a2[0] = a[lda * 5]; + a3[0] = a[lda * 6]; + a4[0] = a[lda * 7]; + if (y >= 2) { + a1[1] = a[lda * 4 + 1]; + a2[1] = a[lda * 5 + 1]; + a3[1] = a[lda * 6 + 1]; + a4[1] = a[lda * 7 + 1]; + } + if (y >= 3) { + a1[2] = a[lda * 4 + 2]; + a2[2] = a[lda * 5 + 2]; + a3[2] = a[lda * 6 + 2]; + a4[2] = a[lda * 7 + 2]; + } + vx = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + vx2 = vec_perm(vx, vx1, mask0); + vx3 = AIsSigned ? vx2 : vx2 - vmask; + vsum2 += vec_sum4s_impl(vx3); + if (CountK % 16 >= 12) { + *reinterpret_cast(&D[64]) = vx3; + D += 80; + } else if (CountK % 16 >= 8) { + *reinterpret_cast(&D[48]) = vx3; + D += 64; + } else if (CountK % 16 >= 4) { + *reinterpret_cast(&D[32]) = vx3; + D += 48; + } else { + *reinterpret_cast(&D[16]) = vx3; + D += 16 * 2; + } + a += 16; + } + A += lda * 8; + + vec_xst(vsum, 0, &(RowSumBuffer[0])); + vec_xst(vsum2, 16, &(RowSumBuffer[0])); + + RowSumBuffer += 8; + CountM -= 8; + } + + // Process four rows of matrix A in a loop. + // + if (CountM >= 4) + { + const uint8_t *a = A; + __vector int vsum = { 0 }; + size_t y = CountK; + + while (y >= 16) + { + Vtype a1 = *reinterpret_cast(&a[0]); + Vtype a2 = *reinterpret_cast(&a[lda]); + Vtype a3 = *reinterpret_cast(&a[lda * 2]); + Vtype a4 = *reinterpret_cast(&a[lda * 3]); + Vtype vx = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + Vtype vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + Vtype vx2 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_odd)); + Vtype vx3 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_odd)); + Vtype vx4 = vec_perm(vx, vx1, mask0); + Vtype vx5 = vec_perm(vx2, vx3, mask0); + Vtype vx6 = vec_perm(vx, vx1, mask3); + Vtype vx7 = vec_perm(vx2, vx3, mask3); + Vtype vx0 = AIsSigned ? vx4 : vx4 - vmask; + *reinterpret_cast(&D[0]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx5 : vx5 - vmask; + *reinterpret_cast(&D[16]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx6 : vx6 - vmask; + *reinterpret_cast(&D[32]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx7 : vx7 - vmask; + *reinterpret_cast(&D[48]) = vx0; + vsum += vec_sum4s_impl(vx0); + D += 16 * 4; + a += 16; + y -= 16; + } + while (y >= 4) + { + int a1 = *reinterpret_cast(&a[0]); + int a2 = *reinterpret_cast(&a[lda]); + int a3 = *reinterpret_cast(&a[lda*2]); + int a4 = *reinterpret_cast(&a[lda*3]); + __vector int vx1 = { a1, a2, a3, a4}; + Vtype vx = AIsSigned ? reinterpret_cast(vx1) : reinterpret_cast(vx1) - vmask; + *reinterpret_cast(&D[0]) = vx; + vsum += vec_sum4s_impl(vx); + D += 16; + a += 4; + y -= 4; + } + if (y >= 1) + { + Vtype vx = vmask; + vx[0] = a[0]; + vx[4] = a[lda]; + vx[8] = a[lda * 2]; + vx[12] = a[lda * 3]; + if (y >= 2) { + vx[1] = a[1]; + vx[5] = a[lda + 1]; + vx[9] = a[lda * 2 + 1]; + vx[13] = a[lda * 3 + 1]; + } + if (y >= 3) { + vx[2] = a[2]; + vx[6] = a[lda + 2]; + vx[10] = a[lda * 2 + 2]; + vx[14] = a[lda * 3 + 2]; + } + Vtype vx1 = AIsSigned ? vx : vx - vmask; + *reinterpret_cast(&D[0]) = vx1; + vsum += vec_sum4s_impl(vx1); + D += 16; + a += 16; + } + A += lda * 4; + + vec_xst(vsum, 0, &(RowSumBuffer[0])); + + RowSumBuffer += 4; + CountM -= 4; + } + + // Process remaining rows of matrix A in a loop. + // + if (CountM <= 3 && CountM > 0) { + const uint8_t *a = A; + size_t y = CountK; + __vector int vsum = { 0 }; + + while (y >= 16) { + Vtype a4 = vmask; + Vtype a2 = vmask; + Vtype a3 = vmask; + Vtype a1 = *reinterpret_cast(&a[0]); + if (CountM == 3) { + a3 = *reinterpret_cast(&a[lda * 2]); + } + if (CountM >= 2) { + a2 = *reinterpret_cast(&a[lda]); + } + Vtype vx = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_even)); + Vtype vx1 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_even)); + Vtype vx2 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a1), + reinterpret_cast<__vector int>(a2), + mask_odd)); + Vtype vx3 = + reinterpret_cast(vec_perm(reinterpret_cast<__vector int>(a3), + reinterpret_cast<__vector int>(a4), + mask_odd)); + Vtype vx4 = vec_perm(vx, vx1, mask0); + Vtype vx5 = vec_perm(vx2, vx3, mask0); + Vtype vx6 = vec_perm(vx, vx1, mask3); + Vtype vx7 = vec_perm(vx2, vx3, mask3); + Vtype vx0 = AIsSigned ? vx4 : vx4 - vmask; + *reinterpret_cast(&D[0]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx5 : vx5 - vmask; + *reinterpret_cast(&D[16]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx6 : vx6 - vmask; + *reinterpret_cast(&D[32]) = vx0; + vsum += vec_sum4s_impl(vx0); + vx0 = AIsSigned ? vx7 : vx7 - vmask; + *reinterpret_cast(&D[48]) = vx0; + vsum += vec_sum4s_impl(vx0); + D += 16 * 4; + a += 16; + y -= 16; + } + while (y >= 4) + { + Vtype vb = vmask; + __vector int vx1 = reinterpret_cast<__vector int>(vb); + vx1[0] = *reinterpret_cast(&a[0]); + if (CountM >= 2) { + vx1[1] = *reinterpret_cast(&a[lda]); + } + if (CountM >= 3) { + vx1[2] = *reinterpret_cast(&a[lda*2]); + } + Vtype vx = AIsSigned ? reinterpret_cast(vx1) : reinterpret_cast(vx1) - vmask; + *reinterpret_cast(&D[0]) = vx; + vsum += vec_sum4s_impl(vx); + D += 16; + a += 4; + y -= 4; + } + if (y >= 1) + { + Vtype vx = (Vtype) vec_splats(0); + vx[0] = a[0] ^ Flip; + if (y >= 2) { + vx[1] = a[1] ^ Flip; + } + if (y >= 3) { + vx[2] = a[2] ^ Flip; + } + if (CountM >= 2) { + vx[4] = a[lda] ^ Flip; + if (y >= 2) { + vx[5] = a[lda + 1] ^ Flip; + } + if (y >= 3) { + vx[6] = a[lda + 2] ^ Flip; + } + } + if (CountM == 3) { + vx[8] = a[lda * 2] ^ Flip; + if (y >= 2) { + vx[9] = a[lda * 2 + 1] ^ Flip; + } + if (y >= 3) { + vx[10] = a[lda * 2 + 2] ^ Flip; + } + } + *reinterpret_cast(&D[0]) = vx; + vsum += vec_sum4s_impl(vx); + D += 16; + } + *RowSumBuffer++ = vsum[0]; + if (CountM >= 2) { + *RowSumBuffer++ = vsum[1]; + } + if (CountM >= 3) { + *RowSumBuffer++ = vsum[2]; + } + } +} + +template +void +MlasGemmQuantCopyPackB8x8( + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer + ) +{ + [[maybe_unused]] constexpr uint8_t BitFlipValue = (BIsSigned ? 0x80 : 0); + typedef __vector unsigned char vec_t; + Vtype vmask = reinterpret_cast(vec_splats(BitFlipValue)); + vec_t mask = {0,4,8,12,1,5,9,13,2,6,10,14,3,7,11,15}; + + const __vector unsigned char vec_zero = { 0 }; + + // Copy columns from matrix B to the packed buffer. Signed buffers are + // converted to unsigned buffers in order to share a common kernel. + // + // If CountK is not aligned to a multiple of four, then the packed buffer + // is padded with zero vectors. + + // Process 16 columns of matrix B in a loop. + // + size_t PackedK = ((CountK + 4 - 1) / 4) * 16; + size_t k2 = PackedK; + size_t k3 = PackedK*2; + size_t k4 = PackedK*3; + + while (CountN >= 16) { + const uint8_t* b = B; + __vector unsigned int vsum = {0}; + __vector unsigned int vsum2 = {0}; + __vector unsigned int vsum3 = {0}; + __vector unsigned int vsum4 = {0}; + size_t y = CountK; + if (y >= 4) { + do { + Vtype b1 = *reinterpret_cast(&b[0]); + Vtype b2 = *reinterpret_cast(&b[ldb]); + Vtype b3 = *reinterpret_cast(&b[ldb*2]); + Vtype b4 = *reinterpret_cast(&b[ldb*3]); + Vtype t1 = vec_mergeh(b1, b3); + Vtype t2 = vec_mergel(b1, b3); + Vtype t3 = vec_mergeh(b2, b4); + Vtype t4 = vec_mergel(b2, b4); + b1 = vec_mergeh(t1, t3); + b2 = vec_mergel(t1, t3); + b3 = vec_mergeh(t2, t4); + b4 = vec_mergel(t2, t4); + vec_t vx1 = BIsSigned ? reinterpret_cast(b1 + vmask) : + reinterpret_cast(b1); + vec_t vx2 = BIsSigned ? reinterpret_cast(b2 + vmask) : + reinterpret_cast(b2); + vec_t vx3 = BIsSigned ? reinterpret_cast(b3 + vmask) : + reinterpret_cast(b3); + vec_t vx4 = BIsSigned ? reinterpret_cast(b4 + vmask) : + reinterpret_cast(b4); + *reinterpret_cast(&D[0]) = vx1; + *reinterpret_cast(&D[k2]) = vx2; + *reinterpret_cast(&D[k3]) = vx3; + *reinterpret_cast(&D[k4]) = vx4; + vsum += vec_sum4(vx1, vec_zero); + vsum2 += vec_sum4(vx2, vec_zero); + vsum3 += vec_sum4(vx3, vec_zero); + vsum4 += vec_sum4(vx4, vec_zero); + D += 16; + b += ldb*4; + y -= 4; + } while (y >= 4); + } + if (y >= 1) { + Vtype b1 = *reinterpret_cast(&b[0]); + Vtype b2 = (y >= 2) ? *reinterpret_cast(&b[ldb]) : vmask; + Vtype b3 = (y >= 3) ? *reinterpret_cast(&b[ldb*2]) : vmask; + Vtype b4 = vmask; + Vtype t1 = vec_mergeh(b1, b3); + Vtype t2 = vec_mergel(b1, b3); + Vtype t3 = vec_mergeh(b2, b4); + Vtype t4 = vec_mergel(b2, b4); + b1 = vec_mergeh(t1, t3); + b2 = vec_mergel(t1, t3); + b3 = vec_mergeh(t2, t4); + b4 = vec_mergel(t2, t4); + vec_t vx1 = BIsSigned ? reinterpret_cast(b1 + vmask) : + reinterpret_cast(b1); + vec_t vx2 = BIsSigned ? reinterpret_cast(b2 + vmask) : + reinterpret_cast(b2); + vec_t vx3 = BIsSigned ? reinterpret_cast(b3 + vmask) : + reinterpret_cast(b3); + vec_t vx4 = BIsSigned ? reinterpret_cast(b4 + vmask) : + reinterpret_cast(b4); + *reinterpret_cast(&D[0]) = vx1; + *reinterpret_cast(&D[k2]) = vx2; + *reinterpret_cast(&D[k3]) = vx3; + *reinterpret_cast(&D[k4]) = vx4; + vsum += vec_sum4(vx1, vec_zero); + vsum2 += vec_sum4(vx2, vec_zero); + vsum3 += vec_sum4(vx3, vec_zero); + vsum4 += vec_sum4(vx4, vec_zero); + D += 16; + } + + vec_xst(vsum, 0, (unsigned int*) ColumnSumBuffer); + vec_xst(vsum2, 16, (unsigned int*) ColumnSumBuffer); + vec_xst(vsum3, 32, (unsigned int*) ColumnSumBuffer); + vec_xst(vsum4, 48, (unsigned int*) ColumnSumBuffer); + + ColumnSumBuffer += 16; + B += 16; + CountN -= 16; + D += k4; + } + + // Process four columns of matrix B in a loop. + // + while (CountN >= 4) { + const uint8_t* b = B; + __vector unsigned int vsum = {0}; + size_t y = CountK; + if (y >= 4) { + do { + int b1 = *reinterpret_cast(&b[0]); + int b2 = *reinterpret_cast(&b[ldb]); + int b3 = *reinterpret_cast(&b[ldb*2]); + int b4 = *reinterpret_cast(&b[ldb*3]); + __vector int vb = {b1, b2, b3, b4}; + Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); + vec_t vx1 = BIsSigned ? reinterpret_cast(vx + vmask) : + reinterpret_cast(vx); + *reinterpret_cast(&D[0]) = vx1; + vsum += vec_sum4(vx1, vec_zero); + D += 16; + b += ldb*4; + y -= 4; + } while (y >= 4); + } + if (y >= 1) { + Vtype vb = vmask; + __vector int vb1 = reinterpret_cast<__vector int>(vb); + vb1[0] = *reinterpret_cast(&b[0]); + if (y >= 2) { + vb1[1] = *reinterpret_cast(&b[ldb]); + } + if (y >= 3) { + vb1[2] = *reinterpret_cast(&b[ldb*2]); + } + Vtype vx = vec_perm(reinterpret_cast(vb1), reinterpret_cast(vb1), mask); + vec_t vx1 = BIsSigned ? reinterpret_cast(vx + vmask) : + reinterpret_cast(vx); + *reinterpret_cast(&D[0]) = vx1; + vsum += vec_sum4(vx1, vec_zero); + D += 16; + } + + vec_xst(vsum, 0, (unsigned int*) ColumnSumBuffer); + + ColumnSumBuffer += 4; + B += 4; + CountN -= 4; + } + + // + // Process the remaining columns of matrix B. + // + if (CountN > 0) { + __vector unsigned int vsum = {0}; + const uint8_t* b = B; + size_t y = CountK; + if (y >= 4) { + do { + Vtype vb = vmask; + if (CountN == 1) { + vb[0] = b[0]; + vb[4] = b[ldb]; + vb[8] = b[ldb*2]; + vb[12] = b[ldb*3]; + } + if (CountN == 2) { + vb[0] = b[0]; + vb[1] = b[1]; + vb[4] = b[ldb]; + vb[5] = b[ldb+1]; + vb[8] = b[ldb*2]; + vb[9] = b[ldb*2+1]; + vb[12] = b[ldb*3]; + vb[13] = b[ldb*3+1]; + } + if (CountN == 3) { + vb[0] = b[0]; + vb[1] = b[1]; + vb[2] = b[2]; + vb[4] = b[ldb]; + vb[5] = b[ldb+1]; + vb[6] = b[ldb+2]; + vb[8] = b[ldb*2]; + vb[9] = b[ldb*2+1]; + vb[10] = b[ldb*2+2]; + vb[12] = b[ldb*3]; + vb[13] = b[ldb*3+1]; + vb[14] = b[ldb*3+2]; + } + Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); + vec_t vx1 = BIsSigned ? reinterpret_cast(vx + vmask) : + reinterpret_cast(vx); + *reinterpret_cast(&D[0]) = vx1; + vsum += vec_sum4(vx1, vec_zero); + D += 16; + b += ldb*4; + y -= 4; + } while (y >= 4); + } + if (y >= 1) { + Vtype vb = vmask; + if (CountN == 1) { + vb[0]= b[0]; + if (y >= 2) { + vb[4] = b[ldb]; + } + if (y >= 3) { + vb[8] = b[ldb*2]; + } + } + if (CountN == 2) { + vb[0] = b[0]; + vb[1] = b[1]; + if (y >= 2) { + vb[4] = b[ldb]; + vb[5] = b[ldb+1]; + } + if (y >= 3) { + vb[8] = b[ldb*2]; + vb[9] = b[ldb*2+1]; + } + } + if (CountN == 3) { + vb[0] = b[0]; + vb[1] = b[1]; + vb[2] = b[2]; + if (y >= 2) { + vb[4] = b[ldb]; + vb[5] = b[ldb+1]; + vb[6] = b[ldb+2]; + } + if (y >= 3) { + vb[8] = b[ldb*2]; + vb[9] = b[ldb*2+1]; + vb[10] = b[ldb*2+2]; + } + } + Vtype vx = vec_perm(reinterpret_cast(vb), reinterpret_cast(vb), mask); + vec_t vx1 = BIsSigned ? reinterpret_cast(vx + vmask) : + reinterpret_cast(vx); + *reinterpret_cast(&D[0]) = vx1; + vsum += vec_sum4(vx1, vec_zero); + D += 16; + } + *ColumnSumBuffer++ = vsum[0]; + if (CountN >= 2) { + *ColumnSumBuffer++ = vsum[1]; + } + if (CountN >= 3) { + *ColumnSumBuffer++ = vsum[2]; + } + } +} + +template<> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned + ) +{ + if (AIsSigned) { + MlasGemmQuantCopyPackA8x8<__vector signed char, true> (D, A, lda, CountM, CountK, RowSumBuffer); + } else { + MlasGemmQuantCopyPackA8x8<__vector unsigned char, false>(D, A, lda, CountM, CountK, RowSumBuffer); + } +} +template<> +void +MlasGemmQuantCopyPackB( + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned + ) +{ + if (BIsSigned) { + MlasGemmQuantCopyPackB8x8<__vector signed char, true>(D, B, ldb, CountN, CountK, ColumnSumBuffer); + } else { + MlasGemmQuantCopyPackB8x8< __vector unsigned char, false>(D, B, ldb, CountN, CountK, ColumnSumBuffer); + } +} + +template +MLAS_FORCEINLINE +void +MlasQgemmStoreVectorZVECTOR + ( + MLAS_INT32X4 result[4], + int32_t* C, + size_t ldc, + size_t row, + bool ZeroMode, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + int pos + ) +{ + size_t RowCount; + __vector signed int vsum0, vsum1, vsum2, vsum3; + __vector signed int columnsum = *reinterpret_cast(&ColumnSumBuffer[pos]); + C += VectorCount; + if (ZeroPointB != nullptr) { + __vector signed int zeropoint = *reinterpret_cast(&ZeroPointB[pos]); + if (ZeroMode) { + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) * zeropoint + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) * zeropoint + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) * zeropoint + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + } + } else { + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) * zeropoint + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) * zeropoint + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) * zeropoint + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) * zeropoint + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + } + } + } else { + if (ZeroMode) { + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) = + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + } + } else { + for (RowCount = 0; RowCount + 4 <= row; RowCount += 4, C += ldc*4) { + vsum0 = vec_splats(RowSumBuffer[RowCount + 0]) + columnsum; + vsum1 = vec_splats(RowSumBuffer[RowCount + 1]) + columnsum; + vsum2 = vec_splats(RowSumBuffer[RowCount + 2]) + columnsum; + vsum3 = vec_splats(RowSumBuffer[RowCount + 3]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + *reinterpret_cast<__vector int *>(&C[ldc]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 1]) + vsum1; + *reinterpret_cast<__vector int *>(&C[ldc*2]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 2]) + vsum2; + *reinterpret_cast<__vector int *>(&C[ldc*3]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 3]) + vsum3; + } + for (; RowCount < row; RowCount++, C += ldc) { + vsum0 = vec_splats(RowSumBuffer[RowCount]) + columnsum; + *reinterpret_cast<__vector int *>(&C[0]) += + *reinterpret_cast<__vector int *>(&result[RowCount + 0]) + vsum0; + } + } + } +}; +template +MLAS_FORCEINLINE +void +MlasQgemmStoreScalarZVECTOR( + MLAS_INT32X4 result[4], + int32_t* C, + size_t ldc, + size_t row, + bool ZeroMode, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB + ) +{ + if (ZeroPointB != nullptr) { + if (ZeroMode) { + for (size_t RowCount = 0;RowCount < row; RowCount++){ + int sum = RowSumBuffer[RowCount]; + sum *= ZeroPointB[0]; + sum += ColumnSumBuffer[0]; + C[RowCount*ldc+Lane] = result[RowCount][Lane] + sum; + } + } else { + for (size_t RowCount = 0;RowCount < row; RowCount++){ + int sum = RowSumBuffer[RowCount]; + sum *= ZeroPointB[0]; + sum += ColumnSumBuffer[0]; + C[RowCount*ldc+Lane] += result[RowCount][Lane] + sum; + } + } + } else { + if (ZeroMode) { + for (size_t RowCount = 0;RowCount < row; RowCount++){ + int sum = RowSumBuffer[RowCount] + ColumnSumBuffer[0]; + C[RowCount*ldc+Lane] = result[RowCount][Lane] + sum; + } + } else { + for (size_t RowCount = 0;RowCount < row; RowCount++){ + int sum = RowSumBuffer[RowCount] + ColumnSumBuffer[0]; + C[RowCount*ldc+Lane] += result[RowCount][Lane] + sum; + } + } + } +}; + +MLAS_FORCEINLINE +void +xvi8ger4pp_impl( + MLAS_INT32X4 acc[4], + __vector unsigned char va, + __vector unsigned char vb + ) +{ + const __vector unsigned char maska[4] = { + { 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2 ,3, 0, 1, 2, 3 }, + { 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7, 4, 5, 6, 7 }, + { 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11, 8, 9, 10, 11 }, + { 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15, 12, 13, 14, 15 } + }; + + const __vector unsigned char maskb = { 0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15 }; + + __vector int va_interim[4]; + __vector unsigned char vb_interm = vec_perm(vb, vb, maskb); + + __vector int va_prep; + __vector int vb_prep[4]; + + va_interim[0] = (__vector int) vec_unpackh(vec_unpackh((__vector signed char) va)); + va_interim[1] = (__vector int) vec_unpackl(vec_unpackh((__vector signed char) va)); + va_interim[2] = (__vector int) vec_unpackh(vec_unpackl((__vector signed char) va)); + va_interim[3] = (__vector int) vec_unpackl(vec_unpackl((__vector signed char) va)); + + vb_prep[0] = (__vector int) vec_unpackh(vec_unpackh(vb_interm)); + vb_prep[1] = (__vector int) vec_unpackl(vec_unpackh(vb_interm)); + vb_prep[2] = (__vector int) vec_unpackh(vec_unpackl(vb_interm)); + vb_prep[3] = (__vector int) vec_unpackl(vec_unpackl(vb_interm)); + + for (size_t i = 0; i < 4; ++i) + { + for (size_t k = 0; k < 4; ++k) + { + va_prep = vec_perm(va_interim[i], va_interim[i], maska[k]); + + acc[i] += va_prep * vb_prep[k]; + } + } +} + +template +MLAS_FORCEINLINE +void +MlasQgemmComputeZVECTOR( + MLAS_INT32X4 acc0[4], + MLAS_INT32X4 acc1[4], + __vector unsigned char *va, + __vector unsigned char *vb + ) +{ + if (CountK == 16) { + xvi8ger4pp_impl(acc0, va[0], vb[0]); + xvi8ger4pp_impl(acc0, va[1], vb[1]); + xvi8ger4pp_impl(acc0, va[2], vb[2]); + xvi8ger4pp_impl(acc0, va[3], vb[3]); + if (CountM) { + xvi8ger4pp_impl(acc1, va[4], vb[0]); + xvi8ger4pp_impl(acc1, va[5], vb[1]); + xvi8ger4pp_impl(acc1, va[6], vb[2]); + xvi8ger4pp_impl(acc1, va[7], vb[3]); + } + } else if (CountK == 12) { + xvi8ger4pp_impl(acc0, va[0], vb[0]); + xvi8ger4pp_impl(acc0, va[1], vb[1]); + xvi8ger4pp_impl(acc0, va[2], vb[2]); + if (CountM) { + xvi8ger4pp_impl(acc1, va[3], vb[0]); + xvi8ger4pp_impl(acc1, va[4], vb[1]); + xvi8ger4pp_impl(acc1, va[5], vb[2]); + } + } else if (CountK == 8) { + xvi8ger4pp_impl(acc0, va[0], vb[0]); + xvi8ger4pp_impl(acc0, va[1], vb[1]); + if (CountM) { + xvi8ger4pp_impl(acc1, va[2], vb[0]); + xvi8ger4pp_impl(acc1, va[3], vb[1]); + } + } else { + xvi8ger4pp_impl(acc0, va[0], vb[0]); + if (CountM) { + xvi8ger4pp_impl(acc1, va[1], vb[0]); + } + } +}; +template<> +size_t +MlasGemmQuantKernel( + const MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedAType* A, + const MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode + ) +{ + if (CountM < 8 && CountM >= 4) { + CountM = 4; + } + size_t Mval = CountM; + if (Mval >= 8) { + Mval = 4; + } + while (CountN > 0) { + const int8_t *a = A; + typedef __vector unsigned char vec_t; + const uint8_t *b = B; + int32_t *C1; + MLAS_INT32X4 acc0[4] = {0}; + MLAS_INT32X4 acc1[4] = {0}; + MLAS_INT32X4 acc2[4] = {0}; + MLAS_INT32X4 acc3[4] = {0}; + MLAS_INT32X4 acc4[4] = {0}; + MLAS_INT32X4 acc5[4] = {0}; + MLAS_INT32X4 acc6[4] = {0}; + MLAS_INT32X4 acc7[4] = {0}; + MLAS_INT32X4 result[4] = {0}; + MLAS_INT32X4 result1[4] = {0}; + size_t k = PackedCountK * MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedK; + size_t k1 = PackedCountK; + // + // Compute the output block using POWER10 MMA builtins. + // + while (k >= 16) { + vec_t *va = const_cast(reinterpret_cast(a)); + vec_t *vb = const_cast(reinterpret_cast(b)); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*16])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*32])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*48])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } + b += 64; + if (CountM >= 8) { + a += 128; + } else { + a += 64; + } + k -= 16; + } + if (k >= 12) { + vec_t *va = const_cast(reinterpret_cast(a)); + vec_t *vb = const_cast(reinterpret_cast(b)); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*16])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*32])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*48])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } + if (CountM >= 8) { + a += 96; + } else { + a += 48; + } + b += 48; + k -= 12; + } + if (k >= 8) { + vec_t *va = const_cast(reinterpret_cast(a)); + vec_t *vb = const_cast(reinterpret_cast(b)); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*16])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*32])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*48])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } + if (CountM >= 8) { + a += 64; + } else { + a += 32; + } + b += 32; + k -= 8; + } + if (k >= 4) { + vec_t *va = const_cast(reinterpret_cast(a)); + vec_t *vb = const_cast(reinterpret_cast(b)); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc0, acc4, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*16])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc1, acc5, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*32])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc2, acc6, va, vb); + } + vb = const_cast(reinterpret_cast(&b[k1*48])); + if (CountM >= 8) { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } else { + MlasQgemmComputeZVECTOR(acc3, acc7, va, vb); + } + } + // Store matrix C with accumulator result. + if (CountN >= 16) { + MlasQgemmStoreVectorZVECTOR<0>(acc0, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc1, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorZVECTOR<8>(acc2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); + MlasQgemmStoreVectorZVECTOR<12>(acc3, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 12); + + if (CountM >= 8) { + C1 = C+ldc*4; + + MlasQgemmStoreVectorZVECTOR<0>(acc4, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc5, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorZVECTOR<8>(acc6, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 8); + MlasQgemmStoreVectorZVECTOR<12>(acc7, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 12); + } + INC_BUFFER(16); + CountN -= 16; + B += 16 * 4 *PackedCountK; + C += 16; + } else { + if (CountN >=12 ) { + MlasQgemmStoreVectorZVECTOR<0>(acc0, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc1, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorZVECTOR<8>(acc2, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 8); + if (CountM >= 8) { + C1 = C+ldc*4; + + MlasQgemmStoreVectorZVECTOR<0>(acc4, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc5, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); + MlasQgemmStoreVectorZVECTOR<8>(acc6, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 8); + } + INC_BUFFER(12); + if (CountN - 12 > 0) { + for (size_t i = 0; i < 4; ++i) { + result[i] = acc3[i]; + } + if (CountM >= 8) { + for (size_t i = 0; i < 4; ++i) { + result1[i] = acc7[i]; + } + } + } + CountN -= 12; + C += 12; + } else if (CountN >= 8) { + MlasQgemmStoreVectorZVECTOR<0>(acc0, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc1, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 4); + if (CountM >= 8) { + C1 = C+ldc*4; + + MlasQgemmStoreVectorZVECTOR<0>(acc4, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); + MlasQgemmStoreVectorZVECTOR<4>(acc5, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 4); + } + INC_BUFFER(8); + if (CountN - 8 > 0) { + for (size_t i = 0; i < 4; ++i) { + result[i] = acc2[i]; + } + if (CountM >= 8) { + for (size_t i = 0; i < 4; ++i) { + result1[i] = acc6[i]; + } + } + } + CountN -= 8; + C += 8; + } else if (CountN >= 4) { + MlasQgemmStoreVectorZVECTOR<0>(acc0, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB, 0); + if (CountM >= 8) { + C1 = C+ldc*4; + + MlasQgemmStoreVectorZVECTOR<0>(acc4, C1, ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB, 0); + } + INC_BUFFER(4); + if (CountN - 4 > 0) { + for (size_t i = 0; i < 4; ++i) { + result[i] = acc1[i]; + } + if (CountM >= 8) { + for (size_t i = 0; i < 4; ++i) { + result1[i] = acc5[i]; + } + } + } + CountN -= 4; + C += 4; + } else { + for (size_t i = 0; i < 4; ++i) { + result[i] = acc0[i]; + } + if (CountM >= 8) { + for (size_t i = 0; i < 4; ++i) { + result1[i] = acc4[i]; + } + } + } + CountN &= 3; + // + // Output the remaining partial output block. + // + if (CountN > 0) { + MlasQgemmStoreScalarZVECTOR<0>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + if (CountM >= 8) { + MlasQgemmStoreScalarZVECTOR<0>(result1, C + (ldc*4), ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB); + } + INC_BUFFER(1); + if (CountN >= 2) { + MlasQgemmStoreScalarZVECTOR<1>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + if (CountM >= 8) { + MlasQgemmStoreScalarZVECTOR<1>(result1, C + (ldc*4), ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB); + } + INC_BUFFER(1); + } + if (CountN >= 3) { + MlasQgemmStoreScalarZVECTOR<2>(result, C, ldc, Mval, ZeroMode, RowSumBuffer, ColumnSumBuffer, ZeroPointB); + if (CountM >= 8) { + MlasQgemmStoreScalarZVECTOR<2>(result1, C + (ldc*4), ldc, 4, ZeroMode, RowSumBuffer+4, ColumnSumBuffer, ZeroPointB); + } + INC_BUFFER(1); + } + } + CountN = 0; + } + } + if (CountM >= 8) { + return 8; + } + return CountM; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchZVECTOR = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedK, + MLAS_GEMM_QUANT_KERNEL_ZVECTOR::PackedStrides.K, + 8 // Kernel M stride +}; diff --git a/src/lib/sconv.h b/src/lib/sconv.h new file mode 100644 index 0000000..12ccff2 --- /dev/null +++ b/src/lib/sconv.h @@ -0,0 +1,29 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sconv.h + +Abstract: + + This module defines convolution kernel flags for configuring convolution + operations including output accumulation, bias addition, and activations. + +--*/ + +// +// Define the convolution kernel flags. +// + +#if defined(MLAS_USE_ARM_NEON_NCHWC) + +#define MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT 0x00000001 +#define MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION 0x00000002 +#define MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION 0x00000004 +#define MLAS_CONV_KERNEL_FLAG_OTHER_ACTIVATION 0x00000008 + +#endif diff --git a/src/lib/sconv_kernel_neon.cpp b/src/lib/sconv_kernel_neon.cpp new file mode 100644 index 0000000..4c5f50a --- /dev/null +++ b/src/lib/sconv_kernel_neon.cpp @@ -0,0 +1,524 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + sconv_kernel_neon.cpp + +Abstract: + + This module implements the single precision convolution kernels for ARM NEON. + +--*/ + +#if defined(MLAS_USE_ARM_NEON_NCHWC) + +#include "mlasi.h" +#include "sconv.h" + +constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; + +// Common implementation for NCHW and NCHWC convolution kernels +template +void + MLASCALL + MlasConvFloatKernelNeonImpl( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + + MLAS_UNREFERENCED_PARAMETER(InputStride); + + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount); + + for (size_t filterSetBlock = 0; filterSetBlock < FilterCount; filterSetBlock++) { + const float* filter = Filter + filterSetBlock * FilterStrideElements; + float* output = Output + filterSetBlock * OutputStrideElements; + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize]); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 4]); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 8]); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(&Bias[filterSetBlock * BlockSize + 12]); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_base = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if constexpr (IsNchwcFormat) { + for (size_t filterBlock = 0; filterBlock < BlockSize; filterBlock++) { + const float* input_element = input_base + filterBlock; + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + float input_value; + if (is_main_region || (input_element >= input_row_start && input_element < input_row_end)) { + input_value = *input_element; + } else { + input_value = 0.0f; + } + + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + size_t kernel_base_pos = kh * (KernelWidth * BlockSize * BlockSize) + + kw * (BlockSize * BlockSize) + + filterBlock * BlockSize; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&filter[kernel_base_pos]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&filter[kernel_base_pos + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&filter[kernel_base_pos + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&filter[kernel_base_pos + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } else { + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + float input_value; + if (is_main_region || (input_base >= input_row_start && input_base < input_row_end)) { + input_value = *input_base; + } else { + input_value = 0.0f; + } + + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + size_t kernel_base_pos = kh * KernelWidth + kw; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&filter[kernel_base_pos * BlockSize + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 12], Accumulator3); + } + } +} + +void + MLASCALL + MlasConvNchwFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + MlasConvFloatKernelNeonImpl( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +} + +// +// Implementation of MlasConvNchwcFloatKernelNeon +// + +void + MLASCALL + MlasConvNchwcFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + MlasConvFloatKernelNeonImpl( + Input, + Filter, + Output, + StrideWidth, + DilationWidth, + FilterCount, + InputStride, + FilterStride, + OutputStride, + KernelHeight, + KernelWidth, + InputBase, + InputWidth, + DilatedInputWidth, + OutputCountLeftPad, + OutputCount, + OutputCountRightPad, + Bias, + KernelFlags + ); +} + +// +// Helper function to load input vector with bounds checking +// +static inline float32x4_t +LoadInputVectorWithBounds( + const float* input_base, + size_t offset, + bool is_main_region, + const float* InputBase, + size_t kh, + size_t DilatedInputWidthElements, + size_t InputWidthElements +) +{ + if (is_main_region) { + return MlasLoadFloat32x4(input_base + offset); + } else { + float input_values[4]; + for (size_t i = 0; i < 4; i++) { + const float* input_element = input_base + offset + i; + const float* input_row_start = InputBase + kh * DilatedInputWidthElements; + const float* input_row_end = input_row_start + InputWidthElements; + + if (input_element >= input_row_start && input_element < input_row_end) { + input_values[i] = *input_element; + } else { + input_values[i] = 0.0f; + } + } + return MlasLoadFloat32x4(input_values); + } +} + +// +// Implementation of MlasConvDepthwiseFloatKernelNeon +// +// This kernel performs depthwise separable convolution where each input channel +// is convolved with its own filter. This is more efficient than standard convolution +// for certain network architectures like MobileNets. +// + +void + MLASCALL + MlasConvDepthwiseFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputStrideElements = InputStride / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + + MLAS_UNREFERENCED_PARAMETER(InputStrideElements); + + const size_t InputWidthElements = InputWidth / sizeof(float); + + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + bool is_main_region = (output_idx >= OutputCountLeftPad && output_idx < OutputCountLeftPad + OutputCount); + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&Output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&Output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(Bias); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(Bias + 4); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(Bias + 8); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(Bias + 12); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + for (size_t kw = 0; kw < KernelWidth; kw++) { + size_t kernel_pos = kh * KernelWidth + kw; + + const float* input_base = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + float32x4_t InputVector0 = LoadInputVectorWithBounds(input_base, 0, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector1 = LoadInputVectorWithBounds(input_base, 4, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector2 = LoadInputVectorWithBounds(input_base, 8, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + float32x4_t InputVector3 = LoadInputVectorWithBounds(input_base, 12, is_main_region, InputBase, kh, DilatedInputWidthElements, InputWidthElements); + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize]); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 4]); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 8]); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(&Filter[kernel_pos * BlockSize + 12]); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector0, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector1, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector2, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector3, FilterVector3, Accumulator3); + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], Accumulator3); + } +} + +// +// Implementation of MlasConvPointwiseFloatKernelNeon +// +// This kernel performs pointwise (1x1) convolution which is essentially +// a matrix multiplication across the channel dimension. It's optimized +// for cases where the kernel size is 1x1. +// + +void + MLASCALL + MlasConvPointwiseFloatKernelNeon( + const float* Input, + const float* Filter, + float* Output, + size_t StrideWidth, + size_t InputChannels, + size_t FilterCount, + size_t InputStride, + size_t FilterStride, + size_t OutputStride, + size_t OutputCount, + const float* Bias, + unsigned KernelFlags + ) +{ + const bool AccumulateOutput = (KernelFlags & MLAS_CONV_KERNEL_FLAG_ACCUMULATE_OUTPUT) != 0; + const bool BiasAddition = (KernelFlags & MLAS_CONV_KERNEL_FLAG_BIAS_ADDITION) != 0; + const bool ReluActivation = (KernelFlags & MLAS_CONV_KERNEL_FLAG_RELU_ACTIVATION) != 0; + + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t InputStrideElements = InputStride / sizeof(float); + const size_t FilterStrideElements = FilterStride / sizeof(float); + const size_t OutputStrideElements = OutputStride / sizeof(float); + + const float32x4_t ZeroVector = MlasBroadcastFloat32x4(0.0f); + + for (size_t output_idx = 0; output_idx < OutputCount; output_idx++) { + for (size_t f = 0; f < FilterCount; f++) { + const float* filter = Filter + f * FilterStrideElements; + float* output = Output + f * OutputStrideElements; + + float32x4_t Accumulator0, Accumulator1, Accumulator2, Accumulator3; + + if (AccumulateOutput) { + Accumulator0 = MlasLoadFloat32x4(&output[output_idx * BlockSize]); + Accumulator1 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 4]); + Accumulator2 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 8]); + Accumulator3 = MlasLoadFloat32x4(&output[output_idx * BlockSize + 12]); + } else { + Accumulator0 = MlasBroadcastFloat32x4(0.0f); + Accumulator1 = MlasBroadcastFloat32x4(0.0f); + Accumulator2 = MlasBroadcastFloat32x4(0.0f); + Accumulator3 = MlasBroadcastFloat32x4(0.0f); + } + + if (BiasAddition) { + const float32x4_t BiasVector0 = MlasLoadFloat32x4(&Bias[f * BlockSize]); + const float32x4_t BiasVector1 = MlasLoadFloat32x4(&Bias[f * BlockSize + 4]); + const float32x4_t BiasVector2 = MlasLoadFloat32x4(&Bias[f * BlockSize + 8]); + const float32x4_t BiasVector3 = MlasLoadFloat32x4(&Bias[f * BlockSize + 12]); + + Accumulator0 = MlasAddFloat32x4(Accumulator0, BiasVector0); + Accumulator1 = MlasAddFloat32x4(Accumulator1, BiasVector1); + Accumulator2 = MlasAddFloat32x4(Accumulator2, BiasVector2); + Accumulator3 = MlasAddFloat32x4(Accumulator3, BiasVector3); + } + + for (size_t c = 0; c < InputChannels; c++) { + const float* input_ptr = Input + c * InputStrideElements + output_idx * StrideWidthElements; + + for (size_t input_b = 0; input_b < BlockSize; input_b++) { + const float input_value = input_ptr[input_b]; + const float32x4_t InputVector = MlasBroadcastFloat32x4(input_value); + + const float* filter_ptr = filter + (c * BlockSize + input_b) * BlockSize; + + const float32x4_t FilterVector0 = MlasLoadFloat32x4(filter_ptr); + const float32x4_t FilterVector1 = MlasLoadFloat32x4(filter_ptr + 4); + const float32x4_t FilterVector2 = MlasLoadFloat32x4(filter_ptr + 8); + const float32x4_t FilterVector3 = MlasLoadFloat32x4(filter_ptr + 12); + + Accumulator0 = MlasMultiplyAddFloat32x4(InputVector, FilterVector0, Accumulator0); + Accumulator1 = MlasMultiplyAddFloat32x4(InputVector, FilterVector1, Accumulator1); + Accumulator2 = MlasMultiplyAddFloat32x4(InputVector, FilterVector2, Accumulator2); + Accumulator3 = MlasMultiplyAddFloat32x4(InputVector, FilterVector3, Accumulator3); + } + } + + if (ReluActivation) { + Accumulator0 = MlasMaximumFloat32x4(Accumulator0, ZeroVector); + Accumulator1 = MlasMaximumFloat32x4(Accumulator1, ZeroVector); + Accumulator2 = MlasMaximumFloat32x4(Accumulator2, ZeroVector); + Accumulator3 = MlasMaximumFloat32x4(Accumulator3, ZeroVector); + } + + MlasStoreFloat32x4(&output[output_idx * BlockSize], Accumulator0); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 4], Accumulator1); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 8], Accumulator2); + MlasStoreFloat32x4(&output[output_idx * BlockSize + 12], Accumulator3); + } + } +} + +#endif diff --git a/src/lib/sgemm.cpp b/src/lib/sgemm.cpp index 616622a..84c26eb 100644 --- a/src/lib/sgemm.cpp +++ b/src/lib/sgemm.cpp @@ -1061,7 +1061,7 @@ Return Value: size_t RowsHandled; -#if (defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_LARCH64)) && !defined(FORCE_GENERIC_ALGORITHMS) +#if (defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) || defined(MLAS_TARGET_LARCH64)) && !defined(FORCE_GENERIC_ALGORITHMS) RowsHandled = GetMlasPlatform().GemmFloatKernel(A, B, C, CountK, CountM, CountN, lda, ldc, alpha, ZeroMode); #else if (ZeroMode) { @@ -1572,7 +1572,13 @@ MlasGemmBatch( MLAS_THREADPOOL* ThreadPool ) { - + // Override + if(GetMlasPlatform().MlasGemmBatchOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans && + GetMlasPlatform().MlasGemmBatchOverride(TransA, TransB, M, N, K, Data, BatchSize, ThreadPool)){ + return; + } // // Compute the number of target threads given the complexity of the SGEMM // operation. Small requests should run using the single threaded path. @@ -1637,6 +1643,8 @@ MlasGemmBatch( size_t MLASCALL MlasGemmPackBSize( + CBLAS_TRANSPOSE TransA, + CBLAS_TRANSPOSE TransB, size_t N, size_t K ) @@ -1661,6 +1669,22 @@ Return Value: // // Compute the number of bytes required to hold the packed buffer. // + // KleidiAI or other override + #if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if (GetMlasPlatform().MlasGemmPackBSizeOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans) { + size_t bytes_required; + //TODO pass status by reference to indicate success/fail + bytes_required = GetMlasPlatform().MlasGemmPackBSizeOverride(TransA, TransB, N, K); + if (bytes_required != 0){// If ArmKleidiAI::MlasGemmPackBSize ran to completion + return bytes_required; + } + } + #endif + MLAS_UNREFERENCED_PARAMETER(TransA); + MLAS_UNREFERENCED_PARAMETER(TransB); + const size_t AlignedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); @@ -1676,6 +1700,7 @@ Return Value: void MLASCALL MlasGemmPackB( + CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB, size_t N, size_t K, @@ -1712,6 +1737,17 @@ Return Value: --*/ { +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + if (GetMlasPlatform().MlasGemmPackBOverride != nullptr && + // TODO: Remove once KAI supports transposing for A + TransA != CBLAS_TRANSPOSE::CblasTrans && + GetMlasPlatform().MlasGemmPackBOverride(TransA, TransB, N, K, B, ldb, PackedB)){ + return; + } +#endif + MLAS_UNREFERENCED_PARAMETER(TransA); + + const size_t AlignedN = (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) & ~(MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1); diff --git a/src/lib/snchwc.cpp b/src/lib/snchwc.cpp index f9cf160..6f3423a 100644 --- a/src/lib/snchwc.cpp +++ b/src/lib/snchwc.cpp @@ -101,7 +101,7 @@ Return Value: --*/ { -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) return GetMlasPlatform().NchwcBlockSize; #else return 1; @@ -674,7 +674,7 @@ struct MLAS_NCHWC_CONV_NCHWC_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwcFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwcFloatKernel; @@ -784,7 +784,7 @@ struct MLAS_NCHWC_CONV_NCHW_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvNchwFloatKernel; #else MLAS_CONV_FLOAT_KERNEL* Kernel = MlasConvNchwFloatKernel; @@ -879,7 +879,7 @@ struct MLAS_NCHWC_CONV_POINTWISE_ALGORITHM : MLAS_NCHWC_GROUPED_CONV_ALGORITHM const size_t FilterStrideBytes = BlockSize * InputChannels * sizeof(float); const size_t OutputStrideBytes = BlockSize * OutputSize * sizeof(float); -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvPointwiseFloatKernel; #else MLAS_CONV_POINTWISE_FLOAT_KERNEL* Kernel = MlasConvPointwiseFloatKernel; @@ -1016,7 +1016,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM const size_t BlockedOutputWidth = BlockSize * OutputWidth; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = GetMlasPlatform().ConvDepthwiseFloatKernel; #else MLAS_CONV_DEPTHWISE_FLOAT_KERNEL* Kernel = MlasConvDepthwiseFloatKernel; @@ -1093,7 +1093,7 @@ struct MLAS_NCHWC_CONV_DEPTHWISE_ALGORITHM : MLAS_NCHWC_CONV_ALGORITHM struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM { -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !(defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) static MLAS_POOL_FLOAT_KERNEL* const PoolKernels[]; #endif @@ -1131,7 +1131,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM const size_t DilatedInputWidthBytes = BlockSize * DilationHeight * InputWidth * sizeof(float); const size_t InputStrideBytes = DilatedInputWidthBytes - KernelWidth * DilationWidthBytes; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_LARCH64) || (defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_POOL_FLOAT_KERNEL* Kernel = GetMlasPlatform().PoolFloatKernel[WorkBlock->PoolingKind]; #else MLAS_POOL_FLOAT_KERNEL* Kernel = PoolKernels[WorkBlock->PoolingKind]; @@ -1197,7 +1197,7 @@ struct MLAS_NCHWC_POOL_ALGORITHM : MLAS_NCHWC_NN_ALGORITHM } }; -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !(defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) MLAS_POOL_FLOAT_KERNEL* const MLAS_NCHWC_POOL_ALGORITHM::PoolKernels[] = { @@ -1621,7 +1621,7 @@ Return Value: } } -#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) +#if !defined(MLAS_TARGET_AMD64) && !defined(MLAS_TARGET_LARCH64) && !(defined(MLAS_TARGET_ARM64) && defined(MLAS_USE_ARM_NEON_NCHWC)) // // Convolution and pooling kernel stubs for architectures that do not yet have diff --git a/src/lib/spool_kernel_neon.cpp b/src/lib/spool_kernel_neon.cpp new file mode 100644 index 0000000..5883625 --- /dev/null +++ b/src/lib/spool_kernel_neon.cpp @@ -0,0 +1,293 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. + +Licensed under the MIT License. + +Module Name: + + spool_kernel_neon.cpp + +Abstract: + + This module implements the single precision pooling kernels for ARM NEON. + +--*/ + +#if defined(MLAS_USE_ARM_NEON_NCHWC) + +#include "mlasi.h" + +constexpr size_t BlockSize = MLAS_PLATFORM::MLAS_NEON_NCHWC_BLOCK_SIZE; + +void + MLASCALL + MlasPoolMaximumFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(ActualKernelSize); + MLAS_UNREFERENCED_PARAMETER(InputStride); + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + const float MaxPaddingValue = std::numeric_limits::lowest(); + + const MLAS_FLOAT32X4 MaxPaddingVector = MlasBroadcastFloat32x4(MaxPaddingValue); + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + MLAS_FLOAT32X4 MaxVector0 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector1 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector2 = MaxPaddingVector; + MLAS_FLOAT32X4 MaxVector3 = MaxPaddingVector; + + for (size_t kh = 0; kh < KernelHeight; kh++) { + const float* row_start = InputBase + kh * DilatedInputWidthElements; + const float* row_end = row_start + InputWidthElements; + + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_ptr = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if (input_ptr >= row_start && (input_ptr + BlockSize) <= row_end) { + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(input_ptr); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(input_ptr + 4); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(input_ptr + 8); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(input_ptr + 12); + + MaxVector0 = MlasMaximumFloat32x4(MaxVector0, InputVector0); + MaxVector1 = MlasMaximumFloat32x4(MaxVector1, InputVector1); + MaxVector2 = MlasMaximumFloat32x4(MaxVector2, InputVector2); + MaxVector3 = MlasMaximumFloat32x4(MaxVector3, InputVector3); + } else { + float values[BlockSize]; + for (size_t i = 0; i < BlockSize; i++) { + const float* element_ptr = input_ptr + i; + if (element_ptr >= row_start && element_ptr < row_end) { + values[i] = *element_ptr; + } else { + values[i] = MaxPaddingValue; + } + } + + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(&values[0]); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(&values[4]); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(&values[8]); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(&values[12]); + + MaxVector0 = MlasMaximumFloat32x4(MaxVector0, InputVector0); + MaxVector1 = MlasMaximumFloat32x4(MaxVector1, InputVector1); + MaxVector2 = MlasMaximumFloat32x4(MaxVector2, InputVector2); + MaxVector3 = MlasMaximumFloat32x4(MaxVector3, InputVector3); + } + } + } + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], MaxVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], MaxVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], MaxVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], MaxVector3); + } +} + +static void +MlasPoolAverageFloatKernelNeonImpl( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad, + bool ExcludePad +) +{ + const size_t StrideWidthElements = StrideWidth / sizeof(float); + const size_t DilationWidthElements = DilationWidth / sizeof(float); + const size_t InputWidthElements = InputWidth / sizeof(float); + const size_t DilatedInputWidthElements = DilatedInputWidth / sizeof(float); + const size_t TotalOutputCount = OutputCountLeftPad + OutputCount + OutputCountRightPad; + + const MLAS_FLOAT32X4 ZeroVector = MlasZeroFloat32x4(); + + for (size_t output_idx = 0; output_idx < TotalOutputCount; output_idx++) { + MLAS_FLOAT32X4 SumVector0 = ZeroVector; + MLAS_FLOAT32X4 SumVector1 = ZeroVector; + MLAS_FLOAT32X4 SumVector2 = ZeroVector; + MLAS_FLOAT32X4 SumVector3 = ZeroVector; + + std::vector valid_count; + if (ExcludePad) { + valid_count.resize(BlockSize, 0); + } + + for (size_t kh = 0; kh < KernelHeight; kh++) { + const float* row_start = InputBase + kh * DilatedInputWidthElements; + const float* row_end = row_start + InputWidthElements; + + for (size_t kw = 0; kw < KernelWidth; kw++) { + const float* input_ptr = Input + output_idx * StrideWidthElements + + kh * DilatedInputWidthElements + kw * DilationWidthElements; + + if (input_ptr >= row_start && (input_ptr + BlockSize) <= row_end) { + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(input_ptr); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(input_ptr + 4); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(input_ptr + 8); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(input_ptr + 12); + + SumVector0 = MlasAddFloat32x4(SumVector0, InputVector0); + SumVector1 = MlasAddFloat32x4(SumVector1, InputVector1); + SumVector2 = MlasAddFloat32x4(SumVector2, InputVector2); + SumVector3 = MlasAddFloat32x4(SumVector3, InputVector3); + + if (ExcludePad) { + for (size_t i = 0; i < BlockSize; i++) { + valid_count[i]++; + } + } + } else { + float values[BlockSize]; + for (size_t i = 0; i < BlockSize; i++) { + const float* element_ptr = input_ptr + i; + if (element_ptr >= row_start && element_ptr < row_end) { + values[i] = *element_ptr; + if (ExcludePad) { + valid_count[i]++; + } + } else { + values[i] = 0.0f; + } + } + + MLAS_FLOAT32X4 InputVector0 = MlasLoadFloat32x4(&values[0]); + MLAS_FLOAT32X4 InputVector1 = MlasLoadFloat32x4(&values[4]); + MLAS_FLOAT32X4 InputVector2 = MlasLoadFloat32x4(&values[8]); + MLAS_FLOAT32X4 InputVector3 = MlasLoadFloat32x4(&values[12]); + + SumVector0 = MlasAddFloat32x4(SumVector0, InputVector0); + SumVector1 = MlasAddFloat32x4(SumVector1, InputVector1); + SumVector2 = MlasAddFloat32x4(SumVector2, InputVector2); + SumVector3 = MlasAddFloat32x4(SumVector3, InputVector3); + } + } + } + + if (ExcludePad) { + float results[BlockSize]; + + MlasStoreFloat32x4(&results[0], SumVector0); + MlasStoreFloat32x4(&results[4], SumVector1); + MlasStoreFloat32x4(&results[8], SumVector2); + MlasStoreFloat32x4(&results[12], SumVector3); + + for (size_t i = 0; i < BlockSize; i++) { + results[i] = results[i] / static_cast(valid_count[i]); + } + + MLAS_FLOAT32X4 ResultVector0 = MlasLoadFloat32x4(&results[0]); + MLAS_FLOAT32X4 ResultVector1 = MlasLoadFloat32x4(&results[4]); + MLAS_FLOAT32X4 ResultVector2 = MlasLoadFloat32x4(&results[8]); + MLAS_FLOAT32X4 ResultVector3 = MlasLoadFloat32x4(&results[12]); + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], ResultVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], ResultVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], ResultVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], ResultVector3); + } else { + const float KernelSize = static_cast(ActualKernelSize); + const MLAS_FLOAT32X4 KernelSizeVector = MlasBroadcastFloat32x4(KernelSize); + + MLAS_FLOAT32X4 ResultVector0 = MlasDivideFloat32x4(SumVector0, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector1 = MlasDivideFloat32x4(SumVector1, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector2 = MlasDivideFloat32x4(SumVector2, KernelSizeVector); + MLAS_FLOAT32X4 ResultVector3 = MlasDivideFloat32x4(SumVector3, KernelSizeVector); + + MlasStoreFloat32x4(&Output[output_idx * BlockSize], ResultVector0); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 4], ResultVector1); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 8], ResultVector2); + MlasStoreFloat32x4(&Output[output_idx * BlockSize + 12], ResultVector3); + } + } +} + +void + MLASCALL + MlasPoolAverageExcludePadFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(InputStride); + + MlasPoolAverageFloatKernelNeonImpl( + Input, Output, StrideWidth, DilationWidth, ActualKernelSize, + KernelHeight, KernelWidth, InputBase, InputWidth, DilatedInputWidth, + OutputCountLeftPad, OutputCount, OutputCountRightPad, + true // ExcludePad = true + ); +} + +void + MLASCALL + MlasPoolAverageIncludePadFloatKernelNeon( + const float* Input, + float* Output, + size_t StrideWidth, + size_t DilationWidth, + size_t InputStride, + size_t ActualKernelSize, + size_t KernelHeight, + size_t KernelWidth, + const float* InputBase, + size_t InputWidth, + size_t DilatedInputWidth, + size_t OutputCountLeftPad, + size_t OutputCount, + size_t OutputCountRightPad + ) +{ + MLAS_UNREFERENCED_PARAMETER(InputStride); + + MlasPoolAverageFloatKernelNeonImpl( + Input, Output, StrideWidth, DilationWidth, ActualKernelSize, + KernelHeight, KernelWidth, InputBase, InputWidth, DilatedInputWidth, + OutputCountLeftPad, OutputCount, OutputCountRightPad, + false // ExcludePad = false + ); +} + +#endif diff --git a/src/lib/sqnbitgemm_kernel_avx2.cpp b/src/lib/sqnbitgemm_kernel_avx2.cpp index 5f80a81..6416d25 100644 --- a/src/lib/sqnbitgemm_kernel_avx2.cpp +++ b/src/lib/sqnbitgemm_kernel_avx2.cpp @@ -18,7 +18,6 @@ Module Name: #include #include #include -#include #include "qnbitgemm.h" #include "sqnbitgemm_kernel_avx_common.h" @@ -29,6 +28,7 @@ Module Name: #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen32.h" #include "sqnbitgemm_m1_sym_kernel_avx2_int8_blklen64.h" +#include void MlasCastF16ToF32KernelAvx2(const unsigned short* src_fp16, float* dst_fp32, size_t size) @@ -603,7 +603,8 @@ SQ8BitGemmKernel_BlkSum_CompInt8_avx2( const float* Bias, size_t ldc, const float* ABlockSum, - const float* QuantBBlkSum + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ ) { if (BlkLen == 16) { diff --git a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h index d2d9886..a745dd9 100644 --- a/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h +++ b/src/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h @@ -1660,7 +1660,24 @@ MlasQ4Int8TileGemmKernelBlkLen32Avx2( if constexpr (NCols4 == 8) { __m128 acc_0 = FoldAccumulators(acc[0], acc[1], acc[2], acc[3]); + + // Clang is not happy with the code here, even if constexpr `NCols4 == 8` is always false in this context: + // + // In file included from .../onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp:26: + // .../onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h:1663:49: error: array index 4 is past the end of the array (that has type '__m256[4]') [-Werror,-Warray-bounds] + // 1663 | __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); + // | ^ ~ + // .../onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2_int8_blklen32.h:1531:13: note: array 'acc' declared here + // 1531 | __m256 acc[NCols4]; + // | ^ +#if defined(__clang__) && defined(HAS_ARRAY_BOUNDS) +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Warray-bounds" +#endif __m128 acc_1 = FoldAccumulators(acc[4], acc[5], acc[6], acc[7]); +#if defined(__clang__) && defined(HAS_ARRAY_BOUNDS) +#pragma clang diagnostic pop +#endif if (BiasPtr != nullptr) { acc_0 = _mm_add_ps(acc_0, _mm_loadu_ps(BiasPtr)); acc_1 = _mm_add_ps(acc_1, _mm_loadu_ps(BiasPtr + 4)); diff --git a/src/lib/sqnbitgemm_kernel_avx512.cpp b/src/lib/sqnbitgemm_kernel_avx512.cpp index f917cf7..bfb2959 100644 --- a/src/lib/sqnbitgemm_kernel_avx512.cpp +++ b/src/lib/sqnbitgemm_kernel_avx512.cpp @@ -18,7 +18,6 @@ Module Name: #include #include #include -#include #include #include "qnbitgemm.h" @@ -266,7 +265,8 @@ SQ8BitGemmKernel_BlkSum_CompInt8_avx512( const float* Bias, size_t ldc, const float* ABlockSum, - const float* QuantBBlkSum + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ ) { if (BlkLen == 16) { diff --git a/src/lib/sqnbitgemm_kernel_avx512vnni.cpp b/src/lib/sqnbitgemm_kernel_avx512vnni.cpp index ea5eebd..e172308 100644 --- a/src/lib/sqnbitgemm_kernel_avx512vnni.cpp +++ b/src/lib/sqnbitgemm_kernel_avx512vnni.cpp @@ -316,7 +316,8 @@ SQ8BitGemmKernel_BlkSum_CompInt8_avx512vnni( const float* Bias, size_t ldc, const float* ABlockSum, - const float* QuantBBlkSum + const float* QuantBBlkSum, + const float* /*QuantBBlkSum2*/ ) { if (BlkLen == 16) { diff --git a/src/lib/sqnbitgemm_kernel_avx_common.h b/src/lib/sqnbitgemm_kernel_avx_common.h index bb38f37..36c15cd 100644 --- a/src/lib/sqnbitgemm_kernel_avx_common.h +++ b/src/lib/sqnbitgemm_kernel_avx_common.h @@ -469,7 +469,8 @@ QNBitGemmPerGemmWorkspaceSize( size_t K, size_t BlkLen, bool /* HasZeroPoint */, - MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + size_t /* BlkBitWidth */ ) { MLAS_UNREFERENCED_PARAMETER(N); diff --git a/src/lib/sqnbitgemm_kernel_lasx.cpp b/src/lib/sqnbitgemm_kernel_lasx.cpp new file mode 100644 index 0000000..04c6540 --- /dev/null +++ b/src/lib/sqnbitgemm_kernel_lasx.cpp @@ -0,0 +1,1089 @@ +/*++ + +Module Name: + + sqnbitgemm_kernel_lasx.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for loongarch64. Accelerate inference + optimization using lasx/lsx vector instruction sets. + +--*/ + +#include +#include + +#include +#include +#include +#include +#include "core/common/safeint.h" + +#include "qnbitgemm.h" +#include "sqnbitgemm_kernel_lasx_common.h" + +// 1. qnbitgemm.h->Q4BitGemmPackQuantBDataSize +template +static size_t +QNBitGemmPackQuantBDataSize_Lasx( + size_t N, + size_t K, + size_t BlkLen, + bool /* HasZeroPoint */, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType +) +{ + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + if (ComputeType == SQNBIT_CompInt8) { + SafeInt PackedQuantBDataSize = SafeInt(N) * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const SafeInt ScaleSize = SafeInt(N) * BlockCountK * sizeof(float); + SafeInt BlkSumSize = SafeInt(BlockCountK) * MlasDivRoundup(N, 16) * 16 * sizeof(float); + + // _mm256_load_si256 requires alignment on a 32-byte boundary + constexpr size_t PackedQuantBDataAlignment = 32; + PackedQuantBDataSize += SafeInt(PackedQuantBDataAlignment) - 1; + constexpr size_t BlkSumAlignment = MlasQNBitQuantBBlkSumAlignment(); + BlkSumSize += SafeInt(BlkSumAlignment) - 1; + + PackedQuantBDataSize += ScaleSize + BlkSumSize; + return PackedQuantBDataSize.Value(); + } else { + SafeInt PackedQuantBDataSize = SafeInt(N) * BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + return PackedQuantBDataSize.Value(); + } +} + +// 2. qnbitgemm.h->SQ4BitGemmPackQuantBData +static void +SQ4BitGemmPackQuantBData_Lasx( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE /* ComputeType*/, + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool +) +{ + constexpr size_t BlkBitWidth = 4; + + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const SafeInt Iterations = SafeInt(N) * BlockCountK; // one iteration per block + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + + // + // For SubBlkLen == 16, pack 16 4-bit values (8 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | v4 v5 | v6 v7 | v8 v9 | vA vB | vC vD | vE vF | + // => + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + // + + // + // For SubBlkLen == 32, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | + // => + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + // + + // + // For SubBlkLen == 64, pack 32 4-bit values (16 bytes) at a time like this: + // + // src: | v0 v1 | v2 v3 | ... | v28 v29 | v30 v31 | v32 v33 | v34 v33 | + // => + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // + + MlasTrySimpleParallel( + ThreadPool, Iterations.Value(), + [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const SafeInt data_offset = SafeInt(n) * BlockCountK * BlkDataSize + k_blk * BlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + data_offset.Value(); + std::byte* PackedQuantBData = PackedQuantBDataBegin + data_offset.Value(); + + for (size_t kk = 0; kk < BlkLen; kk += SubBlkLen) { + for (size_t byte_pair_idx = 0; byte_pair_idx < SubBlkBytePairCount; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + SubBlkDataSize / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0F}) | ((src1 & std::byte{0x0F}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } + + QuantBData += SubBlkDataSize; + PackedQuantBData += SubBlkDataSize; + } + } + ); +} + +// 3. qnbitgemm.h->SQ4BitGemmPackQuantBDataAndBlkSum +static void +SQ4BitGemmPackQuantBDataAndBlkSum_Lasx( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool has_zp_input, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& packed_quant_b, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + // TODO: always use SubBlkLen = 64 in SQNBIT_CompInt8 + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (BlkLen == 32 && ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 64; + } + + if (QuantBDataBegin) { + PackQuantB(QuantBDataBegin, packed_quant_b.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + SafeInt offset = SafeInt(N) * BlockCountK; + std::copy(QuantBScaleBegin, QuantBScaleBegin + offset.Value(), packed_quant_b.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !has_zp_input) || QuantBZPBegin) { + ComputePackBlkSum_Lasx( + BlkLen, SubBlkLen, N, + packed_quant_b.PackedQuantBScale, + QuantBZPBegin, + packed_quant_b.QuantBBlkSum, + ThreadPool, + BlockCountK + ); + } +} + +// 3. qnbitgemm.h->SQ8BitGemmPackQuantBDataAndBlkSum +static void +SQ8BitGemmPackQuantBDataAndBlkSum_Lasx( + size_t N, + size_t K, + size_t BlkLen, + MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + assert(BlkLen >= 16 && BlkLen % 16 == 0); + + const size_t BlockCountK = MlasDivRoundup(K, BlkLen); + + size_t SubBlkLen = (BlkLen == 16) ? 16 : (BlkLen == 32 ? 32 : 64); + if (ComputeType == SQNBIT_CompInt8) { + SubBlkLen = 64; + } + Q8PackQuantBDataAndBlkSum_lasx(N, BlockCountK, BlkLen, SubBlkLen, QuantBDataBegin, QuantBScaleBegin, HasZeroPoint, QuantBZPBegin, PackedQuantB, ThreadPool); +} + +MLAS_FORCEINLINE +__m256 +load_float_n_lasx(const float* data, int n) +{ + if (n <= 0) { + alignas(32) float zero_array[8] = {0}; + return (__m256)__lasx_xvld((void*)&zero_array, 0); + } + alignas(32) float buf[8] = {0}; + if (n > 0 && n <= 8) { + for (int i = 0; i < n; ++i) { + buf[i] = data[i]; + } + } + return (__m256)__lasx_xvld((void*)&buf, 0); +} + +// ComputeDotProducts_BlkLen32Plus_CompFp32_lasx +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkLen32Plus_CompFp32_lasx( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* sum_ptr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* bias_ptr +) +{ + if constexpr (!HasZeroPoint) { + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t SubBlkLen32 = 32; + constexpr size_t SubBlkStep16 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen32); + static_assert(SubBlkStep16 == 16); + + __m256 acc[NCols]; + + alignas(32) static const float zero_array[8] = {0}; + UnrolledLoop([&](size_t i) { + acc[i] = (__m256)__lasx_xvld((void*)&zero_array, 0); + }); + + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; + [[maybe_unused]] int count_half_4 = 0; + [[maybe_unused]] uint8_t offset[NCols]; + + // TODO: Improve Memory Access Performance with Prefetching Matrix Operations + // alignas(32) float a_buf[2][32] = {0.0}; + //__m256 a_buf[8]; + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt scale_offset = SafeInt(StrideQuantBScale) * i; + scale_v[i] = *(s + scale_offset.Value()); + }); + + std::byte* b_blk_data_col_ptr[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt data_offset = SafeInt(StrideQuantBData) * i; + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value()); + }); + + // not ready for "Manual conversion to float" in neon yet. + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen32) { + size_t kklen = std::min((int)SubBlkLen32, (int)(ck - kk)); + + __m256 av0_8_ps = load_float_n_lasx(ARowPtr + k + kk, std::min(kklen, 8)); + __m256 av1_8_ps = load_float_n_lasx(ARowPtr + k + kk + 8, std::min(kklen > 8 ? kklen - 8 : 0, 8)); + __m256 av2_8_ps = load_float_n_lasx(ARowPtr + k + kk + 16, std::min(kklen > 16 ? kklen - 16 : 0, 8)); + __m256 av3_8_ps = load_float_n_lasx(ARowPtr + k + kk + 24, std::min(kklen > 24 ? kklen - 24 : 0, 8)); + + if constexpr (IsBlkLen64Layout) { + count_half_4 = 4 * (int)((kk % (2 * SubBlkLen32)) / SubBlkLen32); + } + + UnrolledLoop([&](size_t i) { + __m256i bv_0_32; + + if constexpr (IsBlkLen64Layout) { + __m256i bv_32_4bit_tmp = __lasx_xvld(b_blk_data_col_ptr[i], 0); + if (!count_half_4) + bv_0_32 = __lasx_xvandi_b(bv_32_4bit_tmp, 0x0F); + else + bv_0_32 = __lasx_xvsrli_b(bv_32_4bit_tmp, 4); + b_blk_data_col_ptr[i] += count_half_4 / 2 * SubBlkStep16; + } else { + // SubBlkLen = 32: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on b_blk_data_col_ptr to ensure that it could be read in 16 units + std::memcpy(packed_bytes, b_blk_data_col_ptr[i], 16); + __m256i bv_32_4bit_tmp = __lasx_xvld((void*)&packed_bytes, 0); + __m256i bv_0_15_tmp = __lasx_xvpermi_d(__lasx_xvandi_b(bv_32_4bit_tmp, 0x0F), 0x36); + __m256i bv_16_31_tmp = __lasx_xvpermi_d(__lasx_xvsrli_b(bv_32_4bit_tmp, 4), 0x36); + bv_0_32 = __lasx_xvpermi_d(__lasx_xvpermi_w(bv_16_31_tmp, bv_0_15_tmp, 0xEE), 0x72); + b_blk_data_col_ptr[i] += SubBlkStep16; + } + + __m256i zp = HasZeroPoint ? __lasx_xvldrepl_b((void*)&offset[i], 0) : __lasx_xvrepli_b(0x08); + bv_0_32 = __lasx_xvsub_b(bv_0_32, zp); + + // (1)8bit -> 16bit + __m256i bv_0_15 = __lasx_xvexth_h_b(__lasx_xvpermi_d(bv_0_32, 0x72)); + __m256i bv_16_31 = __lasx_xvexth_h_b(__lasx_xvpermi_d(bv_0_32, 0xD8)); + + // (2)16bit -> int32 + __m256i bv_0_7 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_0_15, 0x72)); + __m256i bv_8_15 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_0_15, 0xD8)); + __m256i bv_16_23 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_16_31, 0x72)); + __m256i bv_24_31 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv_16_31, 0xD8)); + + // (3)int32 -> fp32 + __m256 fbv_0_7 = __lasx_xvffint_s_w(bv_0_7); + __m256 fbv_8_15 = __lasx_xvffint_s_w(bv_8_15); + __m256 fbv_16_23 = __lasx_xvffint_s_w(bv_16_23); + __m256 fbv_24_31 = __lasx_xvffint_s_w(bv_24_31); + + __m256 scale_ps = (__m256)__lasx_xvldrepl_w(&scale_v[i], 0); + + fbv_0_7 = __lasx_xvfmul_s(fbv_0_7, scale_ps); + fbv_8_15 = __lasx_xvfmul_s(fbv_8_15, scale_ps); + fbv_16_23 = __lasx_xvfmul_s(fbv_16_23, scale_ps); + fbv_24_31 = __lasx_xvfmul_s(fbv_24_31, scale_ps); + + acc[i] = __lasx_xvfmadd_s(fbv_0_7, av0_8_ps, acc[i]); + acc[i] = __lasx_xvfmadd_s(fbv_8_15, av1_8_ps, acc[i]); + acc[i] = __lasx_xvfmadd_s(fbv_16_23, av2_8_ps, acc[i]); + acc[i] = __lasx_xvfmadd_s(fbv_24_31, av3_8_ps, acc[i]); + }); + } + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + ++s; + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators_Lasx(acc[0], acc[1], acc[2], acc[3]); + if (bias_ptr != nullptr) { + acc_x = __lsx_vfadd_s(acc_x, (__m128)__lsx_vld((void*)bias_ptr, 0)); + } + __lsx_vst(acc_x, sum_ptr, 0); + } else { + UnrolledLoop([&](size_t i) { + float sum = hsum_float_8_lasx(acc[i]); + float bias_tmp = bias_ptr == nullptr ? 0.0f : bias_ptr[i]; + sum_ptr[i] = sum + bias_tmp; + }); + } +} + +// ComputeDotProducts_BlkLen16_CompFp32_lasx +template +MLAS_FORCEINLINE void +ComputeDotProducts_BlkLen16_CompFp32_lasx( + size_t BlkLen, + const float* ARowPtr, + const std::byte* QuantBDataColPtr, + const float* QuantBScaleColPtr, + const std::byte* QuantBZeroPointColPtr, + float* sum_ptr, + size_t CountK, + size_t StrideQuantBData, + size_t StrideQuantBScale, + size_t StrideQuantBZeroPoint, + const float* bias_ptr +) +{ + if constexpr (!HasZeroPoint) { + (void)QuantBZeroPointColPtr; + (void)StrideQuantBZeroPoint; + } + + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t SubBlkLen16 = 16; + constexpr size_t SubBlkStep8 = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubBlkLen16); + static_assert(SubBlkStep8 == 8); + + __m256 acc[NCols]; + alignas(32) int zero_array[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + UnrolledLoop([&](size_t i) { + acc[i] = (__m256)__lasx_xvld((void*)&zero_array, 0); + }); + + const std::byte* b_blk_data_ptr = QuantBDataColPtr; + const float* s = QuantBScaleColPtr; + + [[maybe_unused]] size_t QuantBZeroPointIdx = 0; + [[maybe_unused]] uint8_t offset[NCols]; + + for (size_t k = 0; k < CountK; k += BlkLen) { + size_t ck = std::min(CountK - k, BlkLen); + + float scale_v[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt scale_offset = SafeInt(StrideQuantBScale) * i; + scale_v[i] = *(s + scale_offset.Value()); + }); + + std::byte* b_blk_data_col_ptr[NCols]; + UnrolledLoop([&](size_t i) { + SafeInt data_offset = SafeInt(StrideQuantBData) * i; + b_blk_data_col_ptr[i] = (std::byte*)(b_blk_data_ptr + data_offset.Value()); + }); + + if constexpr (HasZeroPoint) { + UnrolledLoop([&](size_t i) { + const std::byte zp_packed = + QuantBZeroPointColPtr[i * StrideQuantBZeroPoint + QuantBZeroPointIdx / 2]; + const std::byte zp = ((QuantBZeroPointIdx & 1) == 1) + ? (zp_packed >> 4) + : (zp_packed & std::byte{0x0F}); + offset[i] = std::to_integer(zp); + }); + } + + for (size_t kk = 0; kk < ck; kk += SubBlkLen16) { + size_t kklen = std::min((int)SubBlkLen16, (int)(ck - kk)); + + __m256 av_lo = load_float_n_lasx(ARowPtr + k + kk, std::min(kklen, 8)); + __m256 av_hi = load_float_n_lasx(ARowPtr + k + kk + 8, std::min(kklen > 8 ? kklen - 8 : 0, 8)); + + UnrolledLoop([&](size_t i) { + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on b_blk_data_col_ptr to ensure that it could be read in 8 units + std::memcpy(packed_bytes + 24, b_blk_data_col_ptr[i], 8); + __m256i B_16val = __lasx_xvld((void*)&packed_bytes, 0); + + /* + low->high + | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | x 3 + | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | 24-31 + */ + + b_blk_data_col_ptr[i] += SubBlkStep8; + __m256i lower = __lasx_xvandi_b(B_16val, 0x0F); + __m256i upper = __lasx_xvsrli_b(B_16val, 4); + __m256i packb = __lasx_xvpermi_d(__lasx_xvpackod_d(upper, lower), 0xD8); + + __m256i zp = HasZeroPoint ? __lasx_xvldrepl_b((void*)&offset[i], 0) : __lasx_xvrepli_b(0x08); + packb = __lasx_xvsub_b(packb, zp); + __m256i bv0_15 = __lasx_xvexth_h_b(packb); + + __m256i bv0_7 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv0_15, 0x72)); + __m256i bv8_15 = __lasx_xvexth_w_h(__lasx_xvpermi_d(bv0_15, 0xD8)); + + __m256 fbv0_7 = __lasx_xvffint_s_w(bv0_7); + __m256 fbv8_15 = __lasx_xvffint_s_w(bv8_15); + __m256 scale = (__m256)__lasx_xvldrepl_w((void*)&scale_v[i], 0); + fbv0_7 = __lasx_xvfmul_s(fbv0_7, scale); + fbv8_15 = __lasx_xvfmul_s(fbv8_15, scale); + + acc[i] = __lasx_xvfmadd_s(av_lo, fbv0_7, acc[i]); + acc[i] = __lasx_xvfmadd_s(av_hi, fbv8_15, acc[i]); + }); + } + + b_blk_data_ptr += MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + ++s; + + if constexpr (HasZeroPoint) { + QuantBZeroPointIdx += 1; + } + } + + if constexpr (NCols == 4) { + __m128 acc_x = FoldAccumulators_Lasx(acc[0], acc[1], acc[2], acc[3]); + if (bias_ptr != nullptr) { + acc_x = __lsx_vfadd_s(acc_x, (__m128)__lsx_vld((void*)bias_ptr, 0)); + } + __lsx_vst(acc_x, sum_ptr, 0); + } else { + UnrolledLoop([&](size_t i) { + float sum = 0.0f; + alignas(32) float acc_buf[8]; + __lasx_xvst(acc[i], (void*)&acc_buf, 0); + UnrolledLoop<8>([&](size_t j) { sum += acc_buf[j]; }); + float bias_tmp = bias_ptr == nullptr ? 0.0f : bias_ptr[i]; + sum_ptr[i] = sum + bias_tmp; + }); + } +} + +// SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx( + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = (CountN) - NCols4; + while (nblk >= 0) { + ComputeDotProducts_BlkLen16_CompFp32_lasx( + BlkLen16, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + SafeInt data_offset = SafeInt(StrideQuantBData) * NCols4; + SafeInt scale_offset = SafeInt(StrideQuantBScale) * NCols4; + QuantBDataColPtr += data_offset.Value(); + QuantBScaleColPtr += scale_offset.Value(); + if constexpr (HasZeroPoint) { + SafeInt zeropoint_offset = SafeInt(StrideQuantBZeroPoint) * NCols4; + QuantBZeroPointColPtr += zeropoint_offset.Value(); + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + ComputeDotProducts_BlkLen16_CompFp32_lasx<1, HasZeroPoint>( + BlkLen16, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +// SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx +template +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols4 = 4; + + const float* ARowPtr = A; + float* CRowPtr = C; + + const size_t BlockCountK = BlockStrideQuantB; + + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + const size_t StrideQuantBZeroPoint = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + const float* BiasPtr = Bias; + + const std::byte* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const std::byte* QuantBZeroPointColPtr = QuantBZeroPoint; + + float* SumPtr = CRowPtr; + + int64_t nblk = static_cast(CountN) - NCols4; + + while (nblk >= 0) { + if (BlkLen >= 64) { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } else { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } + + SafeInt data_offset = SafeInt(StrideQuantBData) * NCols4; + SafeInt scale_offset = SafeInt(StrideQuantBScale) * NCols4; + QuantBDataColPtr += data_offset.Value(); + QuantBScaleColPtr += scale_offset.Value(); + if constexpr (HasZeroPoint) { + SafeInt zeropoint_offset = SafeInt(StrideQuantBZeroPoint) * NCols4; + QuantBZeroPointColPtr += zeropoint_offset.Value(); + } + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + + nblk -= NCols4; + } + + // left over columns less than NCols + nblk += NCols4; + for (int64_t n = 0; n < nblk; ++n) { + if (BlkLen >= 64) { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx<1, HasZeroPoint, true>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } else { + ComputeDotProducts_BlkLen32Plus_CompFp32_lasx<1, HasZeroPoint, false>( + BlkLen, + ARowPtr, QuantBDataColPtr, QuantBScaleColPtr, QuantBZeroPointColPtr, SumPtr, CountK, + StrideQuantBData, StrideQuantBScale, StrideQuantBZeroPoint, + BiasPtr + ); + } + + QuantBDataColPtr += StrideQuantBData; + QuantBScaleColPtr += StrideQuantBScale; + if constexpr (HasZeroPoint) { + QuantBZeroPointColPtr += StrideQuantBZeroPoint; + } + + BiasPtr += BiasPtr != nullptr ? 1 : 0; + SumPtr += 1; + } +} + +MLAS_FORCEINLINE void +SQ4BitGemmM1Kernel_CompFp32_Lasx( + size_t BlkLen, + const float* A, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + float* C, + size_t CountN, + size_t CountK, + size_t BlockStrideQuantB, + const float* Bias +) +{ + if (BlkLen == 16) { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx( + A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen16_CompFp32_lasx( + A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } + } else { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx( + BlkLen, A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } else { + SQ4BitGemmM1Kernel_BlkLen32Plus_CompFp32_lasx( + BlkLen, A, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountN, CountK, BlockStrideQuantB, Bias + ); + } + } +} + +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_lasx( + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockCountK +) +{ + constexpr size_t BlkLen16 = 16; + constexpr size_t BlkBitWidth4 = 4; + + constexpr size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen16); + const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; + /* + TODO: constexpr use template parameter + Since QuantBZeroPoint is a model parameter and cannot be determined at compile time, constexpr cannot be used + and comments are required, However, when the usage scenario can be determined, constexpr can be used to enhance + performance. + */ + /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; + const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + constexpr size_t NCols8 = 8; // process NCols8 columns of QuantB at a time + constexpr size_t GemmFloatKernelWidth16 = 16; // mlas GemmFloatKernel requires B with width 16 + for (size_t col = 0; col < CountN; col += NCols8) { + const int cols = std::min((int)NCols8, (int)CountN - (int)col); + for (size_t k = 0; k < BlockCountK; k++) { + // count # of tiles plus blks of the current tile from top + const size_t tile_count = col / GemmFloatKernelWidth16; + SafeInt offset = SafeInt(tile_count * CountK + k * BlkLen16) * GemmFloatKernelWidth16; + float* dst_ptr = FpData + offset.Value(); + if (col % GemmFloatKernelWidth16 >= NCols8) { + // for the second half to 16 width tile + dst_ptr += NCols8; + } + SafeInt b_data_offset = SafeInt(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + SafeInt b_scale_offset = SafeInt(col) * BlockCountK + k; + SafeInt b_zp_offset = SafeInt(col) * zp_col_stride_in_bytes + k / 2; + const std::byte* b_data_ptr = QuantBData + b_data_offset.Value(); + const float* scale_ptr = QuantBScale + b_scale_offset.Value(); + const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value(); + bool is_lower = (k % 2) == 0; + + __m256i weight_16_epi16[NCols8]; + __m256 scale_8_ps[NCols8]; + UnrolledLoop([&](size_t col_) { + if ((int)col_ < cols) { + // dst: | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on QuantBData to ensure that it could be read in 8 units + std::memcpy(packed_bytes + 24, b_data_ptr + col_ * b_data_col_stride_in_bytes, 8); + __m256i B_16val = __lasx_xvld((void*)&packed_bytes, 0); + // low->high + // | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | 0 0 | x 3 + // | v0 v8 | v1 v9 | v2 vA | v3 vB | v4 vC | v5 vD | v6 vE | v7 vF | 24-31 + + __m256i lower = __lasx_xvandi_b(B_16val, 0x0F); + __m256i upper = __lasx_xvsrli_b(B_16val, 4); + __m256i packb = __lasx_xvpermi_d(__lasx_xvpackod_d(upper, lower), 0xD8); + + if (HasZeroPoint) { + std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + __m256i zero_point = __lasx_xvreplgr2vr_b(static_cast(zp)); + packb = __lasx_xvsub_b(packb, zero_point); + } else { + __m256i zero_point = __lasx_xvrepli_b(0x08); + packb = __lasx_xvsub_b(packb, zero_point); + } + weight_16_epi16[col_] = __lasx_xvexth_h_b(packb); + scale_8_ps[col_] = (__m256)__lasx_xvldrepl_w((void*)(scale_ptr + col_ * BlockCountK), 0); + } else { + weight_16_epi16[col_] = __lasx_xvrepli_d(0); + scale_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + }); + + for (int i_of_2 = 0; i_of_2 < 2; i_of_2++) { + __m256 weight_8_ps[8]; + for (size_t col_ = 0; col_ < 8; col_++) { + if ((int)col_ < cols) { + if (i_of_2 == 0) { + __m256i weight_i_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_16_epi16[col_], 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_8_epi32), scale_8_ps[col_]); + } else { + __m256i weight_i_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_16_epi16[col_], 0xD8)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_8_epi32), scale_8_ps[col_]); + } + } else { + weight_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + } + // transpose and store + __m256 a0 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0x44); // a1, a2, b1, b2, a5, a6, b5, b6 + __m256 a1 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0xEE); // a3, a4, b3, b4, a7, a8, b7, b8 + __m256 a2 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0x44); // c1, c2, d1, d2, c5, c6, d5, d6 + __m256 a3 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0xEE); // c3, c4, d3, d4, c7, c8, d7, d8 + __m256 a4 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0x44); // e1, e2, f1, f2, e5, e6, f5, f6 + __m256 a5 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0xEE); // e3, e4, f3, f4, e7, e8, f7, f8 + __m256 a6 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0x44); // g1, g2, h1, h2, g5, g6, h5, h6 + __m256 a7 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0xEE); // g3, g4, h3, h4, g7, g8, h7, h8 + + __m256 b0 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0x88); // a1, b1, c1, d1, a5, b5, c5, d5 + __m256 b1 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0xDD); // a2, b2, c2, d2, a6, b6, c6, d6 + __m256 b2 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0x88); // a3, b3, c3, d3, a7, b7, c7, d7 + __m256 b3 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0xDD); // a4, b4, c4, d4, a8, b8, c8, d8 + __m256 b4 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0x88); // e1, f1, g1, h1, e5, f5, g5, h5 + __m256 b5 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0xDD); // e2, f2, g2, h2, e6, f6, g6, h6 + __m256 b6 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0x88); // e3, f3, g3, h3, e7, f7, g7, h7 + __m256 b7 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0xDD); // e4, f4, g4, h4, e8, f8, g8, h8 + + // next i_of_2th row + const size_t ij_offset_in_k = i_of_2 * 8 * GemmFloatKernelWidth16; + __m256 weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x02); // a1, b1, c1, d1, e1, f1, g1, h1 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x02); // a2, b2, c2, d2, e2, f2, g2, h2 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x02); // a3, b3, c3, d3, e3, f3, g3, h3 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x02); // a4, b4, c4, d4, e4, f4, g4, h4 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x13); // a5, b5, c5, d5, e5, f5, g5, h5 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x13); // a6, b6, c6, d6, e6, f6, g6, h6 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x13); // a7, b7, c7, d7, e7, f7, g7, h7 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x13); // a8, b8, c8, d8, e8, f8, g8, h8 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, 0); + } + } + } +} + +template +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockCountK +) +{ + constexpr size_t BlkBitWidth4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t GemmFloatKernelWidth16 = 16; + constexpr size_t SubblkLen32 = 32; + + const size_t blk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, BlkLen); + const size_t subblk_data_size_in_bytes = MlasQNBitBlkDataSizeInBytes(BlkBitWidth4, SubblkLen32); + const size_t b_data_col_stride_in_bytes = BlockCountK * blk_data_size_in_bytes; + /* + TODO: constexpr use template parameter + Since QuantBZeroPoint is a model parameter and cannot be determined at compile time, constexpr cannot be used + and comments are required, However, when the usage scenario can be determined, constexpr can be used to enhance + performance. + */ + /*constexpr*/ const bool HasZeroPoint = QuantBZeroPoint != nullptr; + const size_t zp_col_stride_in_bytes = MlasQNBitZeroPointsForBlksSizeInBytes(BlockCountK); + + [[maybe_unused]] int count_half_4 = 0; + + for (size_t col = 0; col < CountN; col += NCols8) { + // TODO: handle last tile with cols < NCols8 + const size_t cols = std::min(NCols8, CountN - col); + for (size_t k = 0; k < BlockCountK; k++) { + // count # of tiles plus blks of the current tile from top + const size_t tile_count = col / GemmFloatKernelWidth16; + SafeInt offset = SafeInt(tile_count * CountK + k * BlkLen) * GemmFloatKernelWidth16; + float* dst_ptr = FpData + offset.Value(); + if (col % GemmFloatKernelWidth16 >= NCols8) { + // for the second half to 16 width tile + dst_ptr += NCols8; + } + SafeInt b_data_offset = SafeInt(col) * b_data_col_stride_in_bytes + k * blk_data_size_in_bytes; + SafeInt b_scale_offset = SafeInt(col) * BlockCountK + k; + SafeInt b_zp_offset = SafeInt(col) * zp_col_stride_in_bytes + k / 2; + const std::byte* b_data_ptr = QuantBData + b_data_offset.Value(); + const float* scale_ptr = QuantBScale + b_scale_offset.Value(); + const std::byte* zp_ptr = QuantBZeroPoint + b_zp_offset.Value(); + bool is_lower = (k % 2) == 0; + + for (size_t subblk = 0; subblk < BlkLen / SubblkLen32; subblk++) { + __m256i weight_32_epi8[NCols8]; + __m256 scale_8_ps[NCols8]; + if constexpr (IsBlkLen64Layout) { + count_half_4 = 4 * (subblk % 2); + } + UnrolledLoop([&](size_t col_) { + // 1. load 32 4-bit data + if (col_ < cols) { + if constexpr (IsBlkLen64Layout) { + // dst: | v0 v32 | v1 v33 | ... | v30 v62 | v31 v63 | + // load 64 weights at once, parse to get v0 - v31 if subblk % 2 == 0, otherwise get v32 - v63 + // at the end of subblk loop, increment b_data_ptr by 2 * subblk_data_size_in_bytes if subblk % 2 == 1 + // so that all v0-64 of the pack are dequantized. + __m256i bv_32_4bit_tmp = __lasx_xvld(b_data_ptr + col_ * b_data_col_stride_in_bytes, 0); + if (!count_half_4) + weight_32_epi8[col_] = __lasx_xvandi_b(bv_32_4bit_tmp, 0x0F); + else + weight_32_epi8[col_] = __lasx_xvsrli_b(bv_32_4bit_tmp, 4); + } else { + // dst: | v0 v16 | v1 v17 | ... | v14 v30 | v15 v31 | + alignas(32) uint8_t packed_bytes[32] = {0}; + // Previously, boundary padding was performed on QuantBData to ensure that it could be read in 16 units + std::memcpy(packed_bytes, b_data_ptr + col_ * b_data_col_stride_in_bytes, 16); + __m256i bv_32_4bit_tmp = __lasx_xvld((void*)&packed_bytes, 0); + __m256i bv_0_15_tmp = __lasx_xvpermi_d(__lasx_xvandi_b(bv_32_4bit_tmp, 0x0F), 0x36); + __m256i bv_16_31_tmp = __lasx_xvpermi_d(__lasx_xvsrli_b(bv_32_4bit_tmp, 4), 0x36); + weight_32_epi8[col_] = __lasx_xvpermi_d(__lasx_xvpermi_w(bv_16_31_tmp, bv_0_15_tmp, 0xEE), 0x72); + } + + // 2. load zeropoint and scale + if (HasZeroPoint) { + std::byte zp_packed = *(zp_ptr + col_ * zp_col_stride_in_bytes); + uint8_t zp = std::to_integer(is_lower ? (zp_packed & std::byte{0x0F}) : (zp_packed >> 4)); + __m256i zero_point = __lasx_xvreplgr2vr_b(static_cast(zp)); + weight_32_epi8[col_] = __lasx_xvsub_b(weight_32_epi8[col_], zero_point); + } else { + __m256i zero_point = __lasx_xvrepli_b(0x08); + weight_32_epi8[col_] = __lasx_xvsub_b(weight_32_epi8[col_], zero_point); + } + + scale_8_ps[col_] = (__m256)__lasx_xvldrepl_w((void*)(scale_ptr + col_ * BlockCountK), 0); + } else { + weight_32_epi8[col_] = __lasx_xvrepli_d(0); + scale_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + }); + + for (int i_of_4 = 0; i_of_4 < 4; i_of_4++) { + __m256 weight_8_ps[8]; + for (size_t col_ = 0; col_ < 8; col_++) { + if (col_ < cols) { + if (i_of_4 == 0) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(__lasx_xvpermi_d(weight_32_epi8[col_], 0xE1)); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 1) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(weight_32_epi8[col_]); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 2) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(__lasx_xvpermi_d(weight_32_epi8[col_], 0xD8)); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0x72)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } else if (i_of_4 == 3) { + __m256i weight_i_16_epi16 = __lasx_xvexth_h_b(weight_32_epi8[col_]); + __m256i weight_i_j_8_epi32 = __lasx_xvexth_w_h(__lasx_xvpermi_d(weight_i_16_epi16, 0xD8)); + weight_8_ps[col_] = __lasx_xvfmul_s(__lasx_xvffint_s_w(weight_i_j_8_epi32), scale_8_ps[col_]); + } + } else { + weight_8_ps[col_] = (__m256)__lasx_xvrepli_d(0); + } + } + // transpose and store + __m256 a0 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0x44); // a1, a2, b1, b2, a5, a6, b5, b6 + __m256 a1 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[1], (__m256i)weight_8_ps[0], 0xEE); // a3, a4, b3, b4, a7, a8, b7, b8 + __m256 a2 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0x44); // c1, c2, d1, d2, c5, c6, d5, d6 + __m256 a3 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[3], (__m256i)weight_8_ps[2], 0xEE); // c3, c4, d3, d4, c7, c8, d7, d8 + __m256 a4 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0x44); // e1, e2, f1, f2, e5, e6, f5, f6 + __m256 a5 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[5], (__m256i)weight_8_ps[4], 0xEE); // e3, e4, f3, f4, e7, e8, f7, f8 + __m256 a6 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0x44); // g1, g2, h1, h2, g5, g6, h5, h6 + __m256 a7 = (__m256)__lasx_xvpermi_w((__m256i)weight_8_ps[7], (__m256i)weight_8_ps[6], 0xEE); // g3, g4, h3, h4, g7, g8, h7, h8 + + __m256 b0 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0x88); // a1, b1, c1, d1, a5, b5, c5, d5 + __m256 b1 = (__m256)__lasx_xvpermi_w((__m256i)a2, (__m256i)a0, 0xDD); // a2, b2, c2, d2, a6, b6, c6, d6 + __m256 b2 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0x88); // a3, b3, c3, d3, a7, b7, c7, d7 + __m256 b3 = (__m256)__lasx_xvpermi_w((__m256i)a3, (__m256i)a1, 0xDD); // a4, b4, c4, d4, a8, b8, c8, d8 + __m256 b4 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0x88); // e1, f1, g1, h1, e5, f5, g5, h5 + __m256 b5 = (__m256)__lasx_xvpermi_w((__m256i)a6, (__m256i)a4, 0xDD); // e2, f2, g2, h2, e6, f6, g6, h6 + __m256 b6 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0x88); // e3, f3, g3, h3, e7, f7, g7, h7 + __m256 b7 = (__m256)__lasx_xvpermi_w((__m256i)a7, (__m256i)a5, 0xDD); // e4, f4, g4, h4, e8, f8, g8, h8 + + // next i_of_2th row + const size_t ij_offset_in_k = i_of_4 * 8 * GemmFloatKernelWidth16; + __m256 weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x02); // a1, b1, c1, d1, e1, f1, g1, h1 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 0 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x02); // a2, b2, c2, d2, e2, f2, g2, h2 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 1 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x02); // a3, b3, c3, d3, e3, f3, g3, h3 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 2 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x02); // a4, b4, c4, d4, e4, f4, g4, h4 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 3 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b0, (__m256i)b4, 0x13); // a5, b5, c5, d5, e5, f5, g5, h5 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 4 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b1, (__m256i)b5, 0x13); // a6, b6, c6, d6, e6, f6, g6, h6 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 5 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b2, (__m256i)b6, 0x13); // a7, b7, c7, d7, e7, f7, g7, h7 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 6 * GemmFloatKernelWidth16, 0); + weight_transposed_8_ps = (__m256)__lasx_xvpermi_q((__m256i)b3, (__m256i)b7, 0x13); // a8, b8, c8, d8, e8, f8, g8, h8 + __lasx_xvst(weight_transposed_8_ps, dst_ptr + ij_offset_in_k + 7 * GemmFloatKernelWidth16, 0); + } + dst_ptr += SubblkLen32 * GemmFloatKernelWidth16; + if constexpr (IsBlkLen64Layout) { + b_data_ptr += (subblk % 2) * 2 * subblk_data_size_in_bytes; + } else { + b_data_ptr += subblk_data_size_in_bytes; + } + } // subblk + } + } +} + +MLAS_FORCEINLINE void +Q4BitBlkDequantBForSgemm_CompFp32_Lasx( + const size_t BlkLen, + float* FpData, + const std::byte* QuantBData, + const float* QuantBScale, + const std::byte* QuantBZeroPoint, + const size_t CountN, + const size_t CountK, + const size_t BlockStrideQuantB +) +{ + if (BlkLen == 16) { + Q4BitBlkDequantBForSgemmBlkLen16_CompFp32_lasx( + FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } else if (BlkLen == 32) { + Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx( + BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } else { + Q4BitBlkDequantBForSgemmBlkLen32AndMore_CompFp32_lasx( + BlkLen, FpData, QuantBData, QuantBScale, QuantBZeroPoint, CountN, CountK, BlockStrideQuantB + ); + } +} + +const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchLasx = []() { + MLAS_QNBIT_GEMM_DISPATCH d; + + d.Q4BitGemmPackQuantBDataSize = QNBitGemmPackQuantBDataSize_Lasx<4>; + d.SQ4BitGemmPackQuantBData = SQ4BitGemmPackQuantBData_Lasx; + d.SQ4BitGemmPackQuantBDataAndBlkSum = SQ4BitGemmPackQuantBDataAndBlkSum_Lasx; + d.SQ8BitGemmPackQuantBDataAndBlkSum = SQ8BitGemmPackQuantBDataAndBlkSum_Lasx; + + d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_Lasx; + d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_Lasx; + + return d; +}(); diff --git a/src/lib/sqnbitgemm_kernel_lasx_common.h b/src/lib/sqnbitgemm_kernel_lasx_common.h new file mode 100644 index 0000000..508bcba --- /dev/null +++ b/src/lib/sqnbitgemm_kernel_lasx_common.h @@ -0,0 +1,514 @@ +/*++ + Abstract: + + Lasx/Lsx tool function, Auxiliary functions for inference required by + 4-bit/8-bit quantization models. +--*/ +#pragma once +#include "qnbitgemm.h" +#include "core/common/safeint.h" +#include + +template +struct MlasAlignedAllocator { + using value_type = T; + + MlasAlignedAllocator() = default; + + template + MlasAlignedAllocator(const MlasAlignedAllocator&) {} + + T* allocate(size_t n) { + // If RequiredAlignment > 0, use the required value + // Otherwise, use the value of MlasGetPreferredBufferAlignment() + size_t alignment = RequiredAlignment > 0 ? + RequiredAlignment : + MlasGetPreferredBufferAlignment(); + + size_t size = n * sizeof(T); + if (size % alignment != 0) // check the size + size = ((size + alignment - 1) / alignment) * alignment; + #if defined(_MSC_VER) + void* ptr = _aligned_malloc(size, alignment); + #else + void* ptr = aligned_alloc(alignment, size); + #endif + if (!ptr) throw std::bad_alloc(); + return static_cast(ptr); + } + + void deallocate(T* ptr, size_t) { + #if defined(_MSC_VER) + _aligned_free(ptr); + #else + free(ptr); + #endif + } + + template + struct rebind { + using other = MlasAlignedAllocator; + }; +}; + +static MLAS_FORCEINLINE __m256 +__lasx_xvzero() +{ + return (__m256)__lasx_xvldi(0); +} + +static size_t +GetContinueLayoutOffsetSubBlk(size_t N, const size_t n, const size_t SubOrBlkCountK, const size_t k_sub_or_blk) +{ + size_t T = n / 4, t = n % 4; + bool te = T == N / 4; + SafeInt scale_dst_offset = SafeInt(T) * 4 * SubOrBlkCountK; + if (te) { + scale_dst_offset += SafeInt(t) * SubOrBlkCountK + k_sub_or_blk; + } else { + scale_dst_offset += SafeInt(k_sub_or_blk) * 4 + t; + } + return scale_dst_offset.Value(); +} + +static size_t +GetContinueLayoutOffsetBlkInSubBlk(size_t N, const size_t n, const size_t BlockCountK, const size_t k_blk, const int blks_per_sub) +{ + size_t T = n / 4, t = n % 4, k_subblk = k_blk / blks_per_sub, b = k_blk % blks_per_sub; + bool te = T == N / 4, be = k_subblk == BlockCountK / blks_per_sub; + SafeInt scale_dst_offset = SafeInt(T) * 4 * BlockCountK; + if (te) { + scale_dst_offset += SafeInt(t) * BlockCountK + k_blk; + } else { + scale_dst_offset += SafeInt(k_subblk) * blks_per_sub * 4; + if (be) { + scale_dst_offset += SafeInt(b) * 4 + t; + } else { + scale_dst_offset += SafeInt(t) * blks_per_sub + b; + } + } + return scale_dst_offset.Value(); +} + +static void +ComputePackBlkSum_Lasx( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK +) +{ + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t k_blk = tid % BlockCountK; + + const SafeInt src_blk_offset = SafeInt(n) * BlockCountK + k_blk; + float QuantBScale = QuantBScaleBegin[src_blk_offset.Value()]; + uint8_t zp = 8; + + if (QuantBZPBegin) { + size_t ZPCountK = MlasDivRoundup(BlockCountK, 2); + SafeInt src_zp_offset = SafeInt(ZPCountK) * n + k_blk / 2; + bool low_zp = k_blk % 2 == 0; + const std::byte* QuantBZP = QuantBZPBegin + src_zp_offset.Value(); + const std::byte low_mask{0X0f}; + zp = (uint8_t)(low_zp ? ((*QuantBZP) & low_mask) : ((*QuantBZP) >> 4)); + } + + float result = -QuantBScale * zp; + + const SafeInt dst_offset = ( SafeInt(n / 16) * BlockCountK + k_blk) * 16 + n % 16; + BlockSumBegin[dst_offset.Value()] = result; + + if (BlkLen == 16) { + } else if (BlkLen >= SubBlkLen) { + const size_t scale_dst_offset = GetContinueLayoutOffsetSubBlk(N, n, BlockCountK, k_blk); + QuantBScaleBegin[scale_dst_offset] = QuantBScale; + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + size_t scale_dst_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk, blks_per_sub); + QuantBScaleBegin[scale_dst_offset] = QuantBScale; + } + }); +} + +static void +PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen +) +{ + constexpr size_t BlkBitWidth = 4; + const size_t BlkBytePairCount = BlkLen / 4; + const size_t BlkDataSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + + const size_t SubBlkDataSize = SubBlkLen / 2; + const size_t SubBlkBytePairCount = SubBlkLen / 4; + const size_t SubBlkCountK = MlasDivRoundup(BlockCountK * BlkLen, SubBlkLen); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t n = tid / SubBlkCountK; + const size_t k_subblk = tid % SubBlkCountK; + + const SafeInt src_data_offset = SafeInt(n) * BlockCountK * BlkDataSize + k_subblk * SubBlkDataSize; + const std::byte* QuantBData = QuantBDataBegin + src_data_offset.Value(); + + size_t PackBytePairCount = SubBlkBytePairCount; + size_t PackDataSize = SubBlkDataSize; + + auto pack_subblk = []( + const std::byte* QuantBData, std::byte* PackedQuantBData, + size_t pack_byte_pair_count, size_t pack_data_size + ) { + for (size_t byte_pair_idx = 0; byte_pair_idx < pack_byte_pair_count; ++byte_pair_idx) { + const std::byte src0 = QuantBData[byte_pair_idx]; + const std::byte src1 = QuantBData[byte_pair_idx + pack_data_size / 2]; + + std::byte& dst0 = PackedQuantBData[2 * byte_pair_idx]; + std::byte& dst1 = PackedQuantBData[2 * byte_pair_idx + 1]; + + dst0 = (src0 & std::byte{0x0f}) | ((src1 & std::byte{0x0f}) << 4); + dst1 = (src0 >> 4) | ((src1 >> 4) << 4); + } }; + + if (SubBlkLen > BlkLen && k_subblk == SubBlkCountK - 1 && + SubBlkLen * SubBlkCountK > BlkLen * BlockCountK) { + // this is the last subblk of the column. check if it extends out of the + // BlockCountK. If it does, we shall pack per blocks so that can compute + // on each block instead of each subblk. + PackBytePairCount = BlkBytePairCount; + PackDataSize = BlkDataSize; + const size_t k_blks_remaining = BlockCountK - (SubBlkCountK - 1) * SubBlkLen / BlkLen; + for (size_t k = 0; k < k_blks_remaining; k++) { + const SafeInt k_blk = SafeInt(k_subblk) * SubBlkLen / BlkLen + k; + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData + k * BlkLen / 2, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + // shall not reach here with avx2 + assert(SubBlkLen == 128); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk.Value(), blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData + k * BlkLen / 2, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } else { + if (BlkLen == 16) { + // not to do the compute order layout yet + std::byte* PackedQuantBData = PackedQuantBDataBegin + src_data_offset; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else if (BlkLen >= SubBlkLen) { + const size_t dst_data_offset = GetContinueLayoutOffsetSubBlk(N, n, SubBlkCountK, k_subblk); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * SubBlkDataSize; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } else { + int blks_per_sub = (int)(SubBlkLen / BlkLen); + const SafeInt k_blk = SafeInt(k_subblk) * blks_per_sub; + const size_t dst_data_offset = GetContinueLayoutOffsetBlkInSubBlk(N, n, BlockCountK, k_blk.Value(), blks_per_sub); + std::byte* PackedQuantBData = PackedQuantBDataBegin + dst_data_offset * BlkLen / 2; + pack_subblk(QuantBData, PackedQuantBData, PackBytePairCount, PackDataSize); + } + } + } + ); +} + +template +MLAS_FORCEINLINE void +UnrolledLoopIterations(IterationFn&& f, std::index_sequence /* indices */) +{ + (f(Indices), ...); +} + +template +MLAS_FORCEINLINE void +UnrolledLoop(IterationFn&& f) +{ + UnrolledLoopIterations(std::forward(f), std::make_index_sequence()); +} + +static MLAS_FORCEINLINE __m128 +FoldAccumulators_Lasx(const __m256& acc0, const __m256& acc1, const __m256& acc2, const __m256& acc3) +{ + /* + acc0 = [A0, A1, A2, A3, A4, A5, A6, A7] + acc1 = [B0, B1, B2, B3, B4, B5, B6, B7] + */ + + __m256 tmpAB_lo = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc1, acc0, 0x44), 0xD8); // a1,a2,a5,a6,b1,b2,b5,b6 + __m256 tmpAB_hi = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc1, acc0, 0xEE), 0xD8); // a3,a4,a7,a8,b3,b4,b7,b8 + __m256 tmpCD_lo = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc3, acc2, 0x44), 0xD8); // c1,c2,c5,c6,d1,d2,d5,d6 + __m256 tmpCD_hi = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(acc3, acc2, 0xEE), 0xD8); // c3,c4,c7,c8,d3,d4,d7,d8 + + __m256 tmpABCD_lo1 = (__m256)__lasx_xvpermi_w(tmpCD_lo, tmpAB_lo, 0x44); // a1,a2,c1,c2,b1,b2,d1,d2 + __m256 tmpABCD_lo2 = (__m256)__lasx_xvpermi_w(tmpCD_hi, tmpAB_hi, 0x44); // a3,a4,c3,c4,b3,b4,d3,d4 + __m256 tmpABCD_hi1 = (__m256)__lasx_xvpermi_w(tmpCD_lo, tmpAB_lo, 0xEE); // a5,a6,c5,c6,b5,b6,d5,d6 + __m256 tmpABCD_hi2 = (__m256)__lasx_xvpermi_w(tmpCD_hi, tmpAB_hi, 0xEE); // a7,a8,c7,c8,b7,b8,d7,d8 + + __m256 sumABCD = __lasx_xvfadd_s(__lasx_xvfadd_s(tmpABCD_lo1, tmpABCD_lo2), __lasx_xvfadd_s(tmpABCD_hi1, tmpABCD_hi2)); + + __m256 sum0 = (__m256)__lasx_xvpermi_w(sumABCD, sumABCD, 0xB1); + sumABCD = (__m256)__lasx_xvpermi_d(__lasx_xvfadd_s(sumABCD, sum0), 0xD8); + + sumABCD = (__m256)__lasx_xvpermi_d(__lasx_xvpermi_w(sumABCD, sumABCD, 0x88), 0xD8); + + alignas(32) float tmp[8]; + __lasx_xvst(sumABCD, (void*)&tmp, 0); + __m128 result = (__m128)__lsx_vld((void*)&tmp, 0); + return result; +} + +__m256 +permutevar_ps_lasx(__m256 vec, __m256i idx_mask) +{ + __m256i veci = (__m256i)vec; + __m256i shuffled = __lasx_xvshuf_w(veci, veci, idx_mask); + return (__m256)shuffled; +} + +static void +Q8PackQuantB( + const std::byte* QuantBDataBegin, + std::byte* PackedQuantBDataBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t N, + const size_t BlockCountK, + const size_t BlkLen, + const size_t SubBlkLen +) +{ + constexpr size_t BlkBitWidth = 8; + const size_t StrideN = BlockCountK * BlkLen; + const size_t BlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t SubBlkSize = MlasQNBitBlkDataSizeInBytes(BlkBitWidth, SubBlkLen); + const size_t SubBlkCountK = MlasDivRoundup(StrideN, SubBlkLen); + const size_t RemainderBlockCountK = BlockCountK % (SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1); + const size_t Iterations = N * SubBlkCountK; // one iteration per sub block + + // SubBlkLen rows x 4 columns pack together, then remainder BlkLen x 4 columns if SubBlkLen > BlkLen. + // remainder columns keep the original order. + // SubBlkLen >= 16 and is multiple of 16 + + MlasTrySimpleParallel( + ThreadPool, Iterations, + [&](ptrdiff_t tid) { + const size_t c = tid / SubBlkCountK; + const size_t c_4 = c & (~3), c_res = c & 3; + const size_t r_subblk = tid % SubBlkCountK; + + const SafeInt data_offset = SafeInt(c) * StrideN + r_subblk * SubBlkLen; + const std::byte* src = QuantBDataBegin + data_offset.Value(); + + if (c_4 + 4 <= N) { // full 4 cols + if (RemainderBlockCountK && r_subblk == SubBlkCountK - 1) { // remainder blocks + const SafeInt subblk_data_offset = SafeInt(c_4) * StrideN + r_subblk * SubBlkSize * 4 + c_res * BlkSize; + std::byte* dest = + PackedQuantBDataBegin + subblk_data_offset.Value(); + for (size_t i = 0; i < RemainderBlockCountK; i++) { + std::copy(src, src + BlkSize, dest); + src += BlkSize; + dest += BlkSize * 4; + } + } else { // full subblock + const SafeInt subblk_data_offset = SafeInt(c_4) * StrideN + r_subblk * SubBlkSize * 4 + c_res * SubBlkSize; + std::byte* dest = + PackedQuantBDataBegin + subblk_data_offset.Value(); + std::copy(src, src + SubBlkSize, dest); + } + } else { // remainder cols + const SafeInt remain_data_offset = SafeInt(c) * StrideN + r_subblk * SubBlkSize; + std::byte* dest = + PackedQuantBDataBegin + remain_data_offset.Value(); + std::copy(src, src + std::min(SubBlkSize, StrideN - r_subblk * SubBlkSize), dest); + } + } + ); +} + +static void +Q8ComputePackBlkSum( + size_t BlkLen, + size_t SubBlkLen, + size_t N, + float* QuantBScaleBegin, + const std::byte* QuantBZPBegin, + float* BlockSumBegin, + MLAS_THREADPOOL* ThreadPool, + const size_t BlockCountK +) +{ + SafeInt size = SafeInt(N) * BlockCountK; + std::vector> QuantBScaleBeginCopy(size.Value()); + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, QuantBScaleBeginCopy.begin()); + + MlasTrySimpleParallel(ThreadPool, N * BlockCountK, [&](ptrdiff_t tid) { + const size_t n = tid / BlockCountK; + const size_t n_4 = n & (~3), n_res = n & 3; + const size_t k_blk = tid % BlockCountK; + + const SafeInt src_blk_offset = SafeInt(n) * BlockCountK + k_blk; + const float& QuantBScale = QuantBScaleBeginCopy[src_blk_offset.Value()]; + uint8_t zp = 128; + if (QuantBZPBegin) { + const std::byte* QuantBZP = QuantBZPBegin + src_blk_offset.Value(); + zp = (uint8_t)(*QuantBZP); + } + + const SafeInt dst_offset = ( SafeInt(n / 16) * BlockCountK + k_blk) * 16 + n % 16; + *(BlockSumBegin + dst_offset.Value()) = -QuantBScale * zp; + + if (n_4 + 4 > N) { + SafeInt ptr_offset = SafeInt(n) * BlockCountK + k_blk; + *(QuantBScaleBegin + ptr_offset.Value()) = QuantBScale; + } else if (BlkLen >= SubBlkLen) { + SafeInt ptr_offset = SafeInt(n_4) * BlockCountK + k_blk * 4 + n_res; + *(QuantBScaleBegin + ptr_offset.Value()) = QuantBScale; + } else { + size_t blks_per_sub = SubBlkLen / BlkLen; + size_t remainder_blk = BlockCountK % blks_per_sub; + size_t sub_blk_count_k = MlasDivRoundup(BlockCountK, blks_per_sub); + size_t k_subblk = k_blk / blks_per_sub; + size_t k_blk_res = k_blk % blks_per_sub; + SafeInt dest_offset; + + if (remainder_blk && k_subblk == sub_blk_count_k - 1) { // remainder blocks + dest_offset = SafeInt(n_4) * BlockCountK + k_blk * 4 + n_res; + } else { // full subblock + dest_offset = SafeInt(n_4) * BlockCountK + k_subblk * blks_per_sub * 4 + n_res * blks_per_sub + k_blk_res; + } + + *(QuantBScaleBegin + dest_offset.Value()) = QuantBScale; + } + }); +} + +static void +Q8PackQuantBDataAndBlkSum_lasx( + size_t N, + size_t BlockCountK, + size_t BlkLen, + size_t SubBlkLen, + const std::byte* QuantBDataBegin, + const float* QuantBScaleBegin, + bool HasZeroPoint, + const std::byte* QuantBZPBegin, + PackedQuantBDataStruct& PackedQuantB, + MLAS_THREADPOOL* ThreadPool +) +{ + if (QuantBDataBegin) { + Q8PackQuantB(QuantBDataBegin, PackedQuantB.PackedQuantBData, ThreadPool, N, BlockCountK, BlkLen, SubBlkLen); + } + + if (QuantBScaleBegin) { + std::copy(QuantBScaleBegin, QuantBScaleBegin + N * BlockCountK, PackedQuantB.PackedQuantBScale); + } + + if ((QuantBScaleBegin && !HasZeroPoint) || QuantBZPBegin) { + Q8ComputePackBlkSum(BlkLen, SubBlkLen, N, PackedQuantB.PackedQuantBScale, QuantBZPBegin, PackedQuantB.QuantBBlkSum, ThreadPool, BlockCountK); + } +} + +static MLAS_FORCEINLINE __m128i +convert_2_ps_to_epi8_lasx(__m256 v0, __m256 v1) +{ + // fp32->int32 + __m256i v0_8_epi32 = __lasx_xvftint_w_s(__lasx_xvfrint_s(v0)); + __m256i v1_8_epi32 = __lasx_xvftint_w_s(__lasx_xvfrint_s(v1)); + + alignas(32) int val_0_15_i32[16] = {0}; + alignas(32) int8_t val_0_15_i8[16] = {0}; + + __lasx_xvst(v0_8_epi32, (void*)&val_0_15_i32, 0); + __lasx_xvst(v1_8_epi32, (void*)&val_0_15_i32, 32); + + UnrolledLoop<16>([&](size_t i) { + if (val_0_15_i32[i] > 127) + val_0_15_i8[i] = 127; + else if (val_0_15_i32[i] < -128) + val_0_15_i8[i] = -128; + else + val_0_15_i8[i] = static_cast(val_0_15_i32[i]); + }); + + __m128i result = __lsx_vld((void*)&val_0_15_i8, 0); + return result; +} + +static inline __m256i +lasx_maddubs_epi16_sat(__m256i a, __m256i b) +{ + // a: bytes interpreted as unsigned + // b: bytes interpreted as signed + __m256i zero_h = __lasx_xvldi(0); // 256-bit zeros + + __m256i even_prod16 = __lasx_xvmaddwev_h_bu_b(zero_h, a, b); + __m256i odd_prod16 = __lasx_xvmaddwod_h_bu_b(zero_h, a, b); + + __m256i sum16_sat = __lasx_xvsadd_h(even_prod16, odd_prod16); + + return sum16_sat; // 16-bit signed saturated results (16 lanes) +} + +static inline __m256i +lasx_madd_epi16(__m256i a, __m256i b) +{ + __m256i zero = __lasx_xvldi(0); + __m256i even_acc = __lasx_xvmaddwev_w_h(zero, a, b); + __m256i result = __lasx_xvmaddwod_w_h(even_acc, a, b); + + return result; // 32-bit signed sums, matches _mm256_madd_epi16 semantics (no saturation) +} + +static inline __m256i +lasx_hadd_epi32(__m256i a, __m256i b) +{ + __m256i a_swapped = __lasx_xvshuf4i_w(a, 0xB1); // 0xB1 = binary 10110001 + __m256i b_swapped = __lasx_xvshuf4i_w(b, 0xB1); + + __m256i a_sum = __lasx_xvadd_w(a, a_swapped); + __m256i b_sum = __lasx_xvadd_w(b, b_swapped); + + __m256i a_even = __lasx_xvpermi_w(a_sum, a_sum, 0x88); + __m256i b_even = __lasx_xvpermi_w(b_sum, b_sum, 0x88); + + __m256i result = __lasx_xvpermi_q(a_even, b_even, 0x20); + + return result; +} + +static inline __m256i +lasx_cvtepu8_epi16_emul_from_m128(const __m128i v128) +{ + alignas(32) int8_t num[32] = {0}; + __lsx_vst(v128, (void*)&num, 0); + __m256i result = __lasx_xvld((void*)&num, 0); + result = __lasx_xvexth_hu_bu(__lasx_xvpermi_d(result, 0x72)); + return result; +} + +static MLAS_FORCEINLINE float +hsum_float_8_lasx(__m256 v) +{ + v = __lasx_xvfadd_s(v, (__m256)__lasx_xvpermi_d(v, 0xB1)); + v = __lasx_xvfadd_s(v, (__m256)__lasx_xvpermi_d(v, 0x4E)); + alignas(32) float num[8] = {0.0f}; + __lasx_xvst(v, (void*)num, 0); + + return num[0] + num[1]; +} diff --git a/src/lib/sqnbitgemm_kernel_neon_int8.cpp b/src/lib/sqnbitgemm_kernel_neon_int8.cpp index 8dbd339..b03b812 100644 --- a/src/lib/sqnbitgemm_kernel_neon_int8.cpp +++ b/src/lib/sqnbitgemm_kernel_neon_int8.cpp @@ -187,6 +187,230 @@ QuantizeARow_CompInt8( } } +MLAS_FORCEINLINE +float32x4_t LoadFloat32x4(const float* src, size_t count) +{ + if (count == 4) { + return vld1q_f32(src); + } else if (count == 3) { + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + v = vld1q_lane_f32(src + 1, v, 1); + v = vld1q_lane_f32(src + 2, v, 2); + return v; + } else if (count == 2) { + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + v = vld1q_lane_f32(src + 1, v, 1); + return v; + } else { + assert(count == 1); + float32x4_t v = vdupq_n_f32(0.0f); + v = vld1q_lane_f32(src, v, 0); + return v; + } +} + +template +using I16VecType = typename std::conditional::type; + +template +I16VecType MLAS_FORCEINLINE +PrepareZeroI16() +{ + if constexpr (IsQuantAUnsigned) { + return vdupq_n_u16(0); + } else { + return vdupq_n_s16(0); + } +} + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +) +{ + // First use i8 to quantize A. range [-128, 127] + // If convert to u8, +128. Range [0, 255] + assert(BlkLen % 16 == 0); + assert(BlkLen <= 256); + MLAS_DECLSPEC_ALIGN(static const uint8_t MASK[16], 16) = { + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }; + const int16x8_t v128 = vdupq_n_s16(128); + QuantAType* blob = reinterpret_cast*>(QuantA); + float* scale_ptr = QuantAScale; + size_t k = 0; + for (; k + BlkLen <= CountK; k += BlkLen) { + float32x4_t absMax0 = vdupq_n_f32(0.0f); + float32x4_t absMax1 = vdupq_n_f32(0.0f); + float32x4_t absMax2 = vdupq_n_f32(0.0f); + float32x4_t absMax3 = vdupq_n_f32(0.0f); + + for (size_t kk = 0; kk < BlkLen; kk += 16) { + const float32x4x4_t v0 = vld4q_f32(A + k + kk); + absMax0 = vmaxq_f32(absMax0, vabsq_f32(v0.val[0])); + absMax1 = vmaxq_f32(absMax1, vabsq_f32(v0.val[1])); + absMax2 = vmaxq_f32(absMax2, vabsq_f32(v0.val[2])); + absMax3 = vmaxq_f32(absMax3, vabsq_f32(v0.val[3])); + } + + const float32x4_t max01 = vmaxq_f32(absMax0, absMax1); + const float32x4_t max23 = vmaxq_f32(absMax2, absMax3); + const float32x4_t max0123 = vmaxq_f32(max01, max23); + const float maxScalar = vmaxvq_f32(max0123); + + // Quantize these floats + const float scale = maxScalar / 127.f; + *scale_ptr = scale; + scale_ptr++; + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const float32x4_t mul = vdupq_n_f32(inverse_scale); + + I16VecType sum_8_i16_0 = PrepareZeroI16(); + I16VecType sum_8_i16_1 = PrepareZeroI16(); + + for (size_t kk = 0; kk < BlkLen; kk += 16) { + const float32x4_t vfp32_0 = LoadFloat32x4(A + k + kk, 4); + const float32x4_t vfp32_1 = LoadFloat32x4(A + k + kk + 4, 4); + const float32x4_t vfp32_2 = LoadFloat32x4(A + k + kk + 8, 4); + const float32x4_t vfp32_3 = LoadFloat32x4(A + k + kk + 12, 4); + + const float32x4_t v0 = vmulq_f32(vfp32_0, mul); + const float32x4_t v1 = vmulq_f32(vfp32_1, mul); + const float32x4_t v2 = vmulq_f32(vfp32_2, mul); + const float32x4_t v3 = vmulq_f32(vfp32_3, mul); + + const int32x4_t i0 = vcvtnq_s32_f32(v0); + const int32x4_t i1 = vcvtnq_s32_f32(v1); + const int32x4_t i2 = vcvtnq_s32_f32(v2); + const int32x4_t i3 = vcvtnq_s32_f32(v3); + + const int16x8_t v_8_i16_0 = vcombine_s16(vqmovn_s32(i0), vqmovn_s32(i1)); + const int16x8_t v_8_i16_1 = vcombine_s16(vqmovn_s32(i2), vqmovn_s32(i3)); + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t v_8_u16_0 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16_0, v128)); + const uint16x8_t v_8_u16_1 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16_1, v128)); + const uint8x16_t v_16_u8 = vcombine_u8(vqmovn_u16(v_8_u16_0), vqmovn_u16(v_8_u16_1)); + vst1q_u8(blob + k + kk, v_16_u8); + + // accumulate Sum(a_i) + const uint16x8_t i_8_u16_0 = vmovl_u8(vget_low_u8(v_16_u8)); + const uint16x8_t i_8_u16_1 = vmovl_high_u8(v_16_u8); + sum_8_i16_0 = vaddq_u16(sum_8_i16_0, i_8_u16_0); + sum_8_i16_1 = vaddq_u16(sum_8_i16_1, i_8_u16_1); + } else { + const int8x16_t v_16_i8 = vcombine_s8(vqmovn_s16(v_8_i16_0), vqmovn_s16(v_8_i16_1)); + vst1q_s8(blob + k + kk, v_16_i8); + + // accumulate Sum(a_i) + sum_8_i16_0 = vaddq_s16(sum_8_i16_0, v_8_i16_0); + sum_8_i16_1 = vaddq_s16(sum_8_i16_1, v_8_i16_1); + } + } + + float qsum; + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t sum_8_u16 = vaddq_u16(sum_8_i16_0, sum_8_i16_1); + qsum = static_cast(vaddvq_u16(sum_8_u16)); + } else { + const int16x8_t sum_8_i16 = vaddq_s16(sum_8_i16_0, sum_8_i16_1); + qsum = static_cast(vaddvq_s16(sum_8_i16)); + } + + *AScaledBlkSum = scale * qsum; + AScaledBlkSum++; + } + + if (k < CountK) { + float32x4_t absMax = vdupq_n_f32(0.0f); + + for (size_t kk = k; kk < CountK; kk += 4) { + size_t step = std::min(static_cast(4), CountK - kk); + const float32x4_t v0 = LoadFloat32x4(A + kk, step); + absMax = vmaxq_f32(absMax, vabsq_f32(v0)); + } + + const float maxScalar = vmaxvq_f32(absMax); + const float scale = maxScalar / 127.f; + *scale_ptr = scale; + + const float inverse_scale = (maxScalar != 0.0f) ? 127.f / maxScalar : 0.0f; + const float32x4_t mul = vdupq_n_f32(inverse_scale); + + I16VecType sum_8_i16 = PrepareZeroI16(); + + for (size_t kk = k; kk < CountK; kk += 4) { + size_t step = std::min(static_cast(4), CountK - kk); + const float32x4_t vfp32 = LoadFloat32x4(A + kk, step); + const float32x4_t v_f32 = vmulq_f32(vfp32, mul); + const int32x4_t v_i32 = vcvtnq_s32_f32(v_f32); + const int16x8_t v_8_i16 = vcombine_s16(vqmovn_s32(v_i32), vdup_n_s16(0)); + + if constexpr (IsQuantAUnsigned) { + const uint16x8_t v_8_u16 = vreinterpretq_u16_s16(vaddq_s16(v_8_i16, v128)); + uint8x8_t v_8_u8 = vqmovn_u16(v_8_u16); + vst1_lane_s32(reinterpret_cast(blob + kk), vreinterpret_s32_u8(v_8_u8), 0); + + // accumulate Sum(a_i) + v_8_u8 = vand_u8(v_8_u8, vld1_u8(MASK + 8 - step)); + const uint16x8_t i_8_u16 = vmovl_u8(v_8_u8); + sum_8_i16 = vaddq_u16(sum_8_i16, i_8_u16); + } else { + const int8x8_t v_8_i8 = vqmovn_s16(v_8_i16); + vst1_lane_s32(reinterpret_cast(blob + kk), vreinterpret_s32_s8(v_8_i8), 0); + + // accumulate Sum(a_i) + sum_8_i16 = vaddq_s16(sum_8_i16, v_8_i16); + } + } + + float qsum; + + if constexpr (IsQuantAUnsigned) { + qsum = static_cast(vaddvq_u16(sum_8_i16)); + } else { + qsum = static_cast(vaddvq_s16(sum_8_i16)); + } + + *AScaledBlkSum = scale * qsum; + + memset(blob + CountK, 0, BlkLen - (CountK % BlkLen)); + } +} + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + +template +void MLASCALL +QuantizeARowComputeBlkSum_CompInt8( + size_t BlkLen, + const float* A, + size_t CountK, + std::byte* QuantA, + float* QuantAScale, + float* AScaledBlkSum // scale_k * Sum_blklen(a_i) +); + namespace { @@ -1439,6 +1663,723 @@ SQ4BitGemmKernel_CompInt8( return CountM; } +MLAS_FORCEINLINE void +Q8Int8GemmR2xC8DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountM % MRows2 == 0); + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB03, scaleA1); + const float32x4_t scaleA1B47 = vmulq_n_f32(scaleB47, scaleA1); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc0_47 = vdupq_n_u32(0U); + uint32x4_t acc1_03 = vdupq_n_u32(0U); + uint32x4_t acc1_47 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_0_47, av1_16_i8, 0); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_1_47, av1_16_i8, 1); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_2_47, av1_16_i8, 2); + acc1_47 = vdotq_laneq_u32(acc1_47, bv_packed_3_47, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_u32(acc0_47)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_u32(acc1_03)); + accf1_47 = vfmaq_f32(accf1_47, scaleA1B47, vcvtq_f32_u32(acc1_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32_03); + accf1_47 = vaddq_f32(accf1_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + vst1q_f32(SumPtr + ldc, accf1_03); + vst1q_f32(SumPtr + ldc + 4, accf1_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC8DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc0_47 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vdotq_laneq_u32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_u32(acc0_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountM % MRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB, scaleA1); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + uint32x4_t acc1_03 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vdotq_laneq_u32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_u32(acc1_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + ldc, accf1_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + + uint32x4_t acc0_03 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vdotq_laneq_u32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_u32(acc0_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC1DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t MRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + assert(CountM % MRows2 == 0); + + for (size_t m = 0; m < CountM; m += MRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + float32x4_t accf1 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + const float scaleA1B = scaleB * scaleA1; + + uint32x4_t acc0 = vdupq_n_u32(0U); + uint32x4_t acc1 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + const uint8x16_t av1_16_i8 = vld1q_u8(QuantAPtr + lda); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vdotq_u32(acc0, bv_packed, av0_16_i8); + acc1 = vdotq_u32(acc1, bv_packed, av1_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_u32(acc0), scaleA0B); + accf1 = vfmaq_n_f32(accf1, vcvtq_f32_u32(acc1), scaleA1B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + float32_t accf1v = vaddvq_f32(accf1); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + accf1v += bias; + } + + *SumPtr = accf0v; + *(SumPtr + ldc) = accf1v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1DotProd( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const uint8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + + uint32x4_t acc0 = vdupq_n_u32(0U); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const uint8x16_t av0_16_i8 = vld1q_u8(QuantAPtr); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vdotq_u32(acc0, bv_packed, av0_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_u32(acc0), scaleA0B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + } + + *SumPtr = accf0v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +template <> +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const uint8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +) { + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols8 = 8; + constexpr size_t NCols4 = 4; + constexpr size_t MRows2 = 2; + const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + + const size_t lda = BlockCountK * BlkLen; + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % MRows2; + size_t multipleRows = CountM - remainingRows; + size_t multipleCols8 = CountN & (~(NCols8 - 1)); + size_t multipleCols4 = CountN & (~(NCols4 - 1)); + size_t remainingCols4 = CountN % NCols4; + + if (multipleRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR2xC8DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols8, + BlockCountK, + Bias, + ldc + ); + } + + if (multipleRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR2xC4DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleCols8, + multipleRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc + ); + } + + if (multipleRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR2xC1DotProd( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleCols4, + multipleRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR1xC8DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols8, + BlockCountK, + Bias, + ldc); + } + + if (remainingRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR1xC4DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols8, + remainingRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc); + } + + if (remainingRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR1xC1DotProd( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols4, + remainingRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc); + } + + return CountM; +} + #ifdef USE_KLEIDIAI void SQ4BitGemmKernel_Packed_CompInt8( diff --git a/src/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp b/src/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp new file mode 100644 index 0000000..db040db --- /dev/null +++ b/src/lib/sqnbitgemm_kernel_neon_int8_i8mm.cpp @@ -0,0 +1,743 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Licensed under the MIT License. + +Module Name: + + sqnbitgemm_kernel_neon_int8_i8mm.cpp + +Abstract: + + This module implements the float/quantized n-bit integer matrix + multiplication kernels for ARM NEON specific to + input type T1 as float32 and + MLAS_QNBIT_GEMM_COMPUTE_TYPE SQNBIT_CompInt8 + using i8mm instructions. + +--*/ + +#include "qnbitgemm.h" +#include "qnbitgemm_kernel_neon.h" + +namespace sqnbitgemm_neon +{ + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC8I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB03, scaleA1); + const float32x4_t scaleA1B47 = vmulq_n_f32(scaleB47, scaleA1); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc0_47 = vdupq_n_s32(0); + int32x4_t acc1_03 = vdupq_n_s32(0); + int32x4_t acc1_47 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_0_47, av1_16_i8, 0); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_1_47, av1_16_i8, 1); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_2_47, av1_16_i8, 2); + acc1_47 = vusdotq_laneq_s32(acc1_47, bv_packed_3_47, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_s32(acc0_47)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_s32(acc1_03)); + accf1_47 = vfmaq_f32(accf1_47, scaleA1B47, vcvtq_f32_s32(acc1_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32_03); + accf1_47 = vaddq_f32(accf1_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + vst1q_f32(SumPtr + ldc, accf1_03); + vst1q_f32(SumPtr + ldc + 4, accf1_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC8I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NCols8 = 8; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol8 = BlockCountK * BlkLen * NCols8; + + assert(CountN % NCols8 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols8) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf0_47 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB03 = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleB47 = vld1q_f32(QuantBScalePtr + NCols4); + + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB03, scaleA0); + const float32x4_t scaleA0B47 = vmulq_n_f32(scaleB47, scaleA0); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc0_47 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_0_47 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_1_47 = vld1q_u8(QuantBDataPtr + 48); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 64); + uint8x16_t bv_packed_2_47 = vld1q_u8(QuantBDataPtr + 80); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 96); + uint8x16_t bv_packed_3_47 = vld1q_u8(QuantBDataPtr + 112); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_0_47, av0_16_i8, 0); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_1_47, av0_16_i8, 1); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_2_47, av0_16_i8, 2); + acc0_47 = vusdotq_laneq_s32(acc0_47, bv_packed_3_47, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols8 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf0_47 = vfmaq_f32(accf0_47, scaleA0B47, vcvtq_f32_s32(acc0_47)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols8; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32_03 = vld1q_f32(BiasPtr); + const float32x4_t bias_4_f32_47 = vld1q_f32(BiasPtr + 4); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32_03); + accf0_47 = vaddq_f32(accf0_47, bias_4_f32_47); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + 4, accf0_47); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol8; + QuantBScaleColPtr += NCols8 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols8 : 0; + SumPtr += NCols8; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC4I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountM % NRows2 == 0); + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + float32x4_t accf1_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + const float32x4_t scaleA1B03 = vmulq_n_f32(scaleB, scaleA1); + + int32x4_t acc0_03 = vdupq_n_s32(0); + int32x4_t acc1_03 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_0_03, av1_16_i8, 0); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_1_03, av1_16_i8, 1); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_2_03, av1_16_i8, 2); + acc1_03 = vusdotq_laneq_s32(acc1_03, bv_packed_3_03, av1_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + accf1_03 = vfmaq_f32(accf1_03, scaleA1B03, vcvtq_f32_s32(acc1_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + accf1_03 = vaddq_f32(accf1_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + vst1q_f32(SumPtr + ldc, accf1_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC4I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NCols4 = 4; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol4 = BlockCountK * BlkLen * NCols4; + + assert(CountN % NCols4 == 0); + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; n += NCols4) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0_03 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float32x4_t scaleB = vld1q_f32(QuantBScalePtr); + const float32x4_t scaleA0B03 = vmulq_n_f32(scaleB, scaleA0); + + int32x4_t acc0_03 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed_0_03 = vld1q_u8(QuantBDataPtr); + uint8x16_t bv_packed_1_03 = vld1q_u8(QuantBDataPtr + 16); + uint8x16_t bv_packed_2_03 = vld1q_u8(QuantBDataPtr + 32); + uint8x16_t bv_packed_3_03 = vld1q_u8(QuantBDataPtr + 48); + + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_0_03, av0_16_i8, 0); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_1_03, av0_16_i8, 1); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_2_03, av0_16_i8, 2); + acc0_03 = vusdotq_laneq_s32(acc0_03, bv_packed_3_03, av0_16_i8, 3); + + QuantAPtr += KStep16; + QuantBDataPtr += NCols4 * KStep16; + } + + accf0_03 = vfmaq_f32(accf0_03, scaleA0B03, vcvtq_f32_s32(acc0_03)); + + ++QuantAScalePtr; + QuantBScalePtr += NCols4; + } + + if (BiasPtr != nullptr) { + const float32x4_t bias_4_f32 = vld1q_f32(BiasPtr); + accf0_03 = vaddq_f32(accf0_03, bias_4_f32); + } + + vst1q_f32(SumPtr, accf0_03); + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol4; + QuantBScaleColPtr += NCols4 * BlockCountK; + + BiasPtr += BiasPtr != nullptr ? NCols4 : 0; + SumPtr += NCols4; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR2xC1I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t NRows2 = 2; + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + assert(CountM % NRows2 == 0); + + for (size_t m = 0; m < CountM; m += NRows2) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + float32x4_t accf1 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleA1 = *(QuantAScalePtr + BlockCountK); + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + const float scaleA1B = scaleB * scaleA1; + + int32x4_t acc0 = vdupq_n_s32(0); + int32x4_t acc1 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + const int8x16_t av1_16_i8 = vld1q_s8(QuantAPtr + lda); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vusdotq_s32(acc0, bv_packed, av0_16_i8); + acc1 = vusdotq_s32(acc1, bv_packed, av1_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_s32(acc0), scaleA0B); + accf1 = vfmaq_n_f32(accf1, vcvtq_f32_s32(acc1), scaleA1B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + float32_t accf1v = vaddvq_f32(accf1); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + accf1v += bias; + } + + *SumPtr = accf0v; + *(SumPtr + ldc) = accf1v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +MLAS_FORCEINLINE void +Q8Int8GemmR1xC1I8MM( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float* QuantBScale, + float* C, + size_t CountM, + size_t CountN, + size_t BlockCountK, + const float* Bias, + size_t ldc +) +{ + constexpr size_t KStep16 = 16; + + const size_t lda = BlockCountK * BlkLen; + const size_t StrideQuantBDataCol = BlockCountK * BlkLen; + + for (size_t m = 0; m < CountM; ++m) { + const uint8_t* QuantBDataColPtr = QuantBData; + const float* QuantBScaleColPtr = QuantBScale; + const float* BiasPtr = Bias; + auto* SumPtr = C + m * ldc; + + for (size_t n = 0; n < CountN; ++n) { + const int8_t* QuantAPtr = QuantA + m * lda; + const float* QuantAScalePtr = QuantAScale + m * BlockCountK; + + const uint8_t* QuantBDataPtr = QuantBDataColPtr; + const float* QuantBScalePtr = QuantBScaleColPtr; + + float32x4_t accf0 = vdupq_n_f32(0.0f); + + for (size_t i = 0; i < BlockCountK; ++i) { + const float scaleA0 = *QuantAScalePtr; + const float scaleB = *QuantBScalePtr; + const float scaleA0B = scaleB * scaleA0; + + int32x4_t acc0 = vdupq_n_s32(0); + + for (size_t k = 0; k < BlkLen; k += KStep16) { + const int8x16_t av0_16_i8 = vld1q_s8(QuantAPtr); + + uint8x16_t bv_packed = vld1q_u8(QuantBDataPtr); + + acc0 = vusdotq_s32(acc0, bv_packed, av0_16_i8); + + QuantAPtr += KStep16; + QuantBDataPtr += KStep16; + } + + accf0 = vfmaq_n_f32(accf0, vcvtq_f32_s32(acc0), scaleA0B); + + ++QuantAScalePtr; + ++QuantBScalePtr; + } + + float32_t accf0v = vaddvq_f32(accf0); + + if (BiasPtr != nullptr) { + const float bias = *BiasPtr; + accf0v += bias; + } + + *SumPtr = accf0v; + + // move to next NCols columns + QuantBDataColPtr += StrideQuantBDataCol; + QuantBScaleColPtr += BlockCountK; + + BiasPtr += BiasPtr ? 1 : 0; + ++SumPtr; + } + } +} + +template <> +size_t +MlasQ8Int8GemmKernelNeon( + const size_t BlkLen, + const int8_t* QuantA, + const float* QuantAScale, + const uint8_t* QuantBData, + const float * QuantBScale, + float* C, + const size_t CountM, + const size_t CountN, + const size_t CountK, + const float* Bias, + const size_t ldc +) { + constexpr size_t BlkBitWidth = 8; + constexpr size_t NCols8 = 8; + constexpr size_t NCols4 = 4; + constexpr size_t NRows2 = 2; + const size_t BlockCountK = MlasDivRoundup(CountK, BlkLen); + + const size_t lda = BlockCountK * BlkLen; + const size_t lda_scale = BlockCountK; + const size_t StrideQuantBData = BlockCountK * MlasQNBitBlkDataSizeInBytes(BlkBitWidth, BlkLen); + const size_t StrideQuantBScale = BlockCountK; + + size_t remainingRows = CountM % NRows2; + size_t multipleRows = CountM - remainingRows; + size_t multipleCols8 = CountN & (~(NCols8 - 1)); + size_t multipleCols4 = CountN & (~(NCols4 - 1)); + size_t remainingCols4 = CountN % NCols4; + + if (multipleRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR2xC8I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData, + QuantBScale, + C, + multipleRows, + multipleCols8, + BlockCountK, + Bias, + ldc + ); + } + + if (multipleRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR2xC4I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleCols8, + multipleRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc + ); + } + + if (multipleRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR2xC1I8MM( + BlkLen, + QuantA, + QuantAScale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleCols4, + multipleRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc + ); + } + + if (remainingRows > 0 && multipleCols8 > 0) { + Q8Int8GemmR1xC8I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData, + QuantBScale, + C + multipleRows * ldc, + remainingRows, + multipleCols8, + BlockCountK, + Bias, + ldc); + } + + if (remainingRows > 0 && multipleCols4 > multipleCols8) { + Q8Int8GemmR1xC4I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols8 * StrideQuantBData, + QuantBScale + multipleCols8 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols8, + remainingRows, + multipleCols4 - multipleCols8, + BlockCountK, + Bias ? Bias + multipleCols8 : nullptr, + ldc); + } + + if (remainingRows > 0 && remainingCols4 > 0) { + Q8Int8GemmR1xC1I8MM( + BlkLen, + QuantA + multipleRows * lda, + QuantAScale + multipleRows * lda_scale, + QuantBData + multipleCols4 * StrideQuantBData, + QuantBScale + multipleCols4 * StrideQuantBScale, + C + multipleRows * ldc + multipleCols4, + remainingRows, + remainingCols4, + BlockCountK, + Bias ? Bias + multipleCols4 : nullptr, + ldc); + } + + return CountM; +} + +} // namespace sqnbitgemm_neon diff --git a/src/lib/sve/elementwise_sve.cpp b/src/lib/sve/elementwise_sve.cpp new file mode 100644 index 0000000..fb2972a --- /dev/null +++ b/src/lib/sve/elementwise_sve.cpp @@ -0,0 +1,683 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + elementwise_sve.cpp + +Abstract: + + This module contains the implementation of SVE-based elementwise operations +--*/ + +#include "mlasi_sve.h" +#include + +// +// Bundles the constants for use by kernels written in assembly. +// + +MLAS_INTERNAL_DATA const struct { + float ErfUpperAbsRange; + float ErfSplitBoundary; + float ErfSMALL_P0; + float ErfSMALL_P1; + float ErfSMALL_P2; + float ErfSMALL_P3; + float ErfSMALL_P4; + float ErfSMALL_P5_Minus_One; + float ErfReserved0; + float ErfBIG_P0; + float ErfBIG_P1; + float ErfBIG_P2; + float ErfBIG_P3; + float ErfBIG_P4; + float ErfBIG_P5; + float ErfBIG_P6_Minus_One; + float ErfNegZero; + float ErfOne; + + float Exp_UpperRange; + float Exp_LowerRange; + float Exp_Log2Reciprocal; + float Exp_log2_hi; + float Exp_log2_lo; + float Exp_P0; + float Exp_P1; + float Exp_P2; + float Exp_P3; + float Exp_P4; + float Exp_P5; + float Exp_P6; + float Exp_C; + int32_t Exp_X7F; +} MlasSveErfConstants = { + 3.925f, + 0.921875f, + -5.99104969e-4f, + 4.99339588e-3f, + -2.67667342e-2f, + 1.12818025e-1f, + -3.76124859e-1f, + 1.28379151e-1f, + 0.0f, + 1.72948930e-5f, + -3.83208680e-4f, + 3.88393435e-3f, + -2.42545605e-2f, + 1.06777847e-1f, + 6.34846687e-1f, + 1.28717512e-1f, + -0.0f, + 1.0f, + + // Independent parameters to calculate Exp for Erff() + 88.3762626647950f, + -88.3762626647949f, + 1.44269504088896341f, + -6.93145752e-1f, + -1.42860677e-6f, + 1.38319808e-3f, + 8.37550033e-3f, + 4.16689515e-2f, + 1.66664466e-1f, + 4.99999851e-1f, + 1.00000000e+0f, + 1.00000000e+0f, + 1.25829120e+7f, + 127, +}; + +MLAS_INTERNAL_DATA const struct { + float LowerRange; + float UpperRange; + float alpha_9; + float alpha_7; + float alpha_5; + float alpha_3; + float alpha_1; + float beta_10; + float beta_8; + float beta_6; + float beta_4; + float beta_2; + float beta_0; + float one_half; +} MlasSveLogisticConstants = { + -18.0f, + 18.0f, + 4.37031012579801e-11f, + 1.15627324459942e-07f, + 6.08574864600143e-05f, + 8.51377133304701e-03f, + 2.48287947061529e-01f, + 6.10247389755681e-13f, + 5.76102136993427e-09f, + 6.29106785017040e-06f, + 1.70198817374094e-03f, + 1.16817656904453e-01f, + 9.93151921023180e-01f, + 0.5f, +}; + +MLAS_INTERNAL_DATA const struct { + float LowerRange; + float UpperRange; + float LowerRangeSumExp; + float UpperRangeSumExp; + float RoundingBias; + float Log2Reciprocal; + float Log2High; + float Log2Low; + float poly_0; + float poly_1; + float poly_2; + float poly_3; + float poly_4; + float poly_56; + int32_t MinimumExponent; + int32_t MaximumExponent; +} MlasSveExpConstants = { + -103.9720840454f, + 88.7762626647950f, + -88.3762626647949f, + 88.3762626647949f, + MLAS_ROUNDING_BIAS_MAGIC, + 1.44269504088896341f, + -6.93145752e-1f, + -1.42860677e-6f, + 0x1.694000p-10, + 0x1.125edcp-7, + 0x1.555b5ap-5, + 0x1.555450p-3, + 0x1.fffff6p-2, + 0x1.000000p+0, + int32_t(0xC1000000), + int32_t(0x3F800000), +}; + +MLAS_INTERNAL_DATA const float MlasSveMinimumF32Value = std::numeric_limits::lowest(); + +void +MLASCALL +MlasSveErfKernel( + const float* Input, + float* Output, + size_t N + ) +/*++ + +Routine Description: + + This routine implements the generic kernel for the error function. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + +Return Value: + + None. + +--*/ +{ + MLAS_SVBOOL Pred = svptrue_b32(); + size_t sve_veclen = svcntw(); + size_t stride = sve_veclen; + + while (N > 0) { + // If fewer that SVE vector length elements are remaining, adjust the predicate + if (N < sve_veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Value = MlasSveLoadFloat32(Pred, Input); + MLAS_SVFLOAT32 NegZero = MlasSveBroadcastFloat32(MlasSveErfConstants.ErfNegZero); + MLAS_SVFLOAT32 SignMask = MlasSveAndFloat32(Pred, Value, NegZero); + MLAS_SVFLOAT32 AbsValue = MlasSveAndNotFloat32(Pred, NegZero, Value); + AbsValue = MlasSveMinimumFloat32(Pred, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfUpperAbsRange), AbsValue); + MLAS_SVFLOAT32 SquareValue = MlasSveMultiplyFloat32(Pred, AbsValue, AbsValue); + + MLAS_SVFLOAT32 r_small = MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P0); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, SquareValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P1)); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, SquareValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P2)); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, SquareValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P3)); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, SquareValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P4)); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, SquareValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSMALL_P5_Minus_One)); + r_small = MlasSveMultiplyAddFloat32(Pred, r_small, AbsValue, AbsValue); + MLAS_SVFLOAT32 split_mask = MlasSveGreaterThanFloat32(Pred, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfSplitBoundary)); + r_small = MlasSveAndNotFloat32(Pred, split_mask, r_small); + + AbsValue = MlasSveAndFloat32(Pred, split_mask, AbsValue); + MLAS_SVFLOAT32 r_big = MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P0); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P1)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P2)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P3)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P4)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P5)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfBIG_P6_Minus_One)); + r_big = MlasSveMultiplyAddFloat32(Pred, r_big, AbsValue, AbsValue); + + r_big = MlasSveXorFloat32(Pred, r_big, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfNegZero)); + r_big = MlasSveMaximumFloat32(Pred, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_LowerRange), r_big); + MLAS_SVFLOAT32 exp_c = MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_C); + MLAS_SVFLOAT32 r = MlasSveMultiplyAddFloat32(Pred, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_Log2Reciprocal), r_big, exp_c); + r = MlasSveSubtractFloat32(Pred, r, exp_c); + + MLAS_SVFLOAT32 fx = MlasSveMultiplyAddFloat32(Pred, r, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_log2_hi), r_big); + fx = MlasSveMultiplyAddFloat32(Pred, r, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_log2_lo), fx); + + MLAS_SVFLOAT32 y = MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P0); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P1)); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P2)); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P3)); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P4)); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P5)); + y = MlasSveMultiplyAddFloat32(Pred, y, fx, MlasSveBroadcastFloat32(MlasSveErfConstants.Exp_P6)); + + y = MlasSveMultiplyFloat32(Pred, y, MlasSvePowerOf2Float32(Pred, r)); + y = MlasSveSubtractFloat32(Pred, MlasSveBroadcastFloat32(MlasSveErfConstants.ErfOne), y); + + y = MlasSveOrFloat32(Pred, r_small, y); + y = MlasSveOrFloat32(Pred, y, SignMask); + MlasSveStoreFloat32(Pred, Output, y); + + Input += stride; + Output += stride; + N -= stride; + } +} + +void +MLASCALL +MlasSveLogisticKernel( + const float* Input, + float* Output, + size_t N + ) +/*++ + +Routine Description: + + This routine implements the generic kernel for the logistic function. + +Arguments: + + Input - Supplies the input buffer. + + Output - Supplies the output buffer. + + N - Supplies the number of elements to process. + +Return Value: + + None. + +--*/ +{ + MLAS_SVBOOL Pred = svptrue_b32(); + size_t sve_veclen = svcntw(); + size_t stride = sve_veclen; + + while (N > 0) { + // If fewer that SVE vector length elements are remaining, adjust the predicate + if (N < sve_veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Value = MlasSveLoadFloat32(Pred, Input); + + Value = MlasSveMaximumFloat32(Pred, MlasSveBroadcastFloat32(MlasSveLogisticConstants.LowerRange), Value); + Value = MlasSveMinimumFloat32(Pred, MlasSveBroadcastFloat32(MlasSveLogisticConstants.UpperRange), Value); + + MLAS_SVFLOAT32 ValueSquared = MlasSveMultiplyFloat32(Pred, Value, Value); + + MLAS_SVFLOAT32 p; + p = MlasSveMultiplyAddFloat32( + Pred, + ValueSquared, + MlasSveBroadcastFloat32(MlasSveLogisticConstants.alpha_9), + MlasSveBroadcastFloat32(MlasSveLogisticConstants.alpha_7) + ); + p = MlasSveMultiplyAddFloat32(Pred, p, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.alpha_5)); + p = MlasSveMultiplyAddFloat32(Pred, p, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.alpha_3)); + p = MlasSveMultiplyAddFloat32(Pred, p, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.alpha_1)); + p = MlasSveMultiplyFloat32(Pred, p, Value); + + MLAS_SVFLOAT32 q; + q = MlasSveMultiplyAddFloat32( + Pred, + ValueSquared, + MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_10), + MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_8) + ); + q = MlasSveMultiplyAddFloat32(Pred, q, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_6)); + q = MlasSveMultiplyAddFloat32(Pred, q, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_4)); + q = MlasSveMultiplyAddFloat32(Pred, q, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_2)); + q = MlasSveMultiplyAddFloat32(Pred, q, ValueSquared, MlasSveBroadcastFloat32(MlasSveLogisticConstants.beta_0)); + + MlasSveStoreFloat32( + Pred, + Output, + MlasSveClampFloat32(Pred,MlasSveAddFloat32( + Pred, + MlasSveDivideFloat32(Pred, p, q), + MlasSveBroadcastFloat32(0.5f) + ), 0.0f, 1.0f) + ); + + Input += stride; + Output += stride; + N -= stride; + } +} + +/* +SVE implementation of expf() using a polynomial approximation is taken from ARM Compute Library Repository. +https://github.com/ARM-software/ComputeLibrary/blob/9f7a1fb06bc0435d989a9a6a3c0fd2cebfedbf5f/src/core/NEON/SVEMath.inl#L105 +*/ +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveComputeExpVector( + MLAS_SVBOOL Pred, + MLAS_SVFLOAT32 Vector +) +{ + const uint32_t svexp_f32_coeff[] = { + 0x3f7ffff6, // x^1: 0x1.ffffecp-1f + 0x3efffedb, // x^2: 0x1.fffdb6p-2f + 0x3e2aaf33, // x^3: 0x1.555e66p-3f + 0x3d2b9f17, // x^4: 0x1.573e2ep-5f + 0x3c072010, // x^5: 0x1.0e4020p-7f + }; + + const auto c1 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(svexp_f32_coeff[0])); + const auto c2 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(svexp_f32_coeff[1])); + const auto c3 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(svexp_f32_coeff[2])); + const auto c4 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(svexp_f32_coeff[3])); + const auto c5 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(svexp_f32_coeff[4])); + + const auto shift = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(0x4b00007f)); // 2^23 + 127 = 0x1.0000fep23f + const auto inv_ln2 = MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(0x3fb8aa3b)); // 1 / ln(2) = 0x1.715476p+0f + const auto neg_ln2_hi = + MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(0xbf317200)); // -ln(2) from bits -1 to -19: -0x1.62e400p-1f + const auto neg_ln2_lo = + MlasSveReinterpretAsFLOAT32(MlasSveBroadcastUINT32(0xb5bfbe8e)); // -ln(2) from bits -20 to -42: -0x1.7f7d1cp-20f + + const auto inf = MlasSveBroadcastFloat32(std::numeric_limits::infinity()); + const auto max_input = MlasSveBroadcastFloat32(88.37f); // Approximately ln(2^127.5) + const auto zero = MlasSveZeroFloat32(); + const auto min_input = MlasSveBroadcastFloat32(-86.64f); // Approximately ln(2^-125) + + // Range reduction: + // e^x = 2^n * e^r + // where: + // n = floor(x / ln(2)) + // r = x - n * ln(2) + // + // By adding x / ln(2) with 2^23 + 127 (shift): + // * As FP32 fraction part only has 23-bits, the addition of 2^23 + 127 forces decimal part + // of x / ln(2) out of the result. The integer part of x / ln(2) (i.e. n) + 127 will occupy + // the whole fraction part of z in FP32 format. + // Subtracting 2^23 + 127 (shift) from z will result in the integer part of x / ln(2) + // (i.e. n) because the decimal part has been pushed out and lost. + // * The addition of 127 makes the FP32 fraction part of z ready to be used as the exponent + // in FP32 format. Left shifting z by 23 bits will result in 2^n. + const auto z = MlasSveMultiplyAddFloat32(Pred, Vector, inv_ln2, shift); + const auto n = MlasSveSubtractFloat32(Pred, z, shift); + const auto scale = MlasSveReinterpretAsFLOAT32(MlasSveShiftLeftUInt32<23>(Pred, MlasSveReinterpretAsUInt32(z))); // 2^n + + // The calculation of n * ln(2) is done using 2 steps to achieve accuracy beyond FP32. + // This outperforms longer Taylor series (3-4 tabs) both in term of accuracy and performance. + const auto r_hi = MlasSveMultiplyAddFloat32(Pred, n, neg_ln2_hi, Vector); + const auto r = MlasSveMultiplyAddFloat32(Pred, n, neg_ln2_lo, r_hi); + + // Compute the truncated Taylor series of e^r. + // poly = scale * (1 + c1 * r + c2 * r^2 + c3 * r^3 + c4 * r^4 + c5 * r^5) + const auto r2 = MlasSveMultiplyFloat32(Pred, r, r); + + const auto p1 = MlasSveMultiplyFloat32(Pred, c1, r); + const auto p23 = MlasSveMultiplyAddFloat32(Pred, c3, r, c2); + const auto p45 = MlasSveMultiplyAddFloat32(Pred, c5, r, c4); + const auto p2345 = MlasSveMultiplyAddFloat32(Pred, p45, r2, p23); + const auto p12345 = MlasSveMultiplyAddFloat32(Pred, p2345, r2, p1); + + auto poly = MlasSveMultiplyAddFloat32(Pred, p12345, scale, scale); + + // Handle underflow and overflow. + poly = MlasSveSelect(MlasSveCompareLessThan(Pred, Vector, min_input), zero, poly); + poly = MlasSveSelect(MlasSveCompareGreaterThan(Pred, Vector, max_input), inf, poly); + + return poly; +} + +void +MLASCALL +MlasSveComputeExpF32Kernel( + const float* Input, + float* Output, + size_t N +) +{ + const size_t veclen = svcntw(); + + // Fast path: Use scalar loop when N is 1 + if (N == 1) { + Output[0] = expf(Input[0]); + return; + } + + // Vectorized path + MLAS_SVBOOL Pred = svptrue_b32(); + size_t stride = veclen; + + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + + MLAS_SVFLOAT32 Vector = MlasSveLoadFloat32(Pred, Input); + Vector = MlasSveComputeExpVector(Pred, Vector); + MlasSveStoreFloat32(Pred, Output, Vector); + + Input += stride; + Output += stride; + N -= stride; + } +} + +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveComputeSumExpVector( + MLAS_SVBOOL Pred, + MLAS_SVFLOAT32 Vector, + MLAS_SVFLOAT32 NegativeMaximumVector +) +{ + Vector = MlasSveAddFloat32(Pred, Vector, NegativeMaximumVector); + Vector = MlasSveMaximumFloat32(Pred, MlasSveBroadcastFloat32(MlasSveExpConstants.LowerRangeSumExp), Vector); + + const auto RoundingBias = MlasSveBroadcastFloat32(MlasSveExpConstants.RoundingBias); + auto biased = MlasSveMultiplyAddFloat32(Pred, Vector, MlasSveExpConstants.Log2Reciprocal, RoundingBias); + auto m = MlasSveSubtractFloat32(Pred, biased, RoundingBias); + + Vector = MlasSveMultiplyAddFloat32(Pred, m, MlasSveExpConstants.Log2High, Vector); + Vector = MlasSveMultiplyAddFloat32(Pred, m, MlasSveExpConstants.Log2Low, Vector); + + auto normal = MlasSveShiftLeftInt32<23>(Pred, MlasSveReinterpretAsInt32(biased)); + normal = MlasSveAddInt32(Pred, normal, MlasSveBroadcastInt32(MlasSveExpConstants.MaximumExponent)); + + auto p = MlasSveBroadcastFloat32(MlasSveExpConstants.poly_0); + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_1); + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_2); + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_3); + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_4); + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_56); // <--| + p = MlasSveMultiplyAddFloat32(Pred, p, Vector, MlasSveExpConstants.poly_56); // Twice? + + p = MlasSveMultiplyFloat32(Pred, p, MlasSveReinterpretAsFloat32(normal)); + return p; +} + +float +MLASCALL +MlasSveComputeSumExpF32Kernel( + const float* Input, + float* Output, + size_t N, + const float* NegativeMaximum +) +/** + * Potential optimization: Consider applying loop unrolling to improve instruction-level + * parallelism (ILP) in this kernel. Evaluate the performance impact using benchmarks + * before and after implementing the optimization. + */ +{ + if (N == 1) { + float result = expf(Input[0] + *NegativeMaximum); + if (Output != nullptr) { + Output[0] = result; + } + return result; + } + + MLAS_SVBOOL Pred = svptrue_b32(); + size_t veclen = svcntw(); + size_t stride = veclen; + float sum = 0.0f; + + MLAS_SVFLOAT32 NegativeMaximumVector = MlasSveBroadcastFloat32(*NegativeMaximum); + + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + + MLAS_SVFLOAT32 Vector = MlasSveLoadFloat32(Pred, Input); + Vector = MlasSveComputeSumExpVector(Pred, Vector, NegativeMaximumVector); + + if (Output != nullptr) { + MlasSveStoreFloat32(Pred, Output, Vector); + Output += stride; + } + + sum += MlasSveReduceAddFloat32(Pred, Vector); + + Input += stride; + N -= stride; + } + return sum; +} + +float MLASCALL +MlasSveReduceMaximumF32Kernel( + const float* Input, + size_t N +) +{ + size_t veclen = svcntw(); + MLAS_SVBOOL Pred = svptrue_b32(); + + float Maximum; + MLAS_SVFLOAT32 MaximumVector0 = MlasSveBroadcastFloat32(MlasSveMinimumF32Value); + + if (N >= veclen * 4) { + MLAS_SVFLOAT32 MaximumVector1 = MaximumVector0; + MLAS_SVFLOAT32 MaximumVector2 = MaximumVector0; + MLAS_SVFLOAT32 MaximumVector3 = MaximumVector0; + + while (N >= veclen * 4) { + MaximumVector0 = MlasSveMaximumFloat32(Pred, MaximumVector0, MlasSveLoadFloat32(Pred, Input)); + MaximumVector1 = MlasSveMaximumFloat32(Pred, MaximumVector1, MlasSveLoadFloat32(Pred, Input + veclen)); + MaximumVector2 = MlasSveMaximumFloat32(Pred, MaximumVector2, MlasSveLoadFloat32(Pred, Input + 2 * veclen)); + MaximumVector3 = MlasSveMaximumFloat32(Pred, MaximumVector3, MlasSveLoadFloat32(Pred, Input + 3 * veclen)); + + Input += veclen * 4; + N -= veclen * 4; + } + + MaximumVector0 = MlasSveMaximumFloat32(Pred, MaximumVector0, MaximumVector1); + MaximumVector2 = MlasSveMaximumFloat32(Pred, MaximumVector2, MaximumVector3); + MaximumVector0 = MlasSveMaximumFloat32(Pred, MaximumVector0, MaximumVector2); + } + size_t stride = veclen; + + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Vector = MlasSveLoadFloat32(Pred, Input); + MaximumVector0 = MlasSveMaximumFloat32(Pred, MaximumVector0, Vector); + + Input += stride; + N -= stride; + } + + Maximum = MlasSveReduceMaximumFloat32(svptrue_b32(), MaximumVector0); + return Maximum; +} + +void +MLASCALL +MlasSveReduceMinimumMaximumF32Kernel( + const float* Input, + float* Min, + float* Max, + size_t N +) +{ + MLAS_SVBOOL Pred = svptrue_b32(); + size_t veclen = svcntw(); + size_t stride = veclen; + + float tmp_min = std::numeric_limits::max(); + float tmp_max = std::numeric_limits::lowest(); + + MLAS_SVFLOAT32 MaximumVector = MlasSveBroadcastFloat32(tmp_max); + MLAS_SVFLOAT32 MinimumVector = MlasSveBroadcastFloat32(tmp_min); + + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Vector = MlasSveLoadFloat32(Pred, Input); + MaximumVector = MlasSveMaximumFloat32(Pred, MaximumVector, Vector); + MinimumVector = MlasSveMinimumFloat32(Pred, MinimumVector, Vector); + + Input += stride; + N -= stride; + } + *Min = MlasSveReduceMinimumFloat32(svptrue_b32(), MinimumVector); + *Max = MlasSveReduceMaximumFloat32(svptrue_b32(), MaximumVector); +} + +void +MLASCALL +MlasSveComputeSoftmaxOutputF32Kernel( + float* Output, + size_t N, + const float* Parameters +) +{ + MLAS_SVBOOL Pred = svptrue_b32(); + size_t veclen = svcntw(); + size_t stride = veclen; + + const float Scale = Parameters[0]; + const MLAS_SVFLOAT32 ScaleVector = MlasSveBroadcastFloat32(Scale); + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Vector = MlasSveMultiplyFloat32(Pred, ScaleVector, MlasSveLoadFloat32(Pred, Output)); + MlasSveStoreFloat32(Pred, Output, Vector); + + Output += stride; + N -= stride; + } +} + +void +MLASCALL +MlasSveComputeLogSoftmaxOutputF32Kernel( + const float* Input, + float* Output, + size_t N, + const float* Parameters +) +{ + MLAS_SVBOOL Pred = svptrue_b32(); + size_t veclen = svcntw(); + size_t stride = veclen; + + const float NegativeMaximum = Parameters[0]; + const float Logarithm = Parameters[1]; + MLAS_SVFLOAT32 NegativeMaximumVector = MlasSveBroadcastFloat32(NegativeMaximum); + MLAS_SVFLOAT32 LogarithmVector = MlasSveBroadcastFloat32(Logarithm); + + while (N > 0) { + if (N < veclen) { + Pred = svwhilelt_b32(0, (int32_t)N); + stride = N; + } + MLAS_SVFLOAT32 Vector = MlasSveLoadFloat32(Pred, Input); + Vector = MlasSveAddFloat32(Pred, Vector, NegativeMaximumVector); + Vector = MlasSveSubtractFloat32(Pred, Vector, LogarithmVector); + MlasSveStoreFloat32(Pred, Output, Vector); + + Input += stride; + Output += stride; + N -= stride; + } + +} diff --git a/src/lib/sve/mlasi_sve.h b/src/lib/sve/mlasi_sve.h new file mode 100644 index 0000000..67a4bf4 --- /dev/null +++ b/src/lib/sve/mlasi_sve.h @@ -0,0 +1,653 @@ +/*++ + +Copyright 2025 FUJITSU LIMITED + +Module Name: + + mlasi_sve.h + +Abstract: + + This module contains the procedure prototypes for the SVE intrinsics. + +--*/ + +#pragma once + +#include "../mlasi.h" +#include // SVE intrinsic header + +#ifndef __clang__ +#pragma GCC push_options +#pragma GCC target("arch=armv8.2-a+sve") + +// Use Clang-specific per-function attribute +#ifdef __clang__ +#define MLAS_SVE_TARGET __attribute__((target("arch=armv8.2-a+sve"))) +#else +#define MLAS_SVE_TARGET +#endif + +typedef svfloat32_t MLAS_SVFLOAT32; +typedef svint32_t MLAS_SVINT32; +typedef svuint32_t MLAS_SVUINT32; +typedef svbool_t MLAS_SVBOOL; + +// function decarations +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveComputeExpVector( + MLAS_SVBOOL Pred, + MLAS_SVFLOAT32 Vector +); + +void +MLASCALL +MlasSveComputeExpF32Kernel( + const float* Input, + float* Output, + size_t N +); + +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveComputeSumExpVector( + MLAS_SVBOOL Pred, + MLAS_SVFLOAT32 Vector, + MLAS_SVFLOAT32 NegativeMaximumVector +); + +float +MLASCALL +MlasSveComputeSumExpF32Kernel( + const float* Input, + float* Output, + size_t N, + const float* NegativeMaximum +); + +float MLASCALL +MlasSveReduceMaximumF32Kernel( + const float* Input, + size_t N +); + +void +MLASCALL +MlasSveReduceMinimumMaximumF32Kernel( + const float* Input, + float* Min, + float* Max, + size_t N +); + +void +MLASCALL +MlasSveComputeSoftmaxOutputF32Kernel( + float* Output, + size_t N, + const float* Parameters +); + +void +MLASCALL +MlasSveComputeLogSoftmaxOutputF32Kernel( + const float* Input, + float* Output, + size_t N, + const float* Parameters +); + +void +MLASCALL +MlasSveErfKernel( + const float* Input, + float* Output, + size_t N +); + +void +MLASCALL +MlasSveLogisticKernel( + const float* Input, + float* Output, + size_t N +); + +//MLAS API for SVE intrinsics + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveReinterpretAsInt32(MLAS_SVFLOAT32 Vector) +{ + return svreinterpret_s32_f32(Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT32 +MlasSveReinterpretAsUInt32(MLAS_SVFLOAT32 Vector) +{ + return svreinterpret_u32_f32(Vector); +} + +// Reinterprets an unsigned 32-bit vector as a 32-bit floating-point vector. +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveReinterpretAsFLOAT32(MLAS_SVUINT32 Vector) +{ + return svreinterpret_f32_u32(Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveCastToInt32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + return svcvt_s32_f32_z(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveCastToFloat32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector) +{ + return svcvt_f32_s32_z(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveBroadcastInt32(int32_t Value) +{ + return svdup_n_s32(Value); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveLoadInt32(MLAS_SVBOOL Pred, const int32_t* Buffer) +{ + return svld1_s32(Pred, Buffer); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreInt32(MLAS_SVBOOL Pred, int32_t* Buffer, MLAS_SVINT32 Vector) +{ + svst1_s32(Pred, Buffer, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveAddInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svadd_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveSubtractInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svsub_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveAndInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svand_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT32 +MlasSveAndUInt32(MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector1, MLAS_SVUINT32 Vector2) +{ + return svand_u32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveOrInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svorr_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveAndNotInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 VectorNot, MLAS_SVINT32 Vector) +{ + return svand_s32_m(Pred, svnot_s32_z(Pred, VectorNot), Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveXorInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return sveor_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveBlendInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2, MLAS_SVINT32 Selection) +{ + return MlasSveOrInt32( + Pred, + MlasSveAndInt32(Pred, Vector2, Selection), + MlasSveAndNotInt32(Pred, Selection, Vector1) + ); +} + +template +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT32 +MlasSveShiftLeftUInt32(MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector) +{ + return svlsl_n_u32_z(Pred, Vector, ShiftCount); +} + +template +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveShiftLeftInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector) +{ + return svlsl_n_s32_z(Pred, Vector, ShiftCount); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT32 +MlasSveShiftRightInt32(MLAS_SVBOOL Pred, MLAS_SVUINT32 Vector, uint ShiftCount) +{ + return svlsr_n_u32_m(Pred, Vector, ShiftCount); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveMaximumInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svmax_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVINT32 +MlasSveMinimumInt32(MLAS_SVBOOL Pred, MLAS_SVINT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svmin_s32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveReinterpretAsFloat32(MLAS_SVINT32 Vector) +{ + return svreinterpret_f32_s32(Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveBroadcastFloat32(float Value) +{ + return svdup_n_f32(Value); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVUINT32 +MlasSveBroadcastUINT32(uint Value) +{ + return svdup_n_u32(Value); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveBroadcastFloat32(const float* Value) +{ + return svld1_f32(svptrue_b32(), Value); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveZeroFloat32(void) +{ + return svdup_n_f32(0.0f); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveLoadFloat32(MLAS_SVBOOL Pred, const float* Buffer) +{ + return svld1_f32(Pred, Buffer); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreFloat32(MLAS_SVBOOL Pred, float* Buffer, MLAS_SVFLOAT32 Vector) +{ + svst1_f32(Pred, Buffer, Vector); +} + +template +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreLaneFloat32(float* Buffer, MLAS_SVFLOAT32 Vector) +{ + svbool_t Pred = svwhilelt_b32(Lane, Lane + 1); + svst1_f32(Pred, Buffer, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +void +MlasSveStoreLowHalfFloat32(float* Buffer, MLAS_SVFLOAT32 Vector) +{ + svbool_t Pred = svwhilelt_b32(0, (int32_t)svcntw() / 2); + svst1_f32(Pred, Buffer, Vector); +} + +template +MLAS_SVE_TARGET +MLAS_FORCEINLINE +float +MlasSveExtractLaneFloat32(MLAS_SVFLOAT32 Vector) +{ + float TmpBuffer[1]; + svbool_t Pred = svwhilelt_b32(Lane, Lane + 1); + svst1_f32(Pred, TmpBuffer, Vector); + return TmpBuffer[0]; +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveInterleaveLowFloat32(MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svzip1_f32(Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveInterleaveHighFloat32(MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svzip2_f32(Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveAddFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svadd_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveSubtractFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svsub_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMultiplyFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svmul_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveExpFloat32(MLAS_SVUINT32 Vector) +{ + return svexpa_f32(Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveScaleFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVINT32 Vector2) +{ + return svscale_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveRoundINTFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + return svrintm_f32_z(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMultiplyAddFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2, MLAS_SVFLOAT32 Vector3) +{ + return svmla_f32_m(Pred, Vector3, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMultiplyAddFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, float Scalar2, MLAS_SVFLOAT32 Vector3) +{ + return MlasSveMultiplyAddFloat32(Pred, Vector1, MlasSveBroadcastFloat32(Scalar2), Vector3); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMultiplyAddFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2, float Scalar3) +{ + return MlasSveMultiplyAddFloat32(Pred, Vector1, Vector2, MlasSveBroadcastFloat32(Scalar3)); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveDivideFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svdiv_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveGreaterThanFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + // Compare Vector1 and Vector2, return a predicate vector + svbool_t cmp_mask = svcmpgt_f32(Pred, Vector1, Vector2); + + //Convert predicate to uint32_t mask + svuint32_t mask_bits = svdup_u32_z(cmp_mask, 0xFFFFFFFF); + + //Reinterpret to float32 + return svreinterpret_f32_u32(mask_bits); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveAndFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return MlasSveReinterpretAsFloat32( + MlasSveAndInt32( + Pred, + MlasSveReinterpretAsInt32(Vector1), + MlasSveReinterpretAsInt32(Vector2) + ) + ); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveOrFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return MlasSveReinterpretAsFloat32( + MlasSveOrInt32( + Pred, + MlasSveReinterpretAsInt32(Vector1), + MlasSveReinterpretAsInt32(Vector2) + ) + ); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveAndNotFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return MlasSveReinterpretAsFloat32( + MlasSveAndNotInt32( + Pred, + MlasSveReinterpretAsInt32(Vector1), + MlasSveReinterpretAsInt32(Vector2) + ) + ); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveXorFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return MlasSveReinterpretAsFloat32( + MlasSveXorInt32( + Pred, + MlasSveReinterpretAsInt32(Vector1), + MlasSveReinterpretAsInt32(Vector2) + ) + ); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveBlendFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2, MLAS_SVFLOAT32 Selection) +{ + return MlasSveOrFloat32( + Pred, + MlasSveAndFloat32(Pred, Vector2, Selection), + MlasSveAndFloat32(Pred, Vector1, Selection) + ); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMaximumFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svmax_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveMinimumFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector1, MLAS_SVFLOAT32 Vector2) +{ + return svmin_f32_m(Pred, Vector1, Vector2); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveClampFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Value, float LowerRange, float UpperRange) +{ + Value = MlasSveMaximumFloat32(Pred, MlasSveBroadcastFloat32(LowerRange), Value); + Value = MlasSveMinimumFloat32(Pred, MlasSveBroadcastFloat32(UpperRange), Value); + return Value; +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +float +MlasSveReduceAddFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + return svaddv_f32(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +float +MlasSveReduceMaximumFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + return svmaxv_f32(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +float +MlasSveReduceMinimumFloat32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + return svminv_f32(Pred, Vector); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSvePowerOf2Float32(MLAS_SVBOOL Pred, MLAS_SVFLOAT32 Vector) +{ + MLAS_SVINT32 emm0 = MlasSveAddInt32( + Pred, + MlasSveCastToInt32(Pred, Vector), + MlasSveBroadcastInt32(127) + ); + return MlasSveReinterpretAsFloat32(MlasSveShiftLeftInt32<23>(Pred, emm0)); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVFLOAT32 +MlasSveSelect(svbool_t Pred, MLAS_SVFLOAT32 TrueValue, MLAS_SVFLOAT32 FalseValue) +{ + return svsel_f32(Pred, TrueValue, FalseValue); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVBOOL +MlasSveCompareLessThan(svbool_t Pred, MLAS_SVFLOAT32 A, MLAS_SVFLOAT32 B) +{ + return svcmplt_f32(Pred, A, B); +} + +MLAS_SVE_TARGET +MLAS_FORCEINLINE +MLAS_SVBOOL +MlasSveCompareGreaterThan(svbool_t Pred, MLAS_SVFLOAT32 A, MLAS_SVFLOAT32 B) +{ + return svcmpgt_f32(Pred, A, B); +} + +// GCC: Pop options after SVE-specific functions +#ifndef __clang__ +#pragma GCC pop_options +#endif + +#endif + diff --git a/src/lib/transpose.cpp b/src/lib/transpose.cpp index 61c3796..0c471e8 100644 --- a/src/lib/transpose.cpp +++ b/src/lib/transpose.cpp @@ -385,6 +385,180 @@ MlasTranspose16x16Block( vec_vsx_st(e0, 0, &Output[OutputStride * 14]); vec_vsx_st(e1, 0, &Output[OutputStride * 15]); } +#elif defined(MLAS_TARGET_S390X) + +MLAS_FORCEINLINE +void +MlasTranspose4x4Block( + const uint32_t* Input, + size_t InputStride, + uint32_t* Output, + size_t OutputStride + ) +{ + const __vector unsigned char mask0 = { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23 }; + const __vector unsigned char mask3 = { 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 }; + + __vector unsigned int a0 = vec_xl(0, Input); + __vector unsigned int a1 = vec_xl(0, &Input[InputStride]); + __vector unsigned int a2 = vec_xl(0, &Input[InputStride * 2]); + __vector unsigned int a3 = vec_xl(0, &Input[InputStride * 3]); + + __vector unsigned int b0 = vec_mergeh(a0, a1); + __vector unsigned int b1 = vec_mergeh(a2, a3); + __vector unsigned int b2 = vec_mergel(a0, a1); + __vector unsigned int b3 = vec_mergel(a2, a3); + + __vector unsigned int c0 = vec_perm(b0, b1, mask0); + __vector unsigned int c1 = vec_perm(b0, b1, mask3); + __vector unsigned int c2 = vec_perm(b2, b3, mask0); + __vector unsigned int c3 = vec_perm(b2, b3, mask3); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(c0); + MLAS_UNREFERENCED_PARAMETER(c1); + MLAS_UNREFERENCED_PARAMETER(c2); + MLAS_UNREFERENCED_PARAMETER(c3); + + vec_xst(c0, 0, Output); + vec_xst(c1, 0, &Output[OutputStride]); + vec_xst(c2, 0, &Output[OutputStride * 2]); + vec_xst(c3, 0, &Output[OutputStride * 3]); +} + +MLAS_FORCEINLINE +void +MlasTranspose16x16Block( + const uint8_t* Input, + size_t InputStride, + uint8_t* Output, + size_t OutputStride + ) +{ + const __vector unsigned char mask0 = { 0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23 }; + const __vector unsigned char mask3 = { 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31 }; + + __vector unsigned char a0 = vec_xl(0, Input); + __vector unsigned char a1 = vec_xl(0, &Input[InputStride]); + __vector unsigned char a2 = vec_xl(0, &Input[InputStride * 2]); + __vector unsigned char a3 = vec_xl(0, &Input[InputStride * 3]); + __vector unsigned char a4 = vec_xl(0, &Input[InputStride * 4]); + __vector unsigned char a5 = vec_xl(0, &Input[InputStride * 5]); + __vector unsigned char a6 = vec_xl(0, &Input[InputStride * 6]); + __vector unsigned char a7 = vec_xl(0, &Input[InputStride * 7]); + __vector unsigned char a8 = vec_xl(0, &Input[InputStride * 8]); + __vector unsigned char a9 = vec_xl(0, &Input[InputStride * 9]); + __vector unsigned char a10 = vec_xl(0, &Input[InputStride * 10]); + __vector unsigned char a11 = vec_xl(0, &Input[InputStride * 11]); + __vector unsigned char a12 = vec_xl(0, &Input[InputStride * 12]); + __vector unsigned char a13 = vec_xl(0, &Input[InputStride * 13]); + __vector unsigned char a14 = vec_xl(0, &Input[InputStride * 14]); + __vector unsigned char a15 = vec_xl(0, &Input[InputStride * 15]); + + __vector unsigned char b0 = vec_mergeh(a0, a1); + __vector unsigned char b1 = vec_mergeh(a2, a3); + __vector unsigned char b2 = vec_mergeh(a4, a5); + __vector unsigned char b3 = vec_mergeh(a6, a7); + __vector unsigned char b4 = vec_mergeh(a8, a9); + __vector unsigned char b5 = vec_mergeh(a10, a11); + __vector unsigned char b6 = vec_mergeh(a12, a13); + __vector unsigned char b7 = vec_mergeh(a14, a15); + __vector unsigned char c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + __vector unsigned char c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + __vector unsigned char c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + __vector unsigned char c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(c0); + MLAS_UNREFERENCED_PARAMETER(c1); + MLAS_UNREFERENCED_PARAMETER(c2); + MLAS_UNREFERENCED_PARAMETER(c3); + + __vector unsigned char d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + __vector unsigned char d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + __vector unsigned char e0 = vec_perm(d0, d1, mask0); + __vector unsigned char e1 = vec_perm(d0, d1, mask3); + + // Workaround to avoid 'variable set but not used' message + MLAS_UNREFERENCED_PARAMETER(e0); + MLAS_UNREFERENCED_PARAMETER(e1); + + vec_xst(e0, 0, &Output[0]); + vec_xst(e1, 0, &Output[OutputStride]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 2]); + vec_xst(e1, 0, &Output[OutputStride * 3]); + + c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 4]); + vec_xst(e1, 0, &Output[OutputStride * 5]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 6]); + vec_xst(e1, 0, &Output[OutputStride * 7]); + + b0 = vec_mergel(a0, a1); + b1 = vec_mergel(a2, a3); + b2 = vec_mergel(a4, a5); + b3 = vec_mergel(a6, a7); + b4 = vec_mergel(a8, a9); + b5 = vec_mergel(a10, a11); + b6 = vec_mergel(a12, a13); + b7 = vec_mergel(a14, a15); + + c0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + c1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + c2 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + c3 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 8]); + vec_xst(e1, 0, &Output[OutputStride * 9]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 10]); + vec_xst(e1, 0, &Output[OutputStride * 11]); + + c0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b0), reinterpret_cast<__vector unsigned short>(b1))); + c1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b2), reinterpret_cast<__vector unsigned short>(b3))); + c2 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b4), reinterpret_cast<__vector unsigned short>(b5))); + c3 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned short>(b6), reinterpret_cast<__vector unsigned short>(b7))); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergeh(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 12]); + vec_xst(e1, 0, &Output[OutputStride * 13]); + + d0 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c0), reinterpret_cast<__vector unsigned int>(c1))); + d1 = reinterpret_cast<__vector unsigned char>(vec_mergel(reinterpret_cast<__vector unsigned int>(c2), reinterpret_cast<__vector unsigned int>(c3))); + e0 = vec_perm(d0, d1, mask0); + e1 = vec_perm(d0, d1, mask3); + vec_xst(e0, 0, &Output[OutputStride * 14]); + vec_xst(e1, 0, &Output[OutputStride * 15]); +} #elif defined(MLAS_LSX_INTRINSICS) @@ -523,7 +697,7 @@ MlasTranspose4xNVector( Output[OutputStride * 3] = a3; } -#if defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) template MLAS_FORCEINLINE void @@ -620,7 +794,7 @@ MlasTransposeThreaded( size_t m = CountM; #if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS) || defined(MLAS_TARGET_POWER) || \ - defined(MLAS_LSX_INTRINSICS) + defined(MLAS_TARGET_S390X) || defined(MLAS_LSX_INTRINSICS) while (m >= 4) { @@ -818,7 +992,7 @@ MlasTransposeThreaded( size_t n = N; -#if defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) while (n >= 16) { const uint8_t* s = Input; diff --git a/src/ort_include/core/common/common.h b/src/ort_include/core/common/common.h index adfd341..820d140 100644 --- a/src/ort_include/core/common/common.h +++ b/src/ort_include/core/common/common.h @@ -294,12 +294,26 @@ inline std::string ToUTF8String(const std::string& s) { return s; } /** * Convert a wide character string to a UTF-8 string */ -std::string ToUTF8String(const std::wstring& s); - -std::wstring ToWideString(const std::string& s); +std::string ToUTF8String(std::wstring_view s); +inline std::string ToUTF8String(const wchar_t* s) { + return ToUTF8String(std::wstring_view{s}); +} +inline std::string ToUTF8String(const std::wstring& s) { + return ToUTF8String(std::wstring_view{s}); +} +std::wstring ToWideString(std::string_view s); +inline std::wstring ToWideString(const char* s) { + return ToWideString(std::string_view{s}); +} +inline std::wstring ToWideString(const std::string& s) { + return ToWideString(std::string_view{s}); +} inline std::wstring ToWideString(const std::wstring& s) { return s; } +inline std::wstring ToWideString(std::wstring_view s) { return std::wstring{s}; } #else inline std::string ToWideString(const std::string& s) { return s; } +inline std::string ToWideString(const char* s) { return s; } +inline std::string ToWideString(std::string_view s) { return std::string{s}; } #endif constexpr size_t kMaxStrLen = 4096; diff --git a/src/ort_include/core/common/const_pointer_container.h b/src/ort_include/core/common/const_pointer_container.h index 1d821ba..80343b4 100644 --- a/src/ort_include/core/common/const_pointer_container.h +++ b/src/ort_include/core/common/const_pointer_container.h @@ -79,6 +79,10 @@ class ConstPointerContainer { return data_[index]; } + const T* const* data() const { + return data_.data(); + } + private: const Container& data_; }; diff --git a/src/ort_include/core/common/cpuid_arch_definition.h b/src/ort_include/core/common/cpuid_arch_definition.h index a541eb6..5946b8c 100644 --- a/src/ort_include/core/common/cpuid_arch_definition.h +++ b/src/ort_include/core/common/cpuid_arch_definition.h @@ -9,6 +9,6 @@ #define CPUIDINFO_ARCH_X86 #endif -#if defined(_M_ARM64) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) +#if defined(_M_ARM64) || defined(_M_ARM64EC) || defined(__aarch64__) || defined(_M_ARM) || defined(__arm__) #define CPUIDINFO_ARCH_ARM #endif // ARM or ARM64 diff --git a/src/ort_include/core/common/cpuid_info.h b/src/ort_include/core/common/cpuid_info.h index 9c67ebb..9c40627 100644 --- a/src/ort_include/core/common/cpuid_info.h +++ b/src/ort_include/core/common/cpuid_info.h @@ -38,8 +38,11 @@ class CPUIDInfo { // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + bool HasArmSve() const { return has_arm_sve_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + bool HasArm_SME() const { return has_arm_sme_; } + bool HasArm_SME2() const { return has_arm_sme2_; } uint32_t GetCurrentCoreIdx() const; @@ -102,7 +105,40 @@ class CPUIDInfo { } private: + // Log function that uses ORT logging if available or writes to stderr. + // This enables us to log even before ORT logging has been initialized. + static void LogEarlyWarning(std::string_view message); + CPUIDInfo(); + + void VendorInfoInit(); + +#if defined(CPUIDINFO_ARCH_X86) + + void X86Init(); + +#elif defined(CPUIDINFO_ARCH_ARM) + +#if defined(__linux__) + + void ArmLinuxInit(); + +#elif defined(_WIN32) + + void ArmWindowsInit(); + +#elif defined(__APPLE__) + + void ArmAppleInit(); + +#endif + +#endif // defined(CPUIDINFO_ARCH_ARM) + +#if defined(CPUINFO_SUPPORTED) + bool pytorch_cpuinfo_init_{false}; +#endif // defined(CPUINFO_SUPPORTED) + bool has_amx_bf16_{false}; bool has_avx_{false}; bool has_avx2_{false}; @@ -125,42 +161,14 @@ class CPUIDInfo { bool has_arm_neon_dot_{false}; bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_{false}; bool has_arm_sve_i8mm_{false}; bool has_arm_neon_bf16_{false}; + bool has_arm_sme_{false}; + bool has_arm_sme2_{false}; std::string vendor_; uint32_t vendor_id_; - - uint32_t GetVendorId(const std::string& vendor); - -#if defined(CPUIDINFO_ARCH_X86) - - void X86Init(); - std::string GetX86Vendor(int32_t* data); - -#elif defined(CPUIDINFO_ARCH_ARM) - -#if defined(CPUINFO_SUPPORTED) - // Now the following var is only used in ARM build, but later on we may expand the usage. - bool pytorch_cpuinfo_init_{false}; -#endif // defined(CPUINFO_SUPPORTED) - -#if defined(__linux__) - - void ArmLinuxInit(); - -#elif defined(_WIN32) - - void ArmWindowsInit(); - std::string GetArmWindowsVendor(); - -#elif defined(__APPLE__) - - void ArmAppleInit(); - -#endif - -#endif // defined(CPUIDINFO_ARCH_ARM) }; -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/src/ort_include/core/framework/endian.h b/src/ort_include/core/common/endian.h similarity index 100% rename from src/ort_include/core/framework/endian.h rename to src/ort_include/core/common/endian.h diff --git a/src/ort_include/core/framework/float16.h b/src/ort_include/core/common/float16.h similarity index 99% rename from src/ort_include/core/framework/float16.h rename to src/ort_include/core/common/float16.h index 97420ff..7cca9be 100644 --- a/src/ort_include/core/framework/float16.h +++ b/src/ort_include/core/common/float16.h @@ -4,9 +4,9 @@ #include -#include "endian.h" +#include "core/common/endian.h" #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 -#include "cuda_bf16.h" +#include "cuda_bf16.h" // from CUDA SDK #endif #if !defined(__CUDACC__) && !defined(__HIPCC__) diff --git a/src/ort_include/core/common/float8.h b/src/ort_include/core/common/float8.h new file mode 100644 index 0000000..7dde1d0 --- /dev/null +++ b/src/ort_include/core/common/float8.h @@ -0,0 +1,937 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// IMPORTANT NOTE: Users of this file MUST include "cuda.h" before including this header +// if they would like to leverage the CUDA implementation for the conversion routines +// in their HOST code (code compiled by MSVC/GCC). +// This is because there is a check on CUDA_VERSION which is a macro defined in cuda.h. +// We can't include cuda.h in this header unconditionally because this header is also +// included in core framework files which are CUDA-agnostic. +// Not including "cuda.h" in GCC/MSVC will fall-back to the CPU conversion routines +// implemented in this file. +// For code compiled by NVCC which includes this header, this file will automatically +// include cuda.h (based on the CUDA_CC macro). + +#pragma once + +#if !defined(DISABLE_FLOAT8_TYPES) + +#include "core/common/endian.h" + +#if defined(__CUDACC__) +// Needed for CUDA_VERSION check below +#include +#endif + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 +#include "cuda_fp8.h" +#endif + +#if !defined(__CUDACC__) && !defined(__HIPCC__) +#include "core/common/narrow.h" +#endif + +#include "core/common/common.h" + +namespace onnxruntime { + +#if defined(__CUDACC__) || defined(__HIPCC__) +#define ORT_HOST_DEVICE __host__ __device__ +#else +#define ORT_HOST_DEVICE +#endif + +// Float8E4M3FN +struct Float8E4M3FN { + uint8_t val{0}; +#if defined(__HIP__) + ORT_HOST_DEVICE Float8E4M3FN() = default; +#else + Float8E4M3FN() = default; +#endif + struct FromBitsT {}; + static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } + constexpr ORT_HOST_DEVICE Float8E4M3FN(unsigned char bits, FromBitsT) : val(bits) {} + + inline explicit ORT_HOST_DEVICE Float8E4M3FN(float v, bool saturate = true) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + val = __nv_cvt_float_to_fp8(v, saturate ? __NV_SATFINITE : __NV_NOSAT, __NV_E4M3); +#else + uint32_t b; + std::memcpy(&b, &v, sizeof(b)); + + val = static_cast((b & 0x80000000) >> 24); // sign + if ((b & 0x7fffffff) == 0x7f800000) { // infinity + if (saturate) { + val |= 126; + } else { + val |= 0x7f; + } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val |= 0x7f; + } else { + uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent + uint32_t m = static_cast(b & 0x007FFFFF); // mantissa + if (e != 0) { + if (e < 117) { + } else if (e < 121) { + // denormalized number + auto d = 120 - e; + if (d < 3) { + val |= 1 << (2 - d); + val |= m >> (21 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (20 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + // rounding + val += 1; + } + } else if (e < 136) { + // normalized number + auto ex = e - 120; + if (ex == 0) { + val |= 0x4; + val |= m >> 21; + } else { + val |= ex << 3; + val |= m >> 20; + if ((val & 0x7F) == 0x7F) { + val &= 0xFE; + } + } + if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { + if ((val & 0x7F) < 0x7E) { + // rounding + val += 1; + } else if (!saturate) { + val |= 0x7F; + } + } + } else if (saturate) { + val |= 126; // 0b01111110 + } else { + val |= 0x7F; + } + } + } +#endif + } + + inline ORT_HOST_DEVICE bool IsNaN() const { + return (val & 0b01111111) == 0b01111111; + } + + inline ORT_HOST_DEVICE float ToFloat() const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + return __half2float(__nv_cvt_fp8_to_halfraw(val, __NV_E4M3)); +#else + uint32_t res; + if (val == 255) { + res = 0xffc00000; + } else if (val == 127) { + res = 0x7fc00000; + } else { + uint32_t expo = (val & 0x78) >> 3; + uint32_t mant = val & 0x07; + uint32_t sign = val & 0x80; + res = sign << 24; + if (expo == 0) { + if (mant > 0) { + expo = 0x7F - 7; + if ((mant & 0x4) == 0) { + mant &= 0x3; + mant <<= 1; + expo -= 1; + } + if ((mant & 0x4) == 0) { + mant &= 0x3; + mant <<= 1; + expo -= 1; + } + res |= (mant & 0x3) << 21; + res |= expo << 23; + } + } else { + res |= mant << 20; + expo -= 0x7; + expo += 0x7F; + res |= expo << 23; + } + } + float float_res; + std::memcpy(&float_res, &res, sizeof(float)); + return float_res; +#endif + } + + inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + explicit ORT_HOST_DEVICE Float8E4M3FN(const __nv_fp8_e4m3& value) { val = *reinterpret_cast(&value); } + explicit ORT_HOST_DEVICE operator __nv_fp8_e4m3() const { return *reinterpret_cast(&val); } +#endif +}; + +inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val == right.val; } +inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val != right.val; } +inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FN& left, const Float8E4M3FN& right) { return left.val < right.val; } + +// User defined suffixes to make it easier to declare +// initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char +#if !defined(__CUDACC__) && !defined(__HIPCC__) + +inline Float8E4M3FN operator""_f8e4m3fn(unsigned long long int v) { + return Float8E4M3FN(narrow(v), Float8E4M3FN::FromBits()); +} + +inline Float8E4M3FN operator""_f8e4m3fnp8(long double v) { + return Float8E4M3FN(static_cast(v), true); +} + +#endif + +inline void Float8E4M3FNToFloat(const Float8E4M3FN* blf, float* flt, size_t size) { + auto src = blf; + auto d = flt; + for (; size != 0; ++src, ++d, --size) { + *d = src->ToFloat(); + } +} + +inline void FloatToFloat8E4M3FN(const float* flt, Float8E4M3FN* blf, size_t size, bool saturate) { + auto src = flt; + auto d = blf; + for (; size != 0; ++src, ++d, --size) { + new (d) Float8E4M3FN(*src, saturate); + } +} + +// Float8E4M3FNUZ +struct Float8E4M3FNUZ { + uint8_t val{0}; +#if defined(__HIP__) + ORT_HOST_DEVICE Float8E4M3FNUZ() = default; +#else + Float8E4M3FNUZ() = default; +#endif + + struct FromBitsT {}; + static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } + constexpr ORT_HOST_DEVICE Float8E4M3FNUZ(unsigned char bits, FromBitsT) : val(bits) {} + + inline explicit ORT_HOST_DEVICE Float8E4M3FNUZ(float v, bool saturate = true) { + // This type does not exist on CUDA. + uint32_t b; + std::memcpy(&b, &v, sizeof(b)); + + val = static_cast((b & 0x80000000) >> 24); // sign + if ((b & 0x7fffffff) == 0x7f800000) { // infinity + if (saturate) { + // the highest available value + val |= 0x7F; + } else { + // NaN + val = 0x80; + } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val = 0x80; + } else { + uint8_t e = static_cast((b & 0x7F800000) >> 23); // exponent + uint32_t m = static_cast(b & 0x007FFFFF); // mantissa + + if (e < 116) { + // all near-zero numbers round to positive zero: + val = 0; + } else if (e < 120) { + // denormalized number + auto d = 119 - e; + if (d < 3) { + val |= 1 << (2 - d); + val |= m >> (21 + d); + } else if (m > 0) { + val |= 1; + } else { + // round to positive zero: + val = 0; + } + auto mask = 1 << (20 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + // rounding + val += 1; + } + } else if (e < 135) { + // normalized number + auto ex = e - 119; + if (ex == 0) { + val |= 0x4; + val |= m >> 21; + } else { + val |= ex << 3; + val |= m >> 20; + } + if ((m & 0x80000) && ((m & 0x100000) || (m & 0x7FFFF))) { + if ((val & 0x7F) < 0x7F) { + // rounding + val += 1; + } else if (!saturate) { + val = 0x80; + } + } + } else if (saturate) { + val |= 0x7F; + } else { + val = 0x80; + } + } + } + + inline ORT_HOST_DEVICE bool IsNaN() const { + return val == 0b10000000; + } + + inline ORT_HOST_DEVICE float ToFloat() const { + // This type does not exist on CUDA. + uint32_t res; + if (val == 0x80) { + res = 0xffc00000; + } else { + uint32_t expo = (val & 0x78) >> 3; + uint32_t mant = val & 0x07; + uint32_t sign = val & 0x80; + res = sign << 24; + if (expo == 0) { + if (mant > 0) { + expo = 0x7F - 8; + if ((mant & 0x4) == 0) { + mant &= 0x3; + mant <<= 1; + expo -= 1; + } + if ((mant & 0x4) == 0) { + mant &= 0x3; + mant <<= 1; + expo -= 1; + } + res |= (mant & 0x3) << 21; + res |= expo << 23; + } + } else { + res |= mant << 20; + expo -= 8; + expo += 0x7F; + res |= expo << 23; + } + } + float float_res; + std::memcpy(&float_res, &res, sizeof(float)); + return float_res; + } + + inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } +}; + +inline ORT_HOST_DEVICE bool operator==(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val == right.val; } +inline ORT_HOST_DEVICE bool operator!=(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val != right.val; } +inline ORT_HOST_DEVICE bool operator<(const Float8E4M3FNUZ& left, const Float8E4M3FNUZ& right) { return left.val < right.val; } + +// User defined suffixes to make it easier to declare +// initializers with MLFloat8E4M3FN and Float8E4M3FN from unsigned char +#if !defined(__CUDACC__) && !defined(__HIPCC__) + +inline Float8E4M3FNUZ operator""_f8e4m3p8fnuz(unsigned long long int v) { + return Float8E4M3FNUZ(narrow(v), Float8E4M3FNUZ::FromBits()); +} + +inline Float8E4M3FNUZ operator""_f8e4m3fnuzp8(long double v) { + return Float8E4M3FNUZ(static_cast(v), true); +} + +#endif + +inline void Float8E4M3FNUZToFloat(const Float8E4M3FNUZ* blf, float* flt, size_t size) { + auto src = blf; + auto d = flt; + for (; size != 0; ++src, ++d, --size) { + *d = src->ToFloat(); + } +} + +inline void FloatToFloat8E4M3FNUZ(const float* flt, Float8E4M3FNUZ* blf, size_t size, bool saturate) { + auto src = flt; + auto d = blf; + for (; size != 0; ++src, ++d, --size) { + new (d) Float8E4M3FNUZ(*src, saturate); + } +} + +// Float8E5M2 +struct Float8E5M2 { + uint8_t val{0}; +#if defined(__HIP__) + ORT_HOST_DEVICE Float8E5M2() = default; +#else + Float8E5M2() = default; +#endif + + struct FromBitsT {}; + static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } + constexpr ORT_HOST_DEVICE Float8E5M2(unsigned char bits, FromBitsT) : val(bits) {} + + inline explicit ORT_HOST_DEVICE Float8E5M2(float v, bool saturate = true) { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + val = __nv_cvt_float_to_fp8(v, saturate ? __NV_SATFINITE : __NV_NOSAT, __NV_E5M2); +#else + uint32_t b; + std::memcpy(&b, &v, sizeof(b)); + + val = (b & 0x80000000) >> 24; // sign + if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf + if (saturate) { + // the highest available value + val |= 0x7B; + } else { + // the infinity + val |= 0x7C; + } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val |= 0x7f; + } else { + uint32_t e = (b & 0x7F800000) >> 23; // exponent + uint32_t m = b & 0x007FFFFF; // mantissa + + if (e != 0) { + if (e < 110) { + } else if (e < 113) { + // denormalized number + auto d = 112 - e; + if (d < 2) { + val |= 1 << (1 - d); + val |= m >> (22 + d); + } else if (m > 0) { + val |= 1; + } + auto mask = 1 << (21 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + // rounding + val += 1; + } + } else if (e < 143) { // 127 + 15 + 1 + auto ex = e - 112; // 127 - 15 + val |= ex << 2; + val |= m >> 21; + if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) { + if ((val & 0x7F) < 0x7B) { + // rounding + val += 1; + } else if (saturate) { + val |= 0x7B; + } else { + val |= 0x7C; + } + } + } else if (saturate) { + val |= 0x7B; + } else { + val |= 0x7C; + } + } + } +#endif + } + + inline ORT_HOST_DEVICE bool IsNaN() const { + // 7D, 7E, 7F are positive NaNs; FD, FE, FF are negative NaNs + return (val & 0b01111111) > 0b01111100; + } + + inline ORT_HOST_DEVICE bool IsInfinity() const { + // 7C and FC are infinity + return (val & 0b01111111) == 0b01111100; + } + + inline ORT_HOST_DEVICE float ToFloat() const { +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + return __half2float(__nv_cvt_fp8_to_halfraw(val, __NV_E5M2)); +#else + uint32_t res; + if (val >= 253) { + res = 0xffc00000; + } else if (val >= 125 && val <= 127) { + res = 0x7fc00000; + } else if (val == 252) { + res = 0xff800000; + } else if (val == 124) { + res = 0x7f800000; + } else { + uint32_t expo = (val & 0x7C) >> 2; + uint32_t mant = val & 0x03; + uint32_t sign = val & 0x80; + res = sign << 24; + if (expo == 0) { + if (mant > 0) { + expo = 0x7F - 15; + if ((mant & 0x2) == 0) { + mant &= 0x1; + mant <<= 1; + expo -= 1; + } + res |= (mant & 0x1) << 22; + res |= expo << 23; + } + } else { + res |= mant << 21; + expo -= 15; + expo += 0x7F; + res |= expo << 23; + } + } + + float float_res; + std::memcpy(&float_res, &res, sizeof(float)); + return float_res; +#endif + } + + inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } + +#if defined(CUDA_VERSION) && CUDA_VERSION >= 11080 + ORT_HOST_DEVICE Float8E5M2(const __nv_fp8_e5m2& value) { val = *reinterpret_cast(&value); } + explicit ORT_HOST_DEVICE operator __nv_fp8_e5m2() const { return *reinterpret_cast(&val); } +#endif +}; + +inline ORT_HOST_DEVICE bool operator==(const Float8E5M2& left, const Float8E5M2& right) { return left.val == right.val; } +inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2& left, const Float8E5M2& right) { return left.val != right.val; } +inline ORT_HOST_DEVICE bool operator<(const Float8E5M2& left, const Float8E5M2& right) { return left.val < right.val; } + +// User defined suffixes to make it easier to declare +// initializers with MLFloat8E5M2 and Float8E5M2 from unsigned char +#if !defined(__CUDACC__) && !defined(__HIPCC__) + +inline Float8E5M2 operator""_f8e5m2fn(unsigned long long int v) { + return Float8E5M2(narrow(v), Float8E5M2::FromBits()); +} + +inline Float8E5M2 operator""_f8e5m2fnp8(long double v) { + return Float8E5M2(static_cast(v), true); +} + +#endif + +inline void Float8E5M2ToFloat(const Float8E5M2* blf, float* flt, size_t size) { + auto src = blf; + auto d = flt; + for (; size != 0; ++src, ++d, --size) { + *d = src->ToFloat(); + } +} + +inline void FloatToFloat8E5M2(const float* flt, Float8E5M2* blf, size_t size, bool saturate) { + auto src = flt; + auto d = blf; + for (; size != 0; ++src, ++d, --size) { + new (d) Float8E5M2(*src, saturate); + } +} + +// Float8E5M2FNUZ +struct Float8E5M2FNUZ { + uint8_t val{0}; +#if defined(__HIP__) + ORT_HOST_DEVICE Float8E5M2FNUZ() = default; +#else + Float8E5M2FNUZ() = default; +#endif + + struct FromBitsT {}; + static constexpr ORT_HOST_DEVICE FromBitsT FromBits() { return FromBitsT(); } + constexpr ORT_HOST_DEVICE Float8E5M2FNUZ(unsigned char bits, FromBitsT) : val(bits) {} + + inline explicit ORT_HOST_DEVICE Float8E5M2FNUZ(float v, bool saturate = true) { + // This type does not exist on CUDA. + uint32_t b; + std::memcpy(&b, &v, sizeof(b)); + + val = (b & 0x80000000) >> 24; // sign + if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf + if (saturate) { + val |= 0x7F; + } else { + val = 0x80; + } + } else if ((b & 0x7F800000) == 0x7F800000) { // NaN + val = 0x80; + } else { + uint32_t e = (b & 0x7F800000) >> 23; // exponent + uint32_t m = b & 0x007FFFFF; // mantissa + + if (e < 109) { + // all near-zero numbers round to positive zero: + val = 0; + } else if (e < 112) { + // denormalized number + auto d = 111 - e; + if (d < 2) { + val |= 1 << (1 - d); + val |= m >> (22 + d); + } else if (m > 0) { + val |= 1; + } else { + // round to positive zero: + val = 0; + } + auto mask = 1 << (21 + d); + if ((m & mask) && ((val & 1) || ((m & (mask - 1)) > 0) || ((m & mask) && (m & (mask << 1)) && ((m & (mask - 1)) == 0)))) { + // rounding + val += 1; + } + } else if (e < 143) { + // normalized number + auto ex = e - 111; + val |= ex << 2; + val |= m >> 21; + if ((m & 0x100000) && ((m & 0xFFFFF) || (m & 0x200000))) { + if ((val & 0x7F) < 0x7F) { + // rounding + val += 1; + } else if (!saturate) { + val = 0x80; + } + } + } else if ((e == 255) && (m == 0)) { + val = 0x80; + } else if (saturate) { + val |= 0x7F; + } else { + val = 0x80; + } + } + } + + inline ORT_HOST_DEVICE bool IsNaN() const { + return val == 0b10000000; + } + + inline ORT_HOST_DEVICE float ToFloat() const { + // This type does not exist on CUDA. + uint32_t res; + if (val == 0x80) { + res = 0xffc00000; + } else { + uint32_t expo = (val & 0x7C) >> 2; + uint32_t mant = val & 0x03; + uint32_t sign = val & 0x80; + res = sign << 24; + if (expo == 0) { + if (mant > 0) { + expo = 0x7F - 16; + if ((mant & 0x2) == 0) { + mant &= 0x1; + mant <<= 1; + expo -= 1; + } + res |= (mant & 0x1) << 22; + res |= expo << 23; + } + } else { + res |= mant << 21; + expo -= 16; + expo += 0x7F; + res |= expo << 23; + } + } + + float float_res; + std::memcpy(&float_res, &res, sizeof(float)); + return float_res; + } + + inline ORT_HOST_DEVICE operator float() const { return ToFloat(); } +}; + +inline ORT_HOST_DEVICE bool operator==(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val == right.val; } +inline ORT_HOST_DEVICE bool operator!=(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val != right.val; } +inline ORT_HOST_DEVICE bool operator<(const Float8E5M2FNUZ& left, const Float8E5M2FNUZ& right) { return left.val < right.val; } + +// User defined suffixes to make it easier to declare +// initializers with MLFloat8E5M2 and Float8E5M2 from unsigned char +#if !defined(__CUDACC__) && !defined(__HIPCC__) + +inline Float8E5M2FNUZ operator""_f8e5m2fnuz(unsigned long long int v) { + return Float8E5M2FNUZ(narrow(v), Float8E5M2FNUZ::FromBits()); +} + +inline Float8E5M2FNUZ operator""_f8e5m2fnuzp8(long double v) { + return Float8E5M2FNUZ(static_cast(v), true); +} + +#endif + +inline void Float8E5M2FNUZToFloat(const Float8E5M2FNUZ* blf, float* flt, size_t size) { + auto src = blf; + auto d = flt; + for (; size != 0; ++src, ++d, --size) { + *d = src->ToFloat(); + } +} + +inline void FloatToFloat8E5M2FNUZ(const float* flt, Float8E5M2FNUZ* blf, size_t size, bool saturate) { + auto src = flt; + auto d = blf; + for (; size != 0; ++src, ++d, --size) { + new (d) Float8E5M2FNUZ(*src, saturate); + } +} + +} // namespace onnxruntime + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E4M3FN lowest() { + return onnxruntime::Float8E4M3FN(0xFE, onnxruntime::Float8E4M3FN::FromBits()); // -448 + } + + static constexpr onnxruntime::Float8E4M3FN max() { + return onnxruntime::Float8E4M3FN(0x7E, onnxruntime::Float8E4M3FN::FromBits()); // 448 + } + + static constexpr onnxruntime::Float8E4M3FN min() { + return onnxruntime::Float8E4M3FN(0x08, onnxruntime::Float8E4M3FN::FromBits()); // 2^-6 = 0.015625 + } + + static constexpr onnxruntime::Float8E4M3FN denorm_min() { + return onnxruntime::Float8E4M3FN(0x01, onnxruntime::Float8E4M3FN::FromBits()); // 2^-9 = 0.001953125 + } + + static constexpr onnxruntime::Float8E4M3FN epsilon() { + return onnxruntime::Float8E4M3FN(0x20, onnxruntime::Float8E4M3FN::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FN round_error() { + return onnxruntime::Float8E4M3FN(0x30, onnxruntime::Float8E4M3FN::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FN infinity() { + // no infinity, returns quiet NaN instead + return quiet_NaN(); + } + + static constexpr onnxruntime::Float8E4M3FN quiet_NaN() { + return onnxruntime::Float8E4M3FN(0x7F, onnxruntime::Float8E4M3FN::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -5; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E5M2 lowest() { + return onnxruntime::Float8E5M2(0xFB, onnxruntime::Float8E5M2::FromBits()); // -57344.0 + } + + static constexpr onnxruntime::Float8E5M2 max() { + return onnxruntime::Float8E5M2(0x7B, onnxruntime::Float8E5M2::FromBits()); // 57344.0 + } + + static constexpr onnxruntime::Float8E5M2 min() { + return onnxruntime::Float8E5M2(0x4, onnxruntime::Float8E5M2::FromBits()); // 2^-14 = 0.00006103515 + } + + static constexpr onnxruntime::Float8E5M2 denorm_min() { + return onnxruntime::Float8E5M2(0x01, onnxruntime::Float8E5M2::FromBits()); // 2^-16 = 0.00001525878 + } + + static constexpr onnxruntime::Float8E5M2 epsilon() { + return onnxruntime::Float8E5M2(0x34, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2 round_error() { + return onnxruntime::Float8E5M2(0x38, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2 infinity() { + return onnxruntime::Float8E5M2(0x7C, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2 quiet_NaN() { + return onnxruntime::Float8E5M2(0x7F, onnxruntime::Float8E5M2::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -13; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E4M3FNUZ lowest() { + return onnxruntime::Float8E4M3FNUZ(0xFF, onnxruntime::Float8E4M3FNUZ::FromBits()); // -240.0 + } + + static constexpr onnxruntime::Float8E4M3FNUZ max() { + return onnxruntime::Float8E4M3FNUZ(0x7F, onnxruntime::Float8E4M3FNUZ::FromBits()); // 240.0 + } + + static constexpr onnxruntime::Float8E4M3FNUZ min() { + return onnxruntime::Float8E4M3FNUZ(0x08, onnxruntime::Float8E4M3FNUZ::FromBits()); // 2^-7 = 0.0078125 + } + + static constexpr onnxruntime::Float8E4M3FNUZ denorm_min() { + return onnxruntime::Float8E4M3FNUZ(0x01, onnxruntime::Float8E4M3FNUZ::FromBits()); // 2^-10 = 0.0009765625 + } + + static constexpr onnxruntime::Float8E4M3FNUZ epsilon() { + return onnxruntime::Float8E4M3FNUZ(0x28, onnxruntime::Float8E4M3FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FNUZ round_error() { + return onnxruntime::Float8E4M3FNUZ(0x38, onnxruntime::Float8E4M3FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E4M3FNUZ infinity() { + // no infinity, returns quiet NaN instead + return quiet_NaN(); + } + + static constexpr onnxruntime::Float8E4M3FNUZ quiet_NaN() { + return onnxruntime::Float8E4M3FNUZ(0x80, onnxruntime::Float8E4M3FNUZ::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 4; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 3; + static constexpr int radix = 2; + static constexpr int min_exponent = -6; + static constexpr int min_exponent10 = -1; + static constexpr int max_exponent = 8; + static constexpr int max_exponent10 = 2; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +template <> +class numeric_limits { + public: + static constexpr onnxruntime::Float8E5M2FNUZ lowest() { + return onnxruntime::Float8E5M2FNUZ(0xFF, onnxruntime::Float8E5M2FNUZ::FromBits()); // -57344.0 + } + + static constexpr onnxruntime::Float8E5M2FNUZ max() { + return onnxruntime::Float8E5M2FNUZ(0x7F, onnxruntime::Float8E5M2FNUZ::FromBits()); // 57344.0 + } + + static constexpr onnxruntime::Float8E5M2FNUZ min() { + return onnxruntime::Float8E5M2FNUZ(0x04, onnxruntime::Float8E5M2FNUZ::FromBits()); // 2^-15 = 0.00003051757 + } + + static constexpr onnxruntime::Float8E5M2FNUZ denorm_min() { + return onnxruntime::Float8E5M2FNUZ(0x01, onnxruntime::Float8E5M2FNUZ::FromBits()); // 2^-17 = 0.00000762939 + } + + static constexpr onnxruntime::Float8E5M2FNUZ epsilon() { + return onnxruntime::Float8E5M2FNUZ(0x34, onnxruntime::Float8E5M2FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2FNUZ round_error() { + return onnxruntime::Float8E5M2FNUZ(0x38, onnxruntime::Float8E5M2FNUZ::FromBits()); + } + + static constexpr onnxruntime::Float8E5M2FNUZ infinity() { + // no infinity, returns quiet NaN instead + return quiet_NaN(); + } + + static constexpr onnxruntime::Float8E5M2FNUZ quiet_NaN() { + return onnxruntime::Float8E5M2FNUZ(0x80, onnxruntime::Float8E5M2FNUZ::FromBits()); + } + + static constexpr bool is_specialized = true; + static constexpr bool is_signed = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = false; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = false; + static constexpr auto has_denorm = true; + static constexpr auto has_denorm_loss = true; + static constexpr auto round_style = round_to_nearest; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 3; + static constexpr int digits10 = 0; + static constexpr int max_digits10 = 2; + static constexpr int radix = 2; + static constexpr int min_exponent = -14; + static constexpr int min_exponent10 = -4; + static constexpr int max_exponent = 16; + static constexpr int max_exponent10 = 4; + static constexpr auto traps = false; + static constexpr auto tinyness_before = false; +}; + +} // namespace std + +#endif // DISABLE_FLOAT8_TYPES diff --git a/src/ort_include/core/common/parse_string.h b/src/ort_include/core/common/parse_string.h index 6345b2a..5f88d49 100644 --- a/src/ort_include/core/common/parse_string.h +++ b/src/ort_include/core/common/parse_string.h @@ -35,13 +35,30 @@ template std::enable_if_t, bool> TryParseStringWithClassicLocale(std::string_view str, T& value) { T parsed_value{}; - const auto [ptr, ec] = std::from_chars(str.data(), str.data() + str.size(), parsed_value); - if (ec != std::errc{}) { + std::from_chars_result conversion_result{}; + if constexpr (std::is_integral_v && std::is_unsigned_v) { + // For unsigned integral types, also handle hex values, i.e., those beginning with "0x". + // std::from_chars() does not accept the "0x" prefix. + const bool has_hex_prefix = str.size() >= 2 && + str[0] == '0' && + (str[1] == 'x' || str[1] == 'X'); + + if (has_hex_prefix) { + str = str.substr(2); + } + + const int base = has_hex_prefix ? 16 : 10; + conversion_result = std::from_chars(str.data(), str.data() + str.size(), parsed_value, base); + } else { + conversion_result = std::from_chars(str.data(), str.data() + str.size(), parsed_value); + } + + if (conversion_result.ec != std::errc{}) { return false; } - if (ptr != str.data() + str.size()) { + if (conversion_result.ptr != str.data() + str.size()) { return false; } diff --git a/src/ort_include/core/common/path_string.h b/src/ort_include/core/common/path_string.h index 6cfb327..4ca326d 100644 --- a/src/ort_include/core/common/path_string.h +++ b/src/ort_include/core/common/path_string.h @@ -40,6 +40,12 @@ inline PathString ToPathString(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::wstring!"); +inline PathString ToPathString(std::string_view s) { + return ToWideString(s); +} +inline PathString ToPathString(const char* s) { + return ToWideString(s); +} inline PathString ToPathString(const std::string& s) { return ToWideString(s); } @@ -56,6 +62,14 @@ inline std::string PathToUTF8String(const PathString& s) { static_assert(std::is_same::value, "PathString is not std::string!"); +inline PathString ToPathString(const char* s) { + return s; +} + +inline PathString ToPathString(std::string_view s) { + return PathString{s}; +} + inline PathChar ToLowerPathChar(PathChar c) { return std::tolower(c); } diff --git a/src/ort_include/core/common/status.h b/src/ort_include/core/common/status.h index da9735a..8cf6420 100644 --- a/src/ort_include/core/common/status.h +++ b/src/ort_include/core/common/status.h @@ -46,6 +46,7 @@ enum StatusCode { EP_FAIL = 11, MODEL_LOAD_CANCELED = 12, MODEL_REQUIRES_COMPILATION = 13, + NOT_FOUND = 14, }; constexpr const char* StatusCodeToString(StatusCode status) noexcept { @@ -78,6 +79,8 @@ constexpr const char* StatusCodeToString(StatusCode status) noexcept { return "MODEL_LOAD_CANCELED"; case StatusCode::MODEL_REQUIRES_COMPILATION: return "MODEL_REQUIRES_COMPILATION"; + case StatusCode::NOT_FOUND: + return "NOT_FOUND"; default: return "GENERAL ERROR"; } @@ -114,6 +117,8 @@ constexpr HRESULT StatusCodeToHRESULT(StatusCode status) noexcept { return HRESULT_FROM_WIN32(ERROR_CANCELLED); case StatusCode::MODEL_REQUIRES_COMPILATION: return HRESULT_FROM_WIN32(ERROR_NOT_SUPPORTED); + case StatusCode::NOT_FOUND: + return HRESULT_FROM_WIN32(ERROR_NOT_FOUND); default: return E_FAIL; } diff --git a/src/ort_include/core/common/string_helper.h b/src/ort_include/core/common/string_helper.h index 1304303..c0b331c 100644 --- a/src/ort_include/core/common/string_helper.h +++ b/src/ort_include/core/common/string_helper.h @@ -7,5 +7,9 @@ // forward declaration struct OrtAllocator; namespace onnxruntime { -char* StrDup(const std::string& str, OrtAllocator* allocator); +char* StrDup(std::string_view str, OrtAllocator* allocator); +inline char* StrDup(const std::string& str, OrtAllocator* allocator) { + return StrDup(std::string_view{str}, allocator); +} +wchar_t* StrDup(std::wstring_view str, OrtAllocator* allocator); } // namespace onnxruntime diff --git a/src/ort_include/core/framework/callback.h b/src/ort_include/core/framework/callback.h deleted file mode 100644 index 88f14d7..0000000 --- a/src/ort_include/core/framework/callback.h +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. -#pragma once -#include "core/common/common.h" - -namespace onnxruntime { -struct OrtCallback { - void (*f)(void* param) noexcept; - void* param; -}; - -/** - * f will be freed in this call - */ -void OrtRunCallback(OrtCallback* f) noexcept; - -/** - * Invokes the contained OrtCallback with operator()(T). - * Useful for something like a std::unique_ptr<> deleter. - */ -struct OrtCallbackInvoker { - OrtCallbackInvoker() noexcept - : callback{nullptr, nullptr} {} - - OrtCallbackInvoker(OrtCallback callback_to_invoke) noexcept - : callback(callback_to_invoke) {} - - OrtCallback callback; - - template - void operator()(T) noexcept { - if (callback.f) { - callback.f(callback.param); - } - } -}; - -/** - * Invokes the contained OrtCallback upon destruction or being assigned to. - */ -class ScopedOrtCallbackInvoker { - public: - explicit ScopedOrtCallbackInvoker(OrtCallback callback) noexcept - : callback_(callback) {} - - ScopedOrtCallbackInvoker(ScopedOrtCallbackInvoker&& other) noexcept - : callback_(other.callback_) { - other.callback_.f = nullptr; - other.callback_.param = nullptr; - } - - ScopedOrtCallbackInvoker& operator=(ScopedOrtCallbackInvoker&& other) noexcept { - if (callback_.f) { - callback_.f(callback_.param); - } - - callback_ = other.callback_; - other.callback_.f = nullptr; - other.callback_.param = nullptr; - - return *this; - } - - ~ScopedOrtCallbackInvoker() noexcept { - if (callback_.f) { - callback_.f(callback_.param); - } - } - - private: - ORT_DISALLOW_COPY_AND_ASSIGNMENT(ScopedOrtCallbackInvoker); - OrtCallback callback_; -}; -} // namespace onnxruntime diff --git a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h index c313944..45b1751 100644 --- a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h +++ b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h @@ -199,6 +199,100 @@ struct PaddingToAvoidFalseSharing { char padding[ORT_FALSE_SHARING_BYTES]; }; +/* Usage: +1. In executor, call Start() before profiling and Stop() to get profiled numbers; +2. Inside thread pool, call LogStart() before interested section and LogEnd... after to log elapsed time; +3. To extend, just add more events in enum Event before "All", and update GetEventName(...) accordingly; +4. Note LogStart must pair with either LogEnd or LogEndAndStart, otherwise ORT_ENFORCE will fail; +5. ThreadPoolProfiler is thread-safe. +*/ +#ifdef ORT_MINIMAL_BUILD +class ThreadPoolProfiler { + public: + enum ThreadPoolEvent { + DISTRIBUTION = 0, + DISTRIBUTION_ENQUEUE, + RUN, + WAIT, + WAIT_REVOKE, + MAX_EVENT + }; + ThreadPoolProfiler(int, const CHAR_TYPE*) {} + ~ThreadPoolProfiler() = default; + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ThreadPoolProfiler); + void Start() {} + std::string Stop() { return "not available for minimal build"; } + void LogStart() {} + void LogEnd(ThreadPoolEvent) {} + void LogEndAndStart(ThreadPoolEvent) {} + void LogStartAndCoreAndBlock(std::ptrdiff_t) {} + void LogCoreAndBlock(std::ptrdiff_t) {} + void LogThreadId(int) {} + void LogRun(int) {} + std::string DumpChildThreadStat() { return {}; } +}; +#else +class ThreadPoolProfiler { + public: + enum ThreadPoolEvent { + DISTRIBUTION = 0, + DISTRIBUTION_ENQUEUE, + RUN, + WAIT, + WAIT_REVOKE, + MAX_EVENT + }; + ThreadPoolProfiler(int num_threads, const CHAR_TYPE* threal_pool_name); + ~ThreadPoolProfiler(); + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ThreadPoolProfiler); + using Clock = std::chrono::high_resolution_clock; + void Start(); // called by executor to start profiling + std::string Stop(); // called by executor to stop profiling and return collected numbers + void LogStart(); // called in main thread to record the starting time point + void LogEnd(ThreadPoolEvent); // called in main thread to calculate and save the time elapsed from last start point + void LogEndAndStart(ThreadPoolEvent); + void LogStartAndCoreAndBlock(std::ptrdiff_t block_size); + void LogCoreAndBlock(std::ptrdiff_t block_size); // called in main thread to log core and block size for task breakdown + void LogThreadId(int thread_idx); // called in child thread to log its id + void LogRun(int thread_idx); // called in child thread to log num of run + std::string DumpChildThreadStat(); // return all child statistics collected so far + + private: + static const char* GetEventName(ThreadPoolEvent); + struct MainThreadStat { + uint64_t events_[MAX_EVENT] = {}; + int32_t core_ = -1; + std::vector blocks_; // block size determined by cost model + std::vector points_; + void LogCore(); + void LogBlockSize(std::ptrdiff_t block_size); + void LogStart(); + void LogEnd(ThreadPoolEvent); + void LogEndAndStart(ThreadPoolEvent); + std::string Reset(); + }; + bool enabled_ = false; + MainThreadStat& GetMainThreadStat(); // return thread local stat + int num_threads_; +#ifdef _MSC_VER +#pragma warning(push) + // C4324: structure was padded due to alignment specifier +#pragma warning(disable : 4324) +#endif // _MSC_VER + struct ORT_ALIGN_TO_AVOID_FALSE_SHARING ChildThreadStat { + std::thread::id thread_id_; + uint64_t num_run_ = 0; + onnxruntime::TimePoint last_logged_point_ = Clock::now(); + int32_t core_ = -1; // core that the child thread is running on + }; +#ifdef _MSC_VER +#pragma warning(pop) +#endif // _MSC_VER + std::vector child_thread_stats_; + std::string thread_pool_name_; +}; +#endif + // Extended Eigen thread pool interface, avoiding the need to modify // the ThreadPoolInterface.h header from the external Eigen // repository. @@ -241,6 +335,8 @@ class ExtendedThreadPoolInterface : public Eigen::ThreadPoolInterface { // two loops execute in series in a parallel section. ] virtual void RunInParallel(std::function fn, unsigned n, std::ptrdiff_t block_size) = 0; + virtual void StartProfiling() = 0; + virtual std::string StopProfiling() = 0; }; class ThreadPoolParallelSection { @@ -609,6 +705,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter return 0; } + ThreadPoolProfiler profiler_; void SignalAllAndWait() { done_ = true; @@ -623,7 +720,13 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } public: + void StartProfiling() override { + profiler_.Start(); + } + std::string StopProfiling() override { + return profiler_.Stop(); + } struct Tag { constexpr Tag() : v_(0) { @@ -664,7 +767,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ThreadPoolTempl(const CHAR_TYPE* name, int num_threads, bool allow_spinning, Environment& env, const ThreadOptions& thread_options) - : + : profiler_(num_threads, name), env_(env), num_threads_(num_threads), allow_spinning_(allow_spinning), @@ -812,6 +915,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // tasks that were created (if any) for the parallel section. We // revoke tasks still in queues, and then wait for any that are // still running. + profiler_.LogStart(); unsigned tasks_started = static_cast(ps.tasks.size()); while (!ps.tasks.empty()) { const auto& item = ps.tasks.back(); @@ -821,6 +925,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } ps.tasks.pop_back(); } + profiler_.LogEnd(ThreadPoolProfiler::WAIT_REVOKE); // Wait for the dispatch task's own work... if (ps.dispatch_q_idx > -1) { @@ -1099,6 +1204,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ps.work_done.store(true, std::memory_order_release); }; + profiler_.LogStart(); ps.dispatch_q_idx = preferred_workers[current_dop] % num_threads_; WorkerData& dispatch_td = worker_data_[ps.dispatch_q_idx]; Queue& dispatch_que = dispatch_td.queue; @@ -1116,6 +1222,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } else { ps.dispatch_q_idx = -1; // failed to enqueue dispatch_task } + profiler_.LogEnd(ThreadPoolProfiler::DISTRIBUTION_ENQUEUE); } else { // Synchronous dispatch ScheduleOnPreferredWorkers(pt, ps, preferred_workers, current_dop, new_dop, std::move(worker_fn)); @@ -1133,6 +1240,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter unsigned n, std::ptrdiff_t block_size) override { ORT_ENFORCE(n <= num_threads_ + 1, "More work items than threads"); + profiler_.LogStartAndCoreAndBlock(block_size); PerThread* pt = GetPerThread(); assert(pt->leading_par_section && "RunInParallel, but not in parallel section"); assert((n > 1) && "Trivial parallel section; should be avoided by caller"); @@ -1162,15 +1270,18 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter }; RunInParallelInternal(*pt, ps, n, false, std::move(worker_fn)); assert(ps.dispatch_q_idx == -1); + profiler_.LogEndAndStart(ThreadPoolProfiler::DISTRIBUTION); // Run work in the main thread loop.fn(0); + profiler_.LogEndAndStart(ThreadPoolProfiler::RUN); // Wait for workers to exit the loop ps.current_loop = 0; while (ps.workers_in_loop) { onnxruntime::concurrency::SpinPause(); } + profiler_.LogEnd(ThreadPoolProfiler::WAIT); } // Run a single parallel loop _without_ a parallel section. This is a @@ -1187,12 +1298,16 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // 1. run fn(...); void RunInParallel(std::function fn, unsigned n, std::ptrdiff_t block_size) override { ORT_ENFORCE(n <= num_threads_ + 1, "More work items than threads"); + profiler_.LogStartAndCoreAndBlock(block_size); PerThread* pt = GetPerThread(); ThreadPoolParallelSection ps; StartParallelSectionInternal(*pt, ps); RunInParallelInternal(*pt, ps, n, true, fn); // select dispatcher and do job distribution; + profiler_.LogEndAndStart(ThreadPoolProfiler::DISTRIBUTION); fn(0); // run fn(0) + profiler_.LogEndAndStart(ThreadPoolProfiler::RUN); EndParallelSectionInternal(*pt, ps); // wait for all + profiler_.LogEnd(ThreadPoolProfiler::WAIT); } int NumThreads() const final { @@ -1424,6 +1539,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter const int steal_count = spin_count / 100; SetDenormalAsZero(set_denormal_as_zero_); + profiler_.LogThreadId(thread_id); while (!should_exit) { Task t = q.PopFront(); @@ -1516,6 +1632,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter if (t) { td.SetActive(); t(); + profiler_.LogRun(thread_id); td.SetSpinning(); } } diff --git a/src/ort_include/core/platform/env.h b/src/ort_include/core/platform/env.h index ff3d78d..c73c64f 100644 --- a/src/ort_include/core/platform/env.h +++ b/src/ort_include/core/platform/env.h @@ -26,7 +26,6 @@ limitations under the License. #include "core/common/common.h" #include "core/common/path_string.h" -#include "core/framework/callback.h" #include "core/session/onnxruntime_c_api.h" #ifndef _WIN32 diff --git a/src/ort_include/core/platform/env_var_utils.h b/src/ort_include/core/platform/env_var_utils.h index b7cb6ea..63a2fed 100644 --- a/src/ort_include/core/platform/env_var_utils.h +++ b/src/ort_include/core/platform/env_var_utils.h @@ -4,7 +4,7 @@ #pragma once #include -#include + #include "core/common/common.h" #ifndef SHARED_PROVIDER #include "core/common/logging/logging.h" @@ -83,7 +83,7 @@ std::optional ParseTestOnlyEnvironmentVariable(const std::string& name, std::string default_hint = "End users should opt for provider options or session options."; const std::string& logged_hint = hint.empty() ? default_hint : hint; - std::cout << "Environment variable " << name << " is used. It is reserved for internal testing purpose. " + LOGS_DEFAULT(WARNING) << "Environment variable " << name << " is used. It is reserved for internal testing purpose. " << logged_hint; return env; diff --git a/src/ort_include/core/platform/threadpool.h b/src/ort_include/core/platform/threadpool.h index ad5c7a1..04df6dc 100644 --- a/src/ort_include/core/platform/threadpool.h +++ b/src/ort_include/core/platform/threadpool.h @@ -360,7 +360,11 @@ class ThreadPool { // working in combination with the thread initiating the loop. static int DegreeOfParallelism(const ThreadPool* tp); - ORT_DISALLOW_COPY_AND_ASSIGNMENT(ThreadPool); + ORT_DISALLOW_COPY_AND_ASSIGNMENT(ThreadPool); + + // StartProfiling and StopProfiling are not to be consumed as public-facing API + static void StartProfiling(concurrency::ThreadPool* tp); + static std::string StopProfiling(concurrency::ThreadPool* tp); private: friend class LoopCounter; @@ -407,6 +411,10 @@ class ThreadPool { void Schedule(std::function fn); + void StartProfiling(); + + std::string StopProfiling(); + ThreadOptions thread_options_; // If a thread pool is created with degree_of_parallelism != 1 then an underlying diff --git a/src/ort_include/core/session/onnxruntime_c_api.h b/src/ort_include/core/session/onnxruntime_c_api.h index a3ae440..40fab45 100644 --- a/src/ort_include/core/session/onnxruntime_c_api.h +++ b/src/ort_include/core/session/onnxruntime_c_api.h @@ -1,106 +1,34 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -// See docs\c_cxx\README.md on generating the Doxygen documentation from this file - -/** \mainpage ONNX Runtime - * - * ONNX Runtime is a high-performance inference and training graph execution engine for deep learning models. - * - * ONNX Runtime's C, C++ APIs offer an easy to use interface to onboard and execute onnx models. - * - \subpage c_cpp_api "Core C, C++ APIs" - * - \subpage training_c_cpp_api "Training C, C++ APIs for on-device training" - * - * \page c_cpp_api Core C, C++ APIs - *

C

- * - * ::OrtApi - Click here to go to the structure with all C API functions. - * - *

C++

- * - * ::Ort - Click here to go to the namespace holding all of the C++ wrapper classes - * - * It is a set of header only wrapper classes around the C API. The goal is to turn the C style return value error codes into C++ exceptions, and to - * automate memory management through standard C++ RAII principles. - * - * \addtogroup Global - * ONNX Runtime C API - * @{ - */ - -#pragma once -#include -#include -#include -#include - - - -//! @} -// SAL2 Definitions -#ifndef _MSC_VER -#define _In_ -#define _In_z_ -#define _In_opt_ -#define _In_opt_z_ -#define _Out_ -#define _Outptr_ -#define _Out_opt_ -#define _Inout_ -#define _Inout_opt_ -#define _Frees_ptr_opt_ -#define _Ret_maybenull_ -#define _Ret_notnull_ -#define _Check_return_ -#define _Outptr_result_maybenull_ -#define _In_reads_(X) -#define _Inout_updates_(X) -#define _Out_writes_(X) -#define _Inout_updates_all_(X) -#define _Out_writes_bytes_all_(X) -#define _Out_writes_all_(X) -#define _Success_(X) -#define _Outptr_result_buffer_maybenull_(X) -#define ORT_ALL_ARGS_NONNULL __attribute__((nonnull)) -#else -#include -#define ORT_ALL_ARGS_NONNULL -#endif - -#ifdef _WIN32 -// Define ORT_DLL_IMPORT if your program is dynamically linked to Ort. -// dllexport is not used, we use a .def file. -#ifdef ORT_DLL_IMPORT -#define ORT_EXPORT __declspec(dllimport) -#else -#define ORT_EXPORT -#endif -#define ORT_API_CALL _stdcall -#define ORT_MUST_USE_RESULT -#define ORTCHAR_T wchar_t -#else -// To make symbols visible on macOS/iOS -#ifdef __APPLE__ -#define ORT_EXPORT __attribute__((visibility("default"))) -#else -#define ORT_EXPORT -#endif -#define ORT_API_CALL -#define ORT_MUST_USE_RESULT __attribute__((warn_unused_result)) -#define ORTCHAR_T char -#endif - -/// ORTCHAR_T, ORT_TSTR are reserved specifically for path handling. -/// All other strings are UTF-8 encoded, use char and std::string -#ifndef ORT_TSTR -#ifdef _WIN32 -#define ORT_TSTR(X) L##X -// When X is a macro, L##X is not defined. In this case, we need to use ORT_TSTR_ON_MACRO. -#define ORT_TSTR_ON_MACRO(X) L"" X -#else -#define ORT_TSTR(X) X -#define ORT_TSTR_ON_MACRO(X) X -#endif -#endif - -/// @} +#pragma once + +#ifndef _WIN32 +#define _In_ +#define _In_z_ +#define _In_opt_ +#define _In_opt_z_ +#define _Out_ +#define _Outptr_ +#define _Out_opt_ +#define _Inout_ +#define _Inout_opt_ +#define _Frees_ptr_opt_ +#define _Ret_maybenull_ +#define _Ret_notnull_ +#define _Check_return_ +#define _Outptr_result_maybenull_ +#define _In_reads_(X) +#define _Inout_updates_(X) +#define _Out_writes_(X) +#define _Inout_updates_all_(X) +#define _Out_writes_bytes_all_(X) +#define _Out_writes_all_(X) +#define _Success_(X) +#define _Outptr_result_buffer_maybenull_(X) +#else +#include +#endif + +#ifdef _WIN32 +#define ORTCHAR_T wchar_t +#else +#define ORTCHAR_T char +#endif \ No newline at end of file diff --git a/src/ort_include/core/util/thread_utils.h b/src/ort_include/core/util/thread_utils.h index b5e2516..c25f789 100644 --- a/src/ort_include/core/util/thread_utils.h +++ b/src/ort_include/core/util/thread_utils.h @@ -7,8 +7,7 @@ #include #include -struct OrtThreadPoolParams -{ +struct OrtThreadPoolParams { // 0: Use default setting. (All the physical cores or half of the logical cores) // 1: Don't create thread pool // n: Create a thread pool with n threads. @@ -20,7 +19,13 @@ struct OrtThreadPoolParams bool auto_set_affinity = false; // If it is true, the thread pool will spin a while after the queue became empty. +#if !defined(ORT_CLIENT_PACKAGE_BUILD) bool allow_spinning = true; +#else + // default allow_spinning to false for ORT builds targeting client/on-device workloads, + // to reduce CPU utilization and improve power efficiency. + bool allow_spinning = false; +#endif // It it is non-negative, thread pool will split a task by a decreasing block size // of remaining_of_total_iterations / (num_of_threads * dynamic_block_base_) @@ -38,16 +43,15 @@ struct OrtThreadPoolParams // meaning ith thread will be attached to first 8 logical processors std::string affinity_str; - const ORTCHAR_T *name = nullptr; + const ORTCHAR_T* name = nullptr; // Set or unset denormal as zero bool set_denormal_as_zero = false; }; -std::ostream &operator<<(std::ostream &os, const OrtThreadPoolParams ¶ms); +std::ostream& operator<<(std::ostream& os, const OrtThreadPoolParams& params); -struct OrtThreadingOptions -{ +struct OrtThreadingOptions { // Params for creating the threads that parallelizes execution of an op OrtThreadPoolParams intra_op_thread_pool_params; @@ -55,17 +59,14 @@ struct OrtThreadingOptions OrtThreadPoolParams inter_op_thread_pool_params; }; -namespace onnxruntime -{ +namespace onnxruntime { - namespace concurrency - { - enum class ThreadPoolType : uint8_t - { - INTRA_OP, - INTER_OP - }; - std::unique_ptr CreateThreadPool(Env *env, OrtThreadPoolParams options, - ThreadPoolType tpool_type); - } // namespace concurrency -} // namespace onnxruntime +namespace concurrency { +enum class ThreadPoolType : uint8_t { + INTRA_OP, + INTER_OP +}; +std::unique_ptr CreateThreadPool(Env* env, OrtThreadPoolParams options, + ThreadPoolType tpool_type); +} // namespace concurrency +} // namespace onnxruntime diff --git a/tests/bench/bench_cast.cpp b/tests/bench/bench_cast.cpp index 1dccbe4..e323346 100644 --- a/tests/bench/bench_cast.cpp +++ b/tests/bench/bench_cast.cpp @@ -1,5 +1,5 @@ #include "bench_util.h" -#include "core/mlas/lib/mlasi.h" +#include "mlasi.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/tests/bench/bench_computesoftmax.cpp b/tests/bench/bench_computesoftmax.cpp index 32135b3..bf30db4 100644 --- a/tests/bench/bench_computesoftmax.cpp +++ b/tests/bench/bench_computesoftmax.cpp @@ -58,10 +58,10 @@ void COMPUTESOFTMAXINPLACE(benchmark::State& state) { std::copy(data.begin(), data.end(), input); // Copy the data to the aligned memory // warming up run - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); for (auto _ : state) { - MlasComputeSoftmax(input, output, N, D, false, false, tp.get()); + MlasComputeSoftmax(input, output, N, D, false, false, 0.0f, tp.get()); } free(ptr.underlying_buffer); diff --git a/tests/bench/bench_fp16_neon_common.cpp b/tests/bench/bench_fp16_neon_common.cpp index 1dccbe4..e323346 100644 --- a/tests/bench/bench_fp16_neon_common.cpp +++ b/tests/bench/bench_fp16_neon_common.cpp @@ -1,5 +1,5 @@ #include "bench_util.h" -#include "core/mlas/lib/mlasi.h" +#include "mlasi.h" #if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) && defined(MLAS_TARGET_ARM64) diff --git a/tests/bench/bench_rope.cpp b/tests/bench/bench_rope.cpp index b0630b9..216ee79 100644 --- a/tests/bench/bench_rope.cpp +++ b/tests/bench/bench_rope.cpp @@ -4,7 +4,7 @@ #include "mlas.h" #include "benchmark/benchmark.h" #include "bench_util.h" -#include "core/framework/float16.h" +#include "core/common/float16.h" using namespace onnxruntime; diff --git a/tests/bench/bench_sconv.cpp b/tests/bench/bench_sconv.cpp index 39d1352..dc37980 100644 --- a/tests/bench/bench_sconv.cpp +++ b/tests/bench/bench_sconv.cpp @@ -3,6 +3,7 @@ #include "mlas.h" #include "bench_util.h" +#include "core/util/thread_utils.h" #include #include @@ -138,6 +139,113 @@ void SCONV_NCHW(benchmark::State& state, const char* /*dummy*/) { } } +static MLAS_THREADPOOL* GetMlasThreadPoolForConvBenchmark(void) { + static auto threadpool = std::make_unique( + &onnxruntime::Env::Default(), onnxruntime::ThreadOptions(), nullptr, 4, true); + return threadpool.get(); +} + +void SCONV_NCHW_THREADED(benchmark::State& state, const char* /*dummy*/) { + MLAS_THREADPOOL* tp = GetMlasThreadPoolForConvBenchmark(); + + const int64_t rank = state.range(0); // Rank + const int64_t batch_size = state.range(1); // N + const int64_t groups = state.range(2); // G + const int64_t input_channels_per_group = state.range(3); // Cpg + const int64_t output_channels_per_group = state.range(4); // Fpg + + if (rank <= 0) throw std::invalid_argument("Kernel rank must greater than 0!"); + if (batch_size <= 0) throw std::invalid_argument("Batch size must greater than 0!"); + if (groups <= 0) throw std::invalid_argument("Group count must greater than 0!"); + if (input_channels_per_group <= 0) throw std::invalid_argument("input_channels_per_group must greater than 0!"); + if (output_channels_per_group <= 0) throw std::invalid_argument("output_channels_per_group must greater than 0!"); + + size_t arg_position = 5; + const auto input_shape = BenchArgsVector(state, arg_position, rank); + const auto kernel_shape = BenchArgsVector(state, arg_position, rank); + const auto paddings = BenchArgsVector(state, arg_position, rank * 2); + const auto strides = BenchArgsVector(state, arg_position, rank); + const auto dilations = BenchArgsVector(state, arg_position, rank); + + // do not check the size of each vector as they are forced from args. + if (std::any_of(input_shape.begin(), input_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all input image dim must > 0"); + } + + if (std::any_of(kernel_shape.begin(), kernel_shape.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all kernel dim must > 0"); + } + + if (std::any_of(strides.begin(), strides.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all strides dim must > 0"); + } + + if (std::any_of(dilations.begin(), dilations.end(), [](const int64_t& dim) { return dim <= 0; })) { + throw std::invalid_argument("all dilations dim must > 0"); + } + + const int64_t GC = groups * input_channels_per_group; + const int64_t GF = groups * output_channels_per_group; + std::vector x_shape = {batch_size, GC}; + x_shape.insert(x_shape.end(), input_shape.begin(), input_shape.end()); + std::vector f_shape = {GF, input_channels_per_group}; + f_shape.insert(f_shape.end(), kernel_shape.begin(), kernel_shape.end()); + + std::vector output_shape((size_t)rank); + for (int64_t i = 0; i < rank; ++i) { + auto km = 1 + dilations[i] * (kernel_shape[i] - 1); + output_shape[i] = (paddings[i] + paddings[i + rank] + input_shape[i] - km) / strides[i] + 1; + } + std::vector y_shape = {batch_size, GF}; + y_shape.insert(y_shape.end(), output_shape.begin(), output_shape.end()); + + MLAS_ACTIVATION activation; + activation.ActivationKind = MlasIdentityActivation; + MLAS_CONV_PARAMETERS Parameters; + size_t WorkingBufferSize = 0; + MlasConvPrepare(&Parameters, + static_cast(rank), + static_cast(batch_size), + static_cast(groups), + static_cast(input_channels_per_group), + input_shape.data(), + kernel_shape.data(), + dilations.data(), + paddings.data(), + strides.data(), + output_shape.data(), + static_cast(output_channels_per_group), + &activation, + &WorkingBufferSize, + 0.0f, + tp); + + auto X = RandomVectorUniform(x_shape, -2.0, 2.0); + auto F = RandomVectorUniform(f_shape, -1.0, 1.0); + int64_t y_size = std::accumulate(y_shape.begin(), y_shape.end(), 1LL, std::multiplies()); + std::vector Y(static_cast(y_size)); + std::vector working_buffer(WorkingBufferSize); + + // warm up first round. + MlasConv(&Parameters, + X.data(), + F.data(), + nullptr, + working_buffer.data(), + Y.data(), + tp); + + for (auto _ : state) { + MlasConv(&Parameters, + X.data(), + F.data(), + nullptr, + working_buffer.data(), + Y.data(), + tp); + } +} + static void ResNet50(benchmark::internal::Benchmark* b) { b->ArgNames(ArgNamesForConv(2)); @@ -221,6 +329,7 @@ static void TeamsModel(benchmark::internal::Benchmark* b) { } BENCHMARK_CAPTURE(SCONV_NCHW, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); +BENCHMARK_CAPTURE(SCONV_NCHW_THREADED, TeamsModel, "")->Apply(TeamsModel)->UseRealTime(); static void General_Conv2d(benchmark::internal::Benchmark* b) { b->ArgNames(ArgNamesForConv(2)); diff --git a/tests/bench/bench_sgemm.cpp b/tests/bench/bench_sgemm.cpp index a94d33c..422fc6f 100644 --- a/tests/bench/bench_sgemm.cpp +++ b/tests/bench/bench_sgemm.cpp @@ -30,9 +30,12 @@ void SGEMM(benchmark::State& state, bool pack_b, bool trans_a, bool trans_b, flo tpo, onnxruntime::concurrency::ThreadPoolType::INTRA_OP)); if (pack_b) { - size_t pack_b_size = MlasGemmPackBSize(N, K); + CBLAS_TRANSPOSE transB_enum = trans_b ? CblasTrans : CblasNoTrans; + CBLAS_TRANSPOSE transA_enum = trans_a ? CblasTrans : CblasNoTrans; + + size_t pack_b_size = MlasGemmPackBSize(transA_enum, transB_enum, N, K); std::vector B_packed(pack_b_size); - MlasGemmPackB(CblasNoTrans, N, K, B.data(), N, B_packed.data()); + MlasGemmPackB(transA_enum, transB_enum, N, K, B.data(), N, B_packed.data()); MlasGemm( trans_a ? CblasTrans : CblasNoTrans, diff --git a/tests/bench/bench_util.h b/tests/bench/bench_util.h index e3abd7b..e6eda2b 100644 --- a/tests/bench/bench_util.h +++ b/tests/bench/bench_util.h @@ -8,7 +8,7 @@ #include #include -#include "core/framework/float16.h" +#include "core/common/float16.h" #include "mlas.h" template diff --git a/tests/unittest/test_conv2d.h b/tests/unittest/test_conv2d.h index 4382ae4..20bf0ec 100644 --- a/tests/unittest/test_conv2d.h +++ b/tests/unittest/test_conv2d.h @@ -245,22 +245,19 @@ class MlasConv2DTest : public MlasTestBase { Filter, Bias, OutputReference); - static constexpr float rtol = 1e-4f; - static constexpr float atol = 1e-6f; - for (size_t i = 0; i != OutputElements; ++i) { - float tolerance = atol + rtol * std::abs(OutputReference[i]); - ASSERT_NEAR(Output[i], OutputReference[i], tolerance) << "B" << BatchCount << "/" - << "G" << GroupCount << "/" - << "Cpg" << InputChannels << "/" - << "Fpg" << FilterCount << "/" - << "H" << InputHeight << "/" - << "W" << InputWidth << "/" - << "KH" << KernelHeight << "/" - << "KW" << KernelWidth << "/" - << "Pad" << PaddingLeftHeight << "," << PaddingLeftWidth << "," << PaddingRightHeight << "," << PaddingRightWidth << "/" - << "Dilation" << DilationHeight << "," << DilationWidth << "/" - << "Stride" << StrideHeight << "," << StrideWidth; - } + + ASSERT_EQ(memcmp(Output, OutputReference, OutputElements * sizeof(float)), 0) + << "B" << BatchCount << "/" + << "G" << GroupCount << "/" + << "Cpg" << InputChannels << "/" + << "Fpg" << FilterCount << "/" + << "H" << InputHeight << "/" + << "W" << InputWidth << "/" + << "KH" << KernelHeight << "/" + << "KW" << KernelWidth << "/" + << "Pad" << PaddingLeftHeight << "," << PaddingLeftWidth << "," << PaddingRightHeight << "," << PaddingRightWidth << "/" + << "Dilation" << DilationHeight << "," << DilationWidth << "/" + << "Stride" << StrideHeight << "," << StrideWidth; } void ExecuteLong(void) override { diff --git a/tests/unittest/test_dequantizelinear.cpp b/tests/unittest/test_dequantizelinear.cpp new file mode 100644 index 0000000..b994981 --- /dev/null +++ b/tests/unittest/test_dequantizelinear.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test_util.h" + +template +class MlasDequantizeLinearTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferInput; + MatrixGuardBuffer BufferOutput; + MatrixGuardBuffer BufferOutputReference; + + void GenerateReference(const QuantInt* Input, float* OutputReference, size_t N, float Scale, QuantInt ZeroPoint) { + int32_t ZeroPointS32 = static_cast(ZeroPoint); + + for (size_t n = 0; n < N; n++) { + OutputReference[n] = static_cast(static_cast(Input[n]) - ZeroPointS32) * Scale; + } + } + + void Test(size_t N) { + QuantInt* Input = BufferInput.GetBuffer(N); + float* Output = BufferOutput.GetBuffer(N); + float* OutputReference = BufferOutputReference.GetBuffer(N); + + std::default_random_engine generator(static_cast(N)); + + std::uniform_real_distribution min_gen(-10.f, -10e-3f); + float MinimumValue = min_gen(generator); + + std::uniform_real_distribution max_gen(10e-3f, 10.f); + float MaximumValue = max_gen(generator); + + float Scale = (MaximumValue - MinimumValue) / 512.f; + + std::uniform_int_distribution zp_distribution(std::numeric_limits::min(), + std::numeric_limits::max()); + QuantInt ZeroPoint = static_cast(zp_distribution(generator)); + + for (size_t n = 0; n < N; n++) { + Input[n] = static_cast(zp_distribution(generator)); + } + + GenerateReference(Input, OutputReference, N, Scale, ZeroPoint); + MlasDequantizeLinear(Input, Output, N, Scale, ZeroPoint); + + for (size_t n = 0; n < N; n++) { + ASSERT_EQ(Output[n], OutputReference[n]) << ", size=" << N << ", index=" << n; + } + } + + public: + static const char* GetTestSuiteName() { + if constexpr (std::is_same_v) { + return "DequantizeLinearS8"; + } else { + return "DequantizeLinearU8"; + } + } + + void ExecuteShort(void) override { + for (size_t n = 1; n <= 512; n++) { + Test(n); + } + } +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + size_t count = 0; + if (is_short_execute) { + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + count += MlasDirectShortExecuteTests>::RegisterShortExecute(); + } + return count; +}); diff --git a/tests/unittest/test_dynamic_qgemm.cpp b/tests/unittest/test_dynamic_qgemm.cpp new file mode 100644 index 0000000..2a5d72b --- /dev/null +++ b/tests/unittest/test_dynamic_qgemm.cpp @@ -0,0 +1,172 @@ +// +// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: MIT +// + +// Currently this test only applies to KleidiAI Guard against it running in any other situation +#if defined(USE_KLEIDIAI) && !defined(_MSC_VER) + +#include "test_util.h" +#include "mlasi.h" // for MLAS_CPUIDINFO + +class MlasDynamicQgemmTest { + private: + MatrixGuardBuffer buffer_a; + MatrixGuardBuffer buffer_bf; + MatrixGuardBuffer buffer_bq; + MatrixGuardBuffer buffer_c; + MatrixGuardBuffer buffer_c_ref; + + public: + void Test(size_t M, size_t N, size_t K, size_t BatchSize) { + // Currently, MlasDynamicQGemmBatch() and associated functions require SME or else they are no-ops. + if (!MLAS_CPUIDINFO::GetCPUIDInfo().HasArm_SME()) { + GTEST_SKIP() << "MlasDynamicQGemmBatch() requires ARM64 SME but it was not detected. Skipping test."; + } + + // Setup buffers for holding various data + + float* A = buffer_a.GetBuffer(M * K * BatchSize); + // Buffer for holding floating point version of weight matrix + float* Bf = buffer_bf.GetBuffer(K * N * BatchSize); + // Buffer for holding quantized version of weight matrix + int8_t* Bq = buffer_bq.GetBuffer(K * N * BatchSize); + float* C = buffer_c.GetBuffer(M * N * BatchSize); + float* CRef = buffer_c_ref.GetBuffer(M * N * BatchSize); + + // Initialize A and Bf + for (size_t i = 0; i < M * K * BatchSize; ++i) + A[i] = static_cast((rand() % 255 - 128) / 16.0f); + for (size_t i = 0; i < K * N * BatchSize; ++i) + Bf[i] = static_cast((rand() % 255 - 128) / 16.0f); + + // Quantize Bf → Bq and compute per-column scale and bias per batch + std::vector> b_scale_batches(BatchSize, std::vector(N)); + std::vector> b_bias_batches(BatchSize, std::vector(N, 0.0f)); + + for (size_t b = 0; b < BatchSize; ++b) { + for (size_t n = 0; n < N; ++n) { + float min_val = Bf[b * K * N + n]; + float max_val = min_val; + for (size_t k = 1; k < K; ++k) { + float v = Bf[b * K * N + k * N + n]; + min_val = std::min(min_val, v); + max_val = std::max(max_val, v); + } + float scale = (max_val - min_val) / 255.0f; + if (scale < 1e-8f) scale = 1.0f; + b_scale_batches[b][n] = scale; + + for (size_t k = 0; k < K; ++k) { + float v = Bf[b * K * N + k * N + n]; + int q = static_cast(std::round(v / scale)); + Bq[b * K * N + k * N + n] = static_cast(std::clamp(q, -128, 127)); + } + } + } + + // Prepare kernel parameters + MLAS_GEMM_DYN_QUANT_SHAPE_PARAMS shape{M, N, K}; + std::vector packed_b_storage(BatchSize * MlasDynamicQgemmPackBSize(N, K)); + std::vector params(BatchSize); + + for (size_t b = 0; b < BatchSize; ++b) { + params[b].A = A + b * M * K; + params[b].lda = K; + params[b].C = C + b * M * N; + params[b].ldc = N; + // Pack b matrix using MlasDynamicQgemmPackBSize & MlasDynamicQgemmPackB + void* packed_b = packed_b_storage.data() + b * MlasDynamicQgemmPackBSize(N, K); + MlasDynamicQgemmPackB(N, K, + Bq + b * K * N, + b_scale_batches[b].data(), + b_bias_batches[b].data(), + packed_b); + params[b].PackedB = packed_b; + } + + // call MlasDynamicQGemmBatch Function + MlasDynamicQGemmBatch(shape, params.data(), BatchSize, nullptr); + + // Compute reference result + for (size_t b = 0; b < BatchSize; ++b) { + for (size_t m = 0; m < M; ++m) { + for (size_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (size_t k = 0; k < K; ++k) { + float a = A[b * M * K + m * K + k]; + float bval = static_cast(Bq[b * K * N + k * N + n]) * b_scale_batches[b][n]; + sum += a * bval; + } + CRef[b * M * N + m * N + n] = sum; + } + } + } + + // Validate results + for (size_t i = 0; i < M * N * BatchSize; ++i) { + float abs_c_ref = std::abs(CRef[i]); + float dynamic_rel_tol = (K <= 4) ? 0.05f : 0.03f; + float rel_tol = dynamic_rel_tol * std::max(abs_c_ref, 1.0f); + float abs_tol = 3.0f; + float allowed = std::max(rel_tol, abs_tol); + float diff = std::abs(C[i] - CRef[i]); + ASSERT_LE(diff, allowed); + } + } + + static const char* GetTestSuiteName() { + return "DynamicQgemm"; + } +}; + +class DynamicQgemmExecuteTest : public MlasTestFixture { + public: + DynamicQgemmExecuteTest(size_t M, size_t N, size_t K, size_t BatchSize) + : M_(M), N_(N), K_(K), BatchSize_(BatchSize) {} + + void TestBody() override { + this->mlas_tester->Test(M_, N_, K_, BatchSize_); + } + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t BatchSize) { + std::stringstream ss; + ss << "M" << M << "_N" << N << "_K" << K << "_B" << BatchSize; + + std::string test_name = ss.str(); + + testing::RegisterTest( + MlasDynamicQgemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + [=]() -> MlasTestFixture* { + return new DynamicQgemmExecuteTest(M, N, K, BatchSize); + }); + + return 1; + } + + static size_t RegisterAll(bool is_short_execute) { + const std::vector batch_size = is_short_execute ? std::vector{1UL, 2UL, 4UL} + : std::vector{1UL, 2UL, 4UL, 8UL, 16UL, 32UL, 64UL}; + size_t count = 0; + const size_t sizes[] = {1, 4, 8, 16, 32, 64}; + for (size_t M : sizes) + for (size_t N : sizes) + for (size_t K : sizes) + for (size_t B : batch_size) + count += RegisterSingleTest(M, N, K, B); + return count; + } + + private: + size_t M_, N_, K_, BatchSize_; +}; + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + return DynamicQgemmExecuteTest::RegisterAll(is_short_execute); +}); +#endif diff --git a/tests/unittest/test_fgemm.h b/tests/unittest/test_fgemm.h index 347e9c6..6d70151 100644 --- a/tests/unittest/test_fgemm.h +++ b/tests/unittest/test_fgemm.h @@ -56,7 +56,7 @@ class FgemmPackedContext { } }; -#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_POWER) +#if defined(MLAS_TARGET_AMD64) || defined(MLAS_TARGET_POWER) || defined(MLAS_TARGET_S390X) template <> class FgemmPackedContext { public: @@ -112,11 +112,11 @@ class FgemmPackedContext { float* C, size_t ldc, MLAS_THREADPOOL* threadpool) { - size_t PackedBSize = MlasGemmPackBSize(N, K); + size_t PackedBSize = MlasGemmPackBSize(TransA, TransB, N, K); void* PackedB = BufferBPacked.GetBuffer(PackedBSize * BatchSize, true); std::vector data(BatchSize); for (size_t i = 0; i < BatchSize; i++) { - MlasGemmPackB(TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); + MlasGemmPackB(TransA, TransB, N, K, B + K * N * i, ldb, (uint8_t*)PackedB + PackedBSize * i); data[i].BIsPacked = true; data[i].A = A + M * K * i; data[i].lda = lda; @@ -195,19 +195,14 @@ class MlasFgemmTest : public MlasTestBase { std::fill_n(C, M * N * BatchSize, -0.5f); std::fill_n(CReference, M * N * BatchSize, -0.5f); - static constexpr float rtol = 1e-5f; - static constexpr float atol = 1e-8f; - PackedContext.TestGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, C, ldc, threadpool_); ReferenceGemm(TransA, TransB, M, N, K, BatchSize, alpha, A, lda, B, ldb, beta, CReference, ldc); for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { for (size_t m = 0; m < M; m++) { for (size_t n = 0; n < N; n++, f++) { - T tolerance = atol + rtol * std::abs(CReference[f]); - // Sensitive to comparing positive/negative zero. - ASSERT_NEAR(C[f], CReference[f], tolerance) + ASSERT_EQ(C[f], CReference[f]) << " Diff @[" << batch << ", " << m << ", " << n << "] f=" << f << ", " << (Packed ? "Packed" : "NoPack") << "." << (Threaded ? "SingleThread" : "Threaded") << "/" diff --git a/tests/unittest/test_fgemm_fixture.h b/tests/unittest/test_fgemm_fixture.h index 53b3eda..c832ca6 100644 --- a/tests/unittest/test_fgemm_fixture.h +++ b/tests/unittest/test_fgemm_fixture.h @@ -70,6 +70,7 @@ class FgemmShortExecuteTest : public MlasTestFixture #include +#include "test_fp16.h" /** * @brief Test class for half precision GEMM diff --git a/tests/unittest/test_main.cpp b/tests/unittest/test_main.cpp index 3a11b91..505c0c0 100644 --- a/tests/unittest/test_main.cpp +++ b/tests/unittest/test_main.cpp @@ -57,7 +57,6 @@ bool AddTestRegister(TestRegister test_register) { } int main(int argc, char** argv) { - bool is_short_execute = (argc <= 1 || strcmp("--long", argv[1]) != 0); std::cout << "-------------------------------------------------------" << std::endl; if (is_short_execute) { diff --git a/tests/unittest/test_rope.cpp b/tests/unittest/test_rope.cpp index 9f08970..3dd6e5e 100644 --- a/tests/unittest/test_rope.cpp +++ b/tests/unittest/test_rope.cpp @@ -15,8 +15,7 @@ Module Name: --*/ #include "test_util.h" -#include "mlas.h" -#include "core/framework/float16.h" +#include "mlasi.h" #include "rotary_embedding.h" using namespace onnxruntime; diff --git a/tests/unittest/test_scaleoutput.cpp b/tests/unittest/test_scaleoutput.cpp index 13de844..0b58e48 100644 --- a/tests/unittest/test_scaleoutput.cpp +++ b/tests/unittest/test_scaleoutput.cpp @@ -22,7 +22,7 @@ class MlasScaleOutputTest : public MlasTestBase { std::numeric_limits::max()); for (size_t s = 0; s < M * N; s++) { - Input[s] = int_distribution(generator); //It could be zero + Input[s] = int_distribution(generator); // It could be zero Output[s] = OutputRef[s] = real_distribution(generator); } @@ -52,8 +52,8 @@ class MlasScaleOutputTest : public MlasTestBase { constexpr float epsilon = 1e-6f; for (size_t n = 0; n < M * N; n++) { - float outvalue = OutputRef[n]; // When `AccumulateMode` is false, there is a high chance that this value could be zero - float diff = std::fabs(Output[n] - outvalue) ; + float outvalue = OutputRef[n]; // When `AccumulateMode` is false, there is a high chance that this value could be zero + float diff = std::fabs(Output[n] - outvalue); if (outvalue != 0) { diff /= outvalue; } diff --git a/tests/unittest/test_softmax.cpp b/tests/unittest/test_softmax.cpp index df0c7f6..f07a0c3 100644 --- a/tests/unittest/test_softmax.cpp +++ b/tests/unittest/test_softmax.cpp @@ -152,7 +152,7 @@ class MlasSoftmaxTest : public MlasTestBase { } void Test(const float* Input, float* Output, float* OutputReference, size_t N, size_t D, bool LogSoftmax, bool SmoothSoftmax) { - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); ReferenceSoftmax(Input, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 1e-6f; @@ -206,7 +206,7 @@ class MlasSoftmaxTest : public MlasTestBase { InputReference[nd] = Input[nd].ToFloat(); } - MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, threadpool_); + MlasComputeSoftmax(Input, Output, N, D, LogSoftmax, SmoothSoftmax, 0.0f, threadpool_); ReferenceSoftmax(InputReference, OutputReference, N, D, LogSoftmax, SmoothSoftmax); constexpr float AbsoluteTolerance = 5e-3f; diff --git a/tests/unittest/test_sq8bitgemm.cpp b/tests/unittest/test_sq8bitgemm.cpp index 1237cdb..f24ea85 100644 --- a/tests/unittest/test_sq8bitgemm.cpp +++ b/tests/unittest/test_sq8bitgemm.cpp @@ -19,7 +19,7 @@ Module Name: #include "test_util.h" #include "mlasi.h" -#include "core/mlas/inc/mlas_q4.h" +#include "mlas_q4.h" #include "qnbitgemm.h" #include "mlas_qnbit.h" @@ -31,10 +31,156 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { std::uniform_real_distribution distrib_f32_; MatrixGuardBuffer inputB_, inputZp_, refB_, packedBuffer_; MatrixGuardBuffer inputScale_, refScale_; - MatrixGuardBuffer inputBlkSum_, refBlkSum_; + MatrixGuardBuffer inputBlkSum_, refBlkSum_, refBlkUnsignedQuantAZeroPointCorrection_; +#ifdef MLAS_TARGET_ARM64 template - void PrepackB(const uint8_t* src, uint8_t* dst) { + void PrepackB(const uint8_t* src, uint8_t* dst, float* refBlkUnsignedQuantAZeroPointCorrection) { + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n / 8 * 8 * ldb + k / 4 * 4 * 8 + (n % 8) * 4 + k % 4; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n / 4 * 4 * ldb + k / 4 * 4 * 4 + (n % 4) * 4 + k % 4; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + for (; n < N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t src_idx = n * ldb + k; + size_t dst_idx = n * ldb + k; + size_t blkSum_idx = n / 16 * 16 * BlkCount + k / BlkLen * 16 + n % 16; + dst[dst_idx] = src[src_idx]; + if (refBlkUnsignedQuantAZeroPointCorrection) { + refBlkUnsignedQuantAZeroPointCorrection[blkSum_idx] += src[src_idx]; + } + } + } + } + + template + void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum, float* refBlkUnsignedQuantAZeroPointCorrection) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n / 8 * 8 * BlkCount + k * 8 + n % 8; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n / 4 * 4 * BlkCount + k * 4 + n % 4; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t src_idx = n * BlkCount + k; + size_t scale_dst_idx = n * BlkCount + k; + size_t sum_dst_idx = n / 16 * 16 * BlkCount + k * 16 + n % 16; + float zp_val = (zp ? static_cast(zp[src_idx]) : 128.f); + float vSum = -scale[src_idx] * zp_val; + packedScale[scale_dst_idx] = scale[src_idx]; + blkSum[sum_dst_idx] = vSum; + if (refBlkUnsignedQuantAZeroPointCorrection) { + float vSum2 = -refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] + zp_val * std::min(BlkLen, K - k * BlkLen); + refBlkUnsignedQuantAZeroPointCorrection[sum_dst_idx] = vSum2 * scale[src_idx]; + } + } + } + } + + template + void CheckB(const uint8_t* packedB, const uint8_t* refB) { + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n / 8 * 8 * ldb + k / 4 * 4 * 8 + (n % 8) * 4 + k % 4; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n / 4 * 4 * ldb + k / 4 * 4 * 4 + (n % 4) * 4 + k % 4; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < K; ++k) { + size_t idx = n * ldb + k; + ASSERT_EQ(packedB[idx], refB[idx]) << " at n=" << n << " k=" << k; + } + } + } + + template + void CheckScale(const float* packedScale, const float* refScale) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + size_t n = 0; + for (; n - n % 8 + 8 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n / 8 * 8 * BlkCount + k * 8 + n % 8; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n - n % 4 + 4 <= N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n / 4 * 4 * BlkCount + k * 4 + n % 4; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + + for (; n < N; ++n) { + for (size_t k = 0; k < BlkCount; ++k) { + size_t idx = n * BlkCount + k; + ASSERT_EQ(packedScale[idx], refScale[idx]) << " at n=" << n << " k=" << k; + } + } + } +#else // not MLAS_TARGET_ARM64 + template + void PrepackB(const uint8_t* src, uint8_t* dst, float* blkUnsignedQuantAZeroPointCorrection) { + MLAS_UNREFERENCED_PARAMETER(blkUnsignedQuantAZeroPointCorrection); + constexpr size_t ldb = (K + BlkLen - 1) & (~(BlkLen - 1)); size_t n = 0; for (; n + 4 <= N; n += 4) { @@ -65,7 +211,9 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { } template - void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum) { + void PrepackBlkSumAndScale(const float* scale, const uint8_t* zp, float* packedScale, float* blkSum, float* blkUnsignedQuantAZeroPointCorrection) { + MLAS_UNREFERENCED_PARAMETER(blkUnsignedQuantAZeroPointCorrection); + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; constexpr size_t BlkPerSubBlk = SubBlkLen > BlkLen ? SubBlkLen / BlkLen : 1; @@ -174,10 +322,15 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { } } } +#endif // MLAS_TARGET_ARM64 template void CheckBlkSum(const float* packedBlkSum, const float* refBlkSum) { - size_t BlkCount = (K + BlkLen - 1) / BlkLen; + if (refBlkSum == nullptr) { + return; + } + + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; for (size_t n = 0; n < N; ++n) { for (size_t k = 0; k < BlkCount; ++k) { @@ -198,6 +351,7 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { constexpr size_t PackBCount = N * Ldb; constexpr size_t ScaleCount = BlkCount * N; const size_t BufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, Bits, BlkLen, hasZp, SQNBIT_CompInt8); + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; const auto* inputB = inputB_.GetFilledBuffer(PackBCount, [this](uint8_t* p, size_t t) { for (size_t i = 0; i < t; i++) { @@ -222,25 +376,36 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { auto* refB = refB_.GetBuffer(PackBCount, true); auto* refScale = refScale_.GetBuffer(ScaleCount, true); auto* refBlkSum = refBlkSum_.GetBuffer(((N + 15) & (~15)) * BlkCount, true); + auto* refBlkUnsignedQuantAZeroPointCorrection = isQuantAUnsigned ? refBlkUnsignedQuantAZeroPointCorrection_.GetBuffer(((N + 15) & (~15)) * BlkCount, true) : nullptr; + + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen, isQuantAUnsigned); + + // Models the packing calls from MatmulNBits operator - we will have 3 separate calls + // for 3 different inputs in the Prepack() function + // The first call prepacks the quantized weights (and accumulates necessary metadata for BlkUnsignedQuantAZeroPointCorrection). + // The second call prepacks the scales. + // The third call prepacks the zero points. + // The inputScale and zero points will be ignored while prepacking the weights (if they are provided). MlasQNBitGemmPackQuantBData( N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, - inputScale, hasZp, nullptr, nullptr); + inputScale, hasZp, inputZp, nullptr); + MlasQNBitGemmPackQuantBData( N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, inputScale, hasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( N, K, Bits, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, nullptr, hasZp, inputZp, nullptr); - PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen); + PrepackB(inputB, refB, refBlkUnsignedQuantAZeroPointCorrection); + PrepackBlkSumAndScale(inputScale, inputZp, refScale, refBlkSum, refBlkUnsignedQuantAZeroPointCorrection); - PrepackB(inputB, refB); - PrepackBlkSumAndScale(inputScale, inputZp, refScale, refBlkSum); - - CheckB(refB, reinterpret_cast(packedQuantB.PackedQuantBData)); - CheckScale(refScale, packedQuantB.PackedQuantBScale); - CheckBlkSum(refBlkSum, packedQuantB.QuantBBlkSum); + CheckB(reinterpret_cast(packedQuantB.PackedQuantBData), refB); + CheckScale(packedQuantB.PackedQuantBScale, refScale); + CheckBlkSum(packedQuantB.QuantBBlkSum, refBlkSum); + CheckBlkSum(packedQuantB.BlkUnsignedQuantAZeroPointCorrection, refBlkUnsignedQuantAZeroPointCorrection); } public: @@ -298,31 +463,203 @@ class MlasSQ8BitPrepackTest : public MlasTestBase { Execute<1, 1, 256, 64>(); Execute<16, 4, 16, 64>(); - Execute<32, 4, 16, 64>(); - Execute<64, 4, 16, 64>(); - Execute<128, 4, 16, 64>(); + Execute<32, 8, 16, 64>(); + Execute<64, 12, 32, 64>(); + Execute<128, 16, 64, 64>(); - Execute<15, 5, 16, 64>(); - Execute<15, 5, 32, 64>(); + Execute<15, 3, 16, 64>(); + Execute<15, 4, 32, 64>(); Execute<15, 5, 64, 64>(); - Execute<15, 5, 128, 64>(); - Execute<15, 5, 256, 64>(); - + Execute<15, 6, 128, 64>(); + Execute<15, 7, 256, 64>(); + Execute<15, 8, 16, 64>(); + Execute<15, 9, 16, 64>(); + + Execute<17, 3, 16, 64>(); + Execute<17, 4, 32, 64>(); + Execute<17, 5, 64, 64>(); + Execute<17, 6, 128, 64>(); + Execute<17, 7, 256, 64>(); Execute<17, 8, 16, 64>(); - Execute<17, 8, 32, 64>(); - Execute<17, 8, 64, 64>(); - Execute<17, 8, 128, 64>(); - Execute<17, 8, 256, 64>(); + Execute<17, 9, 16, 64>(); Execute<159, 16, 16, 64>(); Execute<160, 17, 32, 64>(); Execute<161, 15, 64, 64>(); Execute<160, 17, 128, 64>(); Execute<159, 16, 256, 64>(); + Execute<3072, 128, 16, 64>(); } } }; +class MlasSQ8BitQuantAKernelTest : public MlasTestBase { + private: + unsigned int seed_; + std::mt19937 gen_; // mersenne_twister_engine seeded with rd() + std::uniform_int_distribution distrib_u8_; + std::uniform_real_distribution distrib_f32_; + MatrixGuardBuffer workspace_, refQuantA_; + MatrixGuardBuffer inputA_, refScale_, refBlkSum_; + + template + void QuantA(const float* inputA, uint8_t* quantA, float* scalePtr, float* blkSum, bool quantAUnsigned) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t input_lda = K; + + constexpr size_t Bits = 8; + constexpr size_t output_lda = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < BlkCount; ++j) { + float vAbsMax = 0.f; + for (size_t k = 0; k < std::min(BlkLen, K - j * BlkLen); ++k) { + size_t input_idx = i * input_lda + j * BlkLen + k; + vAbsMax = std::max(vAbsMax, fabsf(inputA[input_idx])); + } + + float scale = vAbsMax / 127.f; + float invScale = vAbsMax == 0.f ? 0.f : 127.f / vAbsMax; + scalePtr[i * BlkCount + j] = scale; + + float vSum = 0.f; + for (size_t k = 0; k < BlkLen; ++k) { + size_t input_idx = i * input_lda + j * BlkLen + k; + size_t output_idx = i * output_lda + j * BlkLen + k; + if (k < std::min(BlkLen, K - j * BlkLen)) { + const auto input_val = inputA[input_idx]; + // Round to nearest, ties away from zero + // float v = std::clamp(std::roundf(input_val * invScale), -128.f, 127.f); + + // Round to nearest, ties to even + float v = std::clamp(std::nearbyint(input_val * invScale), -128.f, 127.f); + + if (quantAUnsigned) { + quantA[output_idx] = static_cast(v + 128.f); + vSum += v + 128.f; + } else { + reinterpret_cast(quantA)[output_idx] = static_cast(v); + vSum += v; + } + } else { + quantA[output_idx] = 0; + } + } + blkSum[i * BlkCount + j] = vSum * scale; + } + } + } + + template + void CheckQuantA(const uint8_t* quantA, const uint8_t* refQuantA) { + constexpr size_t lda = (K + BlkLen - 1) & (~(BlkLen - 1)); + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < lda; ++j) { + size_t idx = i * lda + j; + ASSERT_EQ(quantA[idx], refQuantA[idx]) << " at i=" << i << " j=" << j; + } + } + } + + template + void CheckScale(const float* scale, const float* refScale) { + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + for (size_t i = 0; i < M; ++i) { + for (size_t j = 0; j < BlkCount; ++j) { + size_t idx = i * BlkCount + j; + ASSERT_EQ(scale[idx], refScale[idx]) << " at i=" << i << " j=" << j; + } + } + } + + template + void TestQuantA() { + if (!MlasIsQNBitGemmAvailable(8, BlkLen, SQNBIT_CompInt8)) return; + + const auto* dispatch = GetMlasPlatform().QNBitGemmDispatch; + constexpr size_t Bits = 8; + constexpr size_t BlkCount = (K + BlkLen - 1) / BlkLen; + constexpr size_t Lda = (((K + BlkLen - 1) & (~(BlkLen - 1))) * Bits + 7) / 8; + constexpr size_t PackACount = M * Lda; + constexpr size_t ScaleCount = M * BlkCount; + const size_t BufferSize = MlasQNBitGemmBatchWorkspaceSize(M, 1, K, 1, Bits, BlkLen, true, SQNBIT_CompInt8); + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; + + const auto* inputA = inputA_.GetFilledBuffer(M * K, [this](float* p, size_t t) { + for (size_t i = 0; i < t; i++) { + p[i] = this->distrib_f32_(this->gen_); + } + }); + + auto* workspace = workspace_.GetBuffer(BufferSize, true); + auto* refQuantA = refQuantA_.GetBuffer(PackACount, true); + auto* refScale = refScale_.GetBuffer(ScaleCount, true); + auto* refBlkSum = refBlkSum_.GetBuffer(ScaleCount, true); + + const size_t Alignment = dispatch->QNBitGemmPerGemmWorkspaceAlignment(BlkLen, SQNBIT_CompInt8); + const uintptr_t WorkspaceAddress = reinterpret_cast(workspace); + auto* quantAPtr = reinterpret_cast((WorkspaceAddress + Alignment - 1) & (~(Alignment - 1))); + auto* scaleAPtr = reinterpret_cast(quantAPtr + PackACount); + auto* blkSumAPtr = scaleAPtr + ScaleCount; + + for (size_t i = 0; i < M; ++i) { + dispatch->QuantizeARowComputeBlkSum_CompInt8(BlkLen, inputA + i * K, K, quantAPtr + i * Lda, scaleAPtr + i * BlkCount, blkSumAPtr + i * BlkCount); + } + + QuantA(inputA, refQuantA, refScale, refBlkSum, isQuantAUnsigned); + CheckQuantA(reinterpret_cast(quantAPtr), refQuantA); + CheckScale(scaleAPtr, refScale); + CheckScale(blkSumAPtr, refBlkSum); + } + + public: + MlasSQ8BitQuantAKernelTest() + : seed_(19287), gen_(seed_), distrib_u8_(0, 255), distrib_f32_(-10.f, 10.f) { + } + + static const char* GetTestSuiteName() { + return "SQ8BitQuantA"; + } + + void ExecuteShort(void) override { + TestQuantA<1, 16, 16>(); + TestQuantA<1, 1, 32>(); + TestQuantA<1, 1, 64>(); + TestQuantA<1, 1, 128>(); + TestQuantA<1, 1, 256>(); + + TestQuantA<4, 16, 16>(); + TestQuantA<8, 32, 16>(); + TestQuantA<12, 64, 32>(); + TestQuantA<16, 128, 64>(); + + TestQuantA<3, 15, 16>(); + TestQuantA<4, 15, 32>(); + TestQuantA<5, 15, 64>(); + TestQuantA<6, 15, 128>(); + TestQuantA<7, 15, 256>(); + TestQuantA<8, 15, 16>(); + TestQuantA<9, 15, 16>(); + + TestQuantA<3, 17, 16>(); + TestQuantA<4, 17, 32>(); + TestQuantA<5, 17, 64>(); + TestQuantA<6, 17, 128>(); + TestQuantA<7, 17, 256>(); + TestQuantA<8, 17, 16>(); + TestQuantA<9, 17, 16>(); + + TestQuantA<2, 159, 16>(); + TestQuantA<3, 159, 16>(); + TestQuantA<17, 160, 32>(); + TestQuantA<15, 161, 64>(); + TestQuantA<17, 160, 128>(); + TestQuantA<16, 159, 256>(); + + TestQuantA<1, 3072, 16>(); + } +}; + class MlasSQ8BitGemmKernelTest : public MlasTestBase { private: unsigned int seed_; @@ -383,9 +720,6 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase { } }); - int q_rows, q_cols; - MlasBlockwiseQuantizedShape((int)BlkLen, true, (int)K, (int)N, q_rows, q_cols); - size_t q_data_size_in_bytes, q_scale_size, q_zp_size_in_bytes; MlasBlockwiseQuantizedBufferSizes<8>((int)(BlkLen), true, (int)K, (int)N, q_data_size_in_bytes, q_scale_size, &q_zp_size_in_bytes); @@ -420,24 +754,34 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase { size_t bufferSize = MlasQNBitGemmPackQuantBDataSize(N, K, 8, BlkLen, HasZp, SQNBIT_CompInt8); auto* packedBuffer = packedBuffer_.GetBuffer(bufferSize, true); + // Models the packing calls from MatmulNBits operator - we will have 3 separate calls + // for 3 different inputs in the Prepack() function + // The first call prepacks the quantized weights (and accumulates necessary metadata for BlkUnsignedQuantAZeroPointCorrection). + // The second call prepacks the scales. + // The third call prepacks the zero points. + + // The inputScale and zero points will be ignored while prepacking the weights (if they are provided). MlasQNBitGemmPackQuantBData( N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, inputB, packedBuffer, - inputScale, HasZp, nullptr, nullptr); + inputScale, HasZp, inputZp, nullptr); + MlasQNBitGemmPackQuantBData( N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, inputScale, HasZp, nullptr, nullptr); + MlasQNBitGemmPackQuantBData( N, K, 8, BlkLen, MLAS_QNBIT_GEMM_COMPUTE_TYPE::SQNBIT_CompInt8, nullptr, packedBuffer, nullptr, HasZp, inputZp, nullptr); - PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen); + const bool isQuantAUnsigned = GetMlasPlatform().ArmNeonIsQuantActivationsUnsigned; + PackedQuantBDataStruct packedQuantB(packedBuffer, N, BlkCount, BlkLen, isQuantAUnsigned); auto* C = C_.GetBuffer(M * ldc, true); auto* ref = ref_.GetBuffer(M * ldc, true); - auto* bias = HasBias ? bias_.GetFilledBuffer(N, [this](float* p, size_t t) { + auto* bias = HasBias ? bias_.GetFilledBuffer(N, [](float* p, size_t t) { for (size_t i = 0; i < t; i++) { - p[i] = this->distrib_f32_(this->gen_); + p[i] = (float)(5 + i); } }) : nullptr; @@ -473,14 +817,17 @@ class MlasSQ8BitGemmKernelTest : public MlasTestBase { template void Execute(void) { - TestSQ8BitGemmKernel(); TestSQ8BitGemmKernel(); - TestSQ8BitGemmKernel(); TestSQ8BitGemmKernel(); + + TestSQ8BitGemmKernel(); + TestSQ8BitGemmKernel(); } void ExecuteShort(void) override { + Execute<1, 16, 1, 16>(); Execute<1, 1, 1, 16>(); + Execute<7, 2, 4, 16>(); Execute<7, 128, 4, 16>(); Execute<8, 497, 5, 16>(); Execute<1, 3072, 128, 16>(); @@ -515,6 +862,7 @@ static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_exe size_t count = 0; if (is_short_execute) { count += MlasDirectShortExecuteTests::RegisterShortExecute(); + count += MlasDirectShortExecuteTests::RegisterShortExecute(); count += MlasDirectShortExecuteTests::RegisterShortExecute(); } return count; From b01d9d1450b5b5c71ffb4bb4300180e491d06b2e Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 17 Oct 2025 21:24:51 -0700 Subject: [PATCH 32/33] revert --- CMakePresets.json | 4803 +++++++++++++++++++-------------------------- 1 file changed, 2005 insertions(+), 2798 deletions(-) diff --git a/CMakePresets.json b/CMakePresets.json index d3aff74..261050b 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -1,477 +1,5 @@ { - "$schema": "https://cmake.org/cmake/help/latest/_downloads/3e2d73bff478d88a7de0de736ba5e361/schema.json", - "buildPresets": [ - { - "configuration": "Debug", - "configurePreset": "linux_clang_debug", - "name": "linux_clang_debug" - }, - { - "configuration": "Debug", - "configurePreset": "linux_clang_debug_asan", - "name": "linux_clang_debug_asan" - }, - { - "configuration": "Debug", - "configurePreset": "linux_clang_debug_asan_no_ort", - "name": "linux_clang_debug_asan_no_ort" - }, - { - "configuration": "Debug", - "configurePreset": "linux_clang_debug_cov", - "name": "linux_clang_debug_cov" - }, - { - "configuration": "Debug", - "configurePreset": "linux_clang_debug_cov_no_ort", - "name": "linux_clang_debug_cov_no_ort" - }, - { - "configuration": "Debug", - "configurePreset": "linux_clang_debug_no_ort", - "name": "linux_clang_debug_no_ort" - }, - { - "configuration": "Debug", - "configurePreset": "linux_gcc_debug", - "name": "linux_gcc_debug" - }, - { - "configuration": "Debug", - "configurePreset": "linux_gcc_debug_asan", - "name": "linux_gcc_debug_asan" - }, - { - "configuration": "Debug", - "configurePreset": "linux_gcc_debug_asan_no_ort", - "name": "linux_gcc_debug_asan_no_ort" - }, - { - "configuration": "Debug", - "configurePreset": "linux_gcc_debug_no_ort", - "name": "linux_gcc_debug_no_ort" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel", - "name": "linux_gcc_minsizerel" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel_asan", - "name": "linux_gcc_minsizerel_asan" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel_asan_no_ort", - "name": "linux_gcc_minsizerel_asan_no_ort" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel_no_ort", - "name": "linux_gcc_minsizerel_no_ort" - }, - { - "configuration": "Release", - "configurePreset": "linux_gcc_release", - "name": "linux_gcc_release" - }, - { - "configuration": "Release", - "configurePreset": "linux_gcc_release_asan", - "name": "linux_gcc_release_asan" - }, - { - "configuration": "Release", - "configurePreset": "linux_gcc_release_asan_no_ort", - "name": "linux_gcc_release_asan_no_ort" - }, - { - "configuration": "Release", - "configurePreset": "linux_gcc_release_no_ort", - "name": "linux_gcc_release_no_ort" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo", - "name": "linux_gcc_relwithdebinfo" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo_asan", - "name": "linux_gcc_relwithdebinfo_asan" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo_asan_no_ort", - "name": "linux_gcc_relwithdebinfo_asan_no_ort" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo_no_ort", - "name": "linux_gcc_relwithdebinfo_no_ort" - }, - { - "configuration": "Debug", - "configurePreset": "macos_arm64_debug", - "name": "macos_arm64_debug" - }, - { - "configuration": "Debug", - "configurePreset": "macos_arm64_debug_asan", - "name": "macos_arm64_debug_asan" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "macos_arm64_minsizerel", - "name": "macos_arm64_minsizerel" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "macos_arm64_minsizerel_asan", - "name": "macos_arm64_minsizerel_asan" - }, - { - "configuration": "Release", - "configurePreset": "macos_arm64_release", - "name": "macos_arm64_release" - }, - { - "configuration": "Release", - "configurePreset": "macos_arm64_release_asan", - "name": "macos_arm64_release_asan" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "macos_arm64_relwithdebinfo", - "name": "macos_arm64_relwithdebinfo" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "macos_arm64_relwithdebinfo_asan", - "name": "macos_arm64_relwithdebinfo_asan" - }, - { - "configuration": "Debug", - "configurePreset": "macos_universal2_debug", - "name": "macos_universal2_debug" - }, - { - "configuration": "Debug", - "configurePreset": "macos_universal2_debug_asan", - "name": "macos_universal2_debug_asan" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "macos_universal2_minsizerel", - "name": "macos_universal2_minsizerel" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "macos_universal2_minsizerel_asan", - "name": "macos_universal2_minsizerel_asan" - }, - { - "configuration": "Release", - "configurePreset": "macos_universal2_release", - "name": "macos_universal2_release" - }, - { - "configuration": "Release", - "configurePreset": "macos_universal2_release_asan", - "name": "macos_universal2_release_asan" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "macos_universal2_relwithdebinfo", - "name": "macos_universal2_relwithdebinfo" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "macos_universal2_relwithdebinfo_asan", - "name": "macos_universal2_relwithdebinfo_asan" - }, - { - "configuration": "Debug", - "configurePreset": "macos_x86_64_debug", - "name": "macos_x86_64_debug" - }, - { - "configuration": "Debug", - "configurePreset": "macos_x86_64_debug_asan", - "name": "macos_x86_64_debug_asan" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "macos_x86_64_minsizerel", - "name": "macos_x86_64_minsizerel" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "macos_x86_64_minsizerel_asan", - "name": "macos_x86_64_minsizerel_asan" - }, - { - "configuration": "Release", - "configurePreset": "macos_x86_64_release", - "name": "macos_x86_64_release" - }, - { - "configuration": "Release", - "configurePreset": "macos_x86_64_release_asan", - "name": "macos_x86_64_release_asan" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "macos_x86_64_relwithdebinfo", - "name": "macos_x86_64_relwithdebinfo" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "macos_x86_64_relwithdebinfo_asan", - "name": "macos_x86_64_relwithdebinfo_asan" - }, - { - "configuration": "Debug", - "configurePreset": "windows_arm64_debug", - "name": "windows_arm64_debug" - }, - { - "configuration": "Debug", - "configurePreset": "windows_arm64_debug_asan", - "name": "windows_arm64_debug_asan" - }, - { - "configuration": "Debug", - "configurePreset": "windows_arm64_debug_asan_no_ort", - "name": "windows_arm64_debug_asan_no_ort" - }, - { - "configuration": "Debug", - "configurePreset": "windows_arm64_debug_no_ort", - "name": "windows_arm64_debug_no_ort" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_arm64_minsizerel", - "name": "windows_arm64_minsizerel" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_arm64_minsizerel_asan", - "name": "windows_arm64_minsizerel_asan" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_arm64_minsizerel_asan_no_ort", - "name": "windows_arm64_minsizerel_asan_no_ort" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_arm64_minsizerel_no_ort", - "name": "windows_arm64_minsizerel_no_ort" - }, - { - "configuration": "Release", - "configurePreset": "windows_arm64_release", - "name": "windows_arm64_release" - }, - { - "configuration": "Release", - "configurePreset": "windows_arm64_release_asan", - "name": "windows_arm64_release_asan" - }, - { - "configuration": "Release", - "configurePreset": "windows_arm64_release_asan_no_ort", - "name": "windows_arm64_release_asan_no_ort" - }, - { - "configuration": "Release", - "configurePreset": "windows_arm64_release_no_ort", - "name": "windows_arm64_release_no_ort" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_arm64_relwithdebinfo", - "name": "windows_arm64_relwithdebinfo" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_arm64_relwithdebinfo_asan", - "name": "windows_arm64_relwithdebinfo_asan" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_arm64_relwithdebinfo_asan_no_ort", - "name": "windows_arm64_relwithdebinfo_asan_no_ort" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_arm64_relwithdebinfo_no_ort", - "name": "windows_arm64_relwithdebinfo_no_ort" - }, - { - "configuration": "Debug", - "configurePreset": "windows_win32_debug", - "name": "windows_win32_debug" - }, - { - "configuration": "Debug", - "configurePreset": "windows_win32_debug_asan", - "name": "windows_win32_debug_asan" - }, - { - "configuration": "Debug", - "configurePreset": "windows_win32_debug_asan_no_ort", - "name": "windows_win32_debug_asan_no_ort" - }, - { - "configuration": "Debug", - "configurePreset": "windows_win32_debug_no_ort", - "name": "windows_win32_debug_no_ort" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_win32_minsizerel", - "name": "windows_win32_minsizerel" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_win32_minsizerel_asan", - "name": "windows_win32_minsizerel_asan" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_win32_minsizerel_asan_no_ort", - "name": "windows_win32_minsizerel_asan_no_ort" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_win32_minsizerel_no_ort", - "name": "windows_win32_minsizerel_no_ort" - }, - { - "configuration": "Release", - "configurePreset": "windows_win32_release", - "name": "windows_win32_release" - }, - { - "configuration": "Release", - "configurePreset": "windows_win32_release_asan", - "name": "windows_win32_release_asan" - }, - { - "configuration": "Release", - "configurePreset": "windows_win32_release_asan_no_ort", - "name": "windows_win32_release_asan_no_ort" - }, - { - "configuration": "Release", - "configurePreset": "windows_win32_release_no_ort", - "name": "windows_win32_release_no_ort" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_win32_relwithdebinfo", - "name": "windows_win32_relwithdebinfo" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_win32_relwithdebinfo_asan", - "name": "windows_win32_relwithdebinfo_asan" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_win32_relwithdebinfo_asan_no_ort", - "name": "windows_win32_relwithdebinfo_asan_no_ort" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_win32_relwithdebinfo_no_ort", - "name": "windows_win32_relwithdebinfo_no_ort" - }, - { - "configuration": "Debug", - "configurePreset": "windows_x64_debug", - "name": "windows_x64_debug" - }, - { - "configuration": "Debug", - "configurePreset": "windows_x64_debug_asan", - "name": "windows_x64_debug_asan" - }, - { - "configuration": "Debug", - "configurePreset": "windows_x64_debug_asan_no_ort", - "name": "windows_x64_debug_asan_no_ort" - }, - { - "configuration": "Debug", - "configurePreset": "windows_x64_debug_no_ort", - "name": "windows_x64_debug_no_ort" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_x64_minsizerel", - "name": "windows_x64_minsizerel" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_x64_minsizerel_asan", - "name": "windows_x64_minsizerel_asan" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_x64_minsizerel_asan_no_ort", - "name": "windows_x64_minsizerel_asan_no_ort" - }, - { - "configuration": "MinSizeRel", - "configurePreset": "windows_x64_minsizerel_no_ort", - "name": "windows_x64_minsizerel_no_ort" - }, - { - "configuration": "Release", - "configurePreset": "windows_x64_release", - "name": "windows_x64_release" - }, - { - "configuration": "Release", - "configurePreset": "windows_x64_release_asan", - "name": "windows_x64_release_asan" - }, - { - "configuration": "Release", - "configurePreset": "windows_x64_release_asan_no_ort", - "name": "windows_x64_release_asan_no_ort" - }, - { - "configuration": "Release", - "configurePreset": "windows_x64_release_no_ort", - "name": "windows_x64_release_no_ort" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_x64_relwithdebinfo", - "name": "windows_x64_relwithdebinfo" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_x64_relwithdebinfo_asan", - "name": "windows_x64_relwithdebinfo_asan" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_x64_relwithdebinfo_asan_no_ort", - "name": "windows_x64_relwithdebinfo_asan_no_ort" - }, - { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_x64_relwithdebinfo_no_ort", - "name": "windows_x64_relwithdebinfo_no_ort" - } - ], + "version": 8, "cmakeMinimumRequired": { "major": 3, "minor": 28, @@ -479,2581 +7,2532 @@ }, "configurePresets": [ { - "binaryDir": "${sourceDir}/build/default/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, + "name": "linux_clang_debug", "displayName": "linux clang debug", - "environment": { - "CC": "clang", - "CXX": "clang++" - }, - "generator": "Unix Makefiles", - "name": "linux_clang_debug" - }, - { - "binaryDir": "${sourceDir}/build/asan/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux clang debug asan", - "environment": { - "CC": "clang", - "CXX": "clang++" - }, - "generator": "Unix Makefiles", - "name": "linux_clang_debug_asan" - }, - { - "binaryDir": "${sourceDir}/build/asan/no_ort", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "MLAS_NO_ONNXRUNTIME": "ON" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux clang debug asan no_ort", - "environment": { - "CC": "clang", - "CXX": "clang++" - }, - "generator": "Unix Makefiles", - "name": "linux_clang_debug_asan_no_ort" - }, - { - "binaryDir": "${sourceDir}/build/cov/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux clang debug cov", - "environment": { - "CC": "clang", - "CXX": "clang++" - }, "generator": "Unix Makefiles", - "name": "linux_clang_debug_cov" - }, - { - "binaryDir": "${sourceDir}/build/cov/no_ort", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "MLAS_NO_ONNXRUNTIME": "ON" - }, + "binaryDir": "${sourceDir}/build/default/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" + "rhs": "Linux" }, - "displayName": "linux clang debug cov no_ort", - "environment": { - "CC": "clang", - "CXX": "clang++" - }, - "generator": "Unix Makefiles", - "name": "linux_clang_debug_cov_no_ort" - }, - { - "binaryDir": "${sourceDir}/build/default/no_ort", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_STANDARD": "20", "CMAKE_C_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_CXX_STANDARD": "20" }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux clang debug no_ort", "environment": { "CC": "clang", "CXX": "clang++" - }, - "generator": "Unix Makefiles", - "name": "linux_clang_debug_no_ort" - }, - { - "binaryDir": "${sourceDir}/build/default/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux gcc debug", - "environment": { - "CC": "gcc", - "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_debug" + } }, { - "binaryDir": "${sourceDir}/build/default/no_ort", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "MLAS_NO_ONNXRUNTIME": "ON" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux gcc debug no_ort", - "environment": { - "CC": "gcc", - "CXX": "g++" - }, + "name": "linux_clang_debug_asan", + "displayName": "linux clang debug asan", "generator": "Unix Makefiles", - "name": "linux_gcc_debug_no_ort" - }, - { "binaryDir": "${sourceDir}/build/asan/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux gcc debug asan", - "environment": { - "CC": "gcc", - "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_debug_asan" - }, - { - "binaryDir": "${sourceDir}/build/asan/no_ort", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "MLAS_NO_ONNXRUNTIME": "ON" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux gcc debug asan no_ort", - "environment": { - "CC": "gcc", - "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_debug_asan_no_ort" - }, - { - "binaryDir": "${sourceDir}/build/default/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux gcc minsizerel", - "environment": { - "CC": "gcc", - "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_minsizerel" - }, - { - "binaryDir": "${sourceDir}/build/default/no_ort", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "MLAS_NO_ONNXRUNTIME": "ON" - }, "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" + "rhs": "Linux" }, - "displayName": "linux gcc minsizerel no_ort", - "environment": { - "CC": "gcc", - "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_minsizerel_no_ort" - }, - { - "binaryDir": "${sourceDir}/build/asan/default", "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_CXX_STANDARD": "20" }, - "displayName": "linux gcc minsizerel asan", "environment": { - "CC": "gcc", - "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_minsizerel_asan" + "CC": "clang", + "CXX": "clang++" + } }, { + "name": "linux_clang_debug_asan_no_ort", + "displayName": "linux clang debug asan no ort", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/asan/no_ort", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", "MLAS_NO_ONNXRUNTIME": "ON" }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux gcc minsizerel asan no_ort", "environment": { - "CC": "gcc", - "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_minsizerel_asan_no_ort" + "CC": "clang", + "CXX": "clang++" + } }, { - "binaryDir": "${sourceDir}/build/default/default", + "name": "linux_clang_debug_cov", + "displayName": "linux clang debug cov", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/cov/default", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_CXX_STANDARD": "20" }, + "environment": { + "CC": "clang", + "CXX": "clang++" + } + }, + { + "name": "linux_clang_debug_cov_no_ort", + "displayName": "linux clang debug cov no ort", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/cov/no_ort", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" + "rhs": "Linux" }, - "displayName": "linux gcc release", - "environment": { - "CC": "gcc", - "CXX": "g++" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fprofile-instr-generate -fcoverage-mapping", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_CXX_STANDARD": "20", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "generator": "Unix Makefiles", - "name": "linux_gcc_release" + "environment": { + "CC": "clang", + "CXX": "clang++" + } }, { + "name": "linux_clang_debug_no_ort", + "displayName": "linux clang debug no ort", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default/no_ort", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_C_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_CXX_STANDARD": "20", "MLAS_NO_ONNXRUNTIME": "ON" }, + "environment": { + "CC": "clang", + "CXX": "clang++" + } + }, + { + "name": "linux_gcc_debug", + "displayName": "linux gcc debug", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" + "rhs": "Linux" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_C_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_CXX_STANDARD": "20" }, - "displayName": "linux gcc release no_ort", "environment": { "CC": "gcc", "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_release_no_ort" + } }, { + "name": "linux_gcc_debug_asan", + "displayName": "linux gcc debug asan", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/asan/default", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address -D_GLIBCXX_DEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_CXX_STANDARD": "20" }, - "displayName": "linux gcc release asan", "environment": { "CC": "gcc", "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_release_asan" + } }, { + "name": "linux_gcc_debug_asan_no_ort", + "displayName": "linux gcc debug asan no ort", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/asan/no_ort", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address -D_GLIBCXX_DEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", "MLAS_NO_ONNXRUNTIME": "ON" }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux gcc release asan no_ort", "environment": { "CC": "gcc", "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_release_asan_no_ort" + } }, { - "binaryDir": "${sourceDir}/build/default/default", + "name": "linux_gcc_debug_no_ort", + "displayName": "linux gcc debug no ort", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default/no_ort", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_C_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -D_GLIBCXX_DEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_CXX_STANDARD": "20", + "MLAS_NO_ONNXRUNTIME": "ON" }, - "displayName": "linux gcc relwithdebinfo", "environment": { "CC": "gcc", "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_relwithdebinfo" + } }, { - "binaryDir": "${sourceDir}/build/default/no_ort", + "name": "linux_gcc_minsizerel", + "displayName": "linux gcc minsizerel", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default/default", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_CXX_STANDARD": "20" }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux gcc relwithdebinfo no_ort", "environment": { "CC": "gcc", "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_relwithdebinfo_no_ort" + } }, { + "name": "linux_gcc_minsizerel_asan", + "displayName": "linux gcc minsizerel asan", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/asan/default", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_CXX_STANDARD": "20" }, - "displayName": "linux gcc relwithdebinfo asan", "environment": { "CC": "gcc", "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_relwithdebinfo_asan" + } }, { + "name": "linux_gcc_minsizerel_asan_no_ort", + "displayName": "linux gcc minsizerel asan no ort", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/asan/no_ort", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_CXX_STANDARD": "20", "MLAS_NO_ONNXRUNTIME": "ON" }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Linux", - "type": "equals" - }, - "displayName": "linux gcc relwithdebinfo asan no_ort", "environment": { "CC": "gcc", "CXX": "g++" - }, - "generator": "Unix Makefiles", - "name": "linux_gcc_relwithdebinfo_asan_no_ort" - }, - { - "binaryDir": "${sourceDir}/build/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_OSX_ARCHITECTURES": "arm64" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" - }, - "displayName": "macos arm64 debug", - "generator": "Unix Makefiles", - "name": "macos_arm64_debug" + } }, { - "binaryDir": "${sourceDir}/build/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" - }, + "name": "linux_gcc_minsizerel_no_ort", + "displayName": "linux gcc minsizerel no ort", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default/no_ort", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Linux" }, - "displayName": "macos arm64 debug asan", - "generator": "Unix Makefiles", - "name": "macos_arm64_debug_asan" - }, - { - "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_OSX_ARCHITECTURES": "arm64" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "MLAS_NO_ONNXRUNTIME": "ON" }, - "displayName": "macos arm64 minsizerel", - "generator": "Unix Makefiles", - "name": "macos_arm64_minsizerel" + "environment": { + "CC": "gcc", + "CXX": "g++" + } }, { - "binaryDir": "${sourceDir}/build/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" - }, + "name": "linux_gcc_release", + "displayName": "linux gcc release", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Linux" }, - "displayName": "macos arm64 minsizerel asan", - "generator": "Unix Makefiles", - "name": "macos_arm64_minsizerel_asan" - }, - { - "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_STANDARD": "20", "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_OSX_ARCHITECTURES": "arm64" + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_CXX_STANDARD": "20" }, + "environment": { + "CC": "gcc", + "CXX": "g++" + } + }, + { + "name": "linux_gcc_release_asan", + "displayName": "linux gcc release asan", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/asan/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Linux" }, - "displayName": "macos arm64 release", - "generator": "Unix Makefiles", - "name": "macos_arm64_release" - }, - { - "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "Release", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_CXX_STANDARD": "20" }, - "displayName": "macos arm64 release asan", - "generator": "Unix Makefiles", - "name": "macos_arm64_release_asan" + "environment": { + "CC": "gcc", + "CXX": "g++" + } }, { - "binaryDir": "${sourceDir}/build/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_OSX_ARCHITECTURES": "arm64" - }, + "name": "linux_gcc_release_asan_no_ort", + "displayName": "linux gcc release asan no ort", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/asan/no_ort", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Linux" }, - "displayName": "macos arm64 relwithdebinfo", - "generator": "Unix Makefiles", - "name": "macos_arm64_relwithdebinfo" - }, - { - "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "arm64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "MLAS_NO_ONNXRUNTIME": "ON" }, - "displayName": "macos arm64 relwithdebinfo asan", - "generator": "Unix Makefiles", - "name": "macos_arm64_relwithdebinfo_asan" + "environment": { + "CC": "gcc", + "CXX": "g++" + } }, { - "binaryDir": "${sourceDir}/build/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64" - }, + "name": "linux_gcc_release_no_ort", + "displayName": "linux gcc release no ort", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default/no_ort", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Linux" }, - "displayName": "macos universal2 debug", - "generator": "Unix Makefiles", - "name": "macos_universal2_debug" - }, - { - "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "MLAS_NO_ONNXRUNTIME": "ON" }, - "displayName": "macos universal2 debug asan", - "generator": "Unix Makefiles", - "name": "macos_universal2_debug_asan" + "environment": { + "CC": "gcc", + "CXX": "g++" + } }, { - "binaryDir": "${sourceDir}/build/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64" - }, + "name": "linux_gcc_relwithdebinfo", + "displayName": "linux gcc relwithdebinfo", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Linux" }, - "displayName": "macos universal2 minsizerel", - "generator": "Unix Makefiles", - "name": "macos_universal2_minsizerel" - }, - { - "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_CXX_STANDARD": "20" }, + "environment": { + "CC": "gcc", + "CXX": "g++" + } + }, + { + "name": "linux_gcc_relwithdebinfo_asan", + "displayName": "linux gcc relwithdebinfo asan", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/asan/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Linux" }, - "displayName": "macos universal2 minsizerel asan", - "generator": "Unix Makefiles", - "name": "macos_universal2_minsizerel_asan" - }, - { - "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64" + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_CXX_STANDARD": "20" }, + "environment": { + "CC": "gcc", + "CXX": "g++" + } + }, + { + "name": "linux_gcc_relwithdebinfo_asan_no_ort", + "displayName": "linux gcc relwithdebinfo asan no ort", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/asan/no_ort", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Linux" }, - "displayName": "macos universal2 release", - "generator": "Unix Makefiles", - "name": "macos_universal2_release" - }, - { - "binaryDir": "${sourceDir}/build/default", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack -fsanitize=address", "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" + "MLAS_NO_ONNXRUNTIME": "ON" }, + "environment": { + "CC": "gcc", + "CXX": "g++" + } + }, + { + "name": "linux_gcc_relwithdebinfo_no_ort", + "displayName": "linux gcc relwithdebinfo no ort", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default/no_ort", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Linux" }, - "displayName": "macos universal2 release asan", - "generator": "Unix Makefiles", - "name": "macos_universal2_release_asan" - }, - { - "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-z,now -Wl,-z,noexecstack", "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "MLAS_NO_ONNXRUNTIME": "ON" }, - "displayName": "macos universal2 relwithdebinfo", - "generator": "Unix Makefiles", - "name": "macos_universal2_relwithdebinfo" + "environment": { + "CC": "gcc", + "CXX": "g++" + } }, { + "name": "macos_arm64_debug", + "displayName": "macos arm64 debug", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" - }, "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "macos universal2 relwithdebinfo asan", - "generator": "Unix Makefiles", - "name": "macos_universal2_relwithdebinfo_asan" - }, - { - "binaryDir": "${sourceDir}/build/default", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0", - "CMAKE_CXX_STANDARD": "20", + "CMAKE_OSX_ARCHITECTURES": "arm64", "CMAKE_C_FLAGS": "-ggdb3 -O0", - "CMAKE_OSX_ARCHITECTURES": "x86_64" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" - }, - "displayName": "macos x86_64 debug", - "generator": "Unix Makefiles", - "name": "macos_x86_64_debug" + "CMAKE_CXX_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_STANDARD": "20" + } }, { + "name": "macos_arm64_debug_asan", + "displayName": "macos arm64 debug asan", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + }, "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug", - "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", + "CMAKE_OSX_ARCHITECTURES": "arm64", "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" - }, - "displayName": "macos x86_64 debug asan", - "generator": "Unix Makefiles", - "name": "macos_x86_64_debug_asan" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } }, { + "name": "macos_arm64_minsizerel", + "displayName": "macos arm64 minsizerel", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", - "CMAKE_OSX_ARCHITECTURES": "x86_64" - }, "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "macos x86_64 minsizerel", - "generator": "Unix Makefiles", - "name": "macos_x86_64_minsizerel" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_OSX_ARCHITECTURES": "arm64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20" + } }, { + "name": "macos_arm64_minsizerel_asan", + "displayName": "macos arm64 minsizerel asan", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + }, "cacheVariables": { "CMAKE_BUILD_TYPE": "MinSizeRel", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", + "CMAKE_OSX_ARCHITECTURES": "arm64", "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" - }, - "displayName": "macos x86_64 minsizerel asan", - "generator": "Unix Makefiles", - "name": "macos_x86_64_minsizerel_asan" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } }, { + "name": "macos_arm64_release", + "displayName": "macos arm64 release", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", - "CMAKE_OSX_ARCHITECTURES": "x86_64" - }, "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "macos x86_64 release", - "generator": "Unix Makefiles", - "name": "macos_x86_64_release" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_OSX_ARCHITECTURES": "arm64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_STANDARD": "20" + } }, { + "name": "macos_arm64_release_asan", + "displayName": "macos arm64 release asan", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + }, "cacheVariables": { "CMAKE_BUILD_TYPE": "Release", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", + "CMAKE_OSX_ARCHITECTURES": "arm64", "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" - }, - "displayName": "macos x86_64 release asan", - "generator": "Unix Makefiles", - "name": "macos_x86_64_release_asan" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } }, { + "name": "macos_arm64_relwithdebinfo", + "displayName": "macos arm64 relwithdebinfo", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_CXX_STANDARD": "20", - "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", - "CMAKE_OSX_ARCHITECTURES": "x86_64" - }, "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "macos x86_64 relwithdebinfo", - "generator": "Unix Makefiles", - "name": "macos_x86_64_relwithdebinfo" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_OSX_ARCHITECTURES": "arm64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20" + } }, { + "name": "macos_arm64_relwithdebinfo_asan", + "displayName": "macos arm64 relwithdebinfo asan", + "generator": "Unix Makefiles", "binaryDir": "${sourceDir}/build/default", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + }, "cacheVariables": { "CMAKE_BUILD_TYPE": "RelWithDebInfo", - "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", - "CMAKE_CXX_STANDARD": "20", + "CMAKE_OSX_ARCHITECTURES": "arm64", "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", - "CMAKE_OSX_ARCHITECTURES": "x86_64", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address" - }, + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } + }, + { + "name": "macos_universal2_debug", + "displayName": "macos universal2 debug", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Darwin", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "macos x86_64 relwithdebinfo asan", - "generator": "Unix Makefiles", - "name": "macos_x86_64_relwithdebinfo_asan" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_C_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/debug/default", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" - }, + "name": "macos_universal2_debug_asan", + "displayName": "macos universal2 debug asan", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 debug", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_debug" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/debug/default", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" - }, + "name": "macos_universal2_minsizerel", + "displayName": "macos universal2 minsizerel", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 debug no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_debug_no_ort" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/debug/asan", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" - }, + "name": "macos_universal2_minsizerel_asan", + "displayName": "macos universal2 minsizerel asan", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 debug asan", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_debug_asan" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/debug/asan", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" - }, + "name": "macos_universal2_release", + "displayName": "macos universal2 release", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 debug asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_debug_asan_no_ort" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/minsizerel/default", + "name": "macos_universal2_release_asan", + "displayName": "macos universal2 release asan", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + }, "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } + }, + { + "name": "macos_universal2_relwithdebinfo", + "displayName": "macos universal2 relwithdebinfo", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20" + } + }, + { + "name": "macos_universal2_relwithdebinfo_asan", + "displayName": "macos universal2 relwithdebinfo asan", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 minsizerel", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_minsizerel" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_OSX_ARCHITECTURES": "arm64;x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/minsizerel/default", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" - }, + "name": "macos_x86_64_debug", + "displayName": "macos x86 64 debug", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 minsizerel no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_minsizerel_no_ort" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_C_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/minsizerel/asan", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" - }, + "name": "macos_x86_64_debug_asan", + "displayName": "macos x86 64 debug asan", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 minsizerel asan", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_minsizerel_asan" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_C_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-ggdb3 -O0 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/minsizerel/asan", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" - }, + "name": "macos_x86_64_minsizerel", + "displayName": "macos x86 64 minsizerel", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 minsizerel asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_minsizerel_asan_no_ort" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/release/default", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" - }, + "name": "macos_x86_64_minsizerel_asan", + "displayName": "macos x86 64 minsizerel asan", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 release", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_release" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "MinSizeRel", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -Os -pipe -ggdb3 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/release/default", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" - }, + "name": "macos_x86_64_release", + "displayName": "macos x86 64 release", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 release no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_release_no_ort" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/release/asan", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" - }, + "name": "macos_x86_64_release_asan", + "displayName": "macos x86 64 release asan", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 release asan", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_release_asan" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/release/asan", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" - }, + "name": "macos_x86_64_relwithdebinfo", + "displayName": "macos x86 64 relwithdebinfo", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 release asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_release_asan_no_ort" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/relwithdebinfo/default", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" - }, + "name": "macos_x86_64_relwithdebinfo_asan", + "displayName": "macos x86 64 relwithdebinfo asan", + "generator": "Unix Makefiles", + "binaryDir": "${sourceDir}/build/default", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" + "rhs": "Darwin" }, - "displayName": "windows win32 relwithdebinfo", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_relwithdebinfo" + "cacheVariables": { + "CMAKE_BUILD_TYPE": "RelWithDebInfo", + "CMAKE_OSX_ARCHITECTURES": "x86_64", + "CMAKE_C_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_CXX_FLAGS": "-DNDEBUG -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -O3 -pipe -ggdb3 -fsanitize=address", + "CMAKE_EXE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "-fsanitize=address", + "CMAKE_CXX_STANDARD": "20" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/relwithdebinfo/default", + "name": "windows_win32_debug", + "displayName": "windows win32 debug", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/debug/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows win32 relwithdebinfo no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_relwithdebinfo_no_ort" + "rhs": "Windows" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "name": "windows_win32_debug_asan", + "displayName": "windows win32 debug asan", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/debug/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows win32 relwithdebinfo asan", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_relwithdebinfo_asan" + "rhs": "Windows" + } }, { - "architecture": "Win32", - "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "name": "windows_win32_debug_asan_no_ort", + "displayName": "windows win32 debug asan no ort", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/debug/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows win32 relwithdebinfo asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_win32_relwithdebinfo_asan_no_ort" + "rhs": "Windows" + } }, { - "architecture": "x64", + "name": "windows_win32_debug_no_ort", + "displayName": "windows win32 debug no ort", + "generator": "Visual Studio 17 2022", "binaryDir": "${sourceDir}/build/debug/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "MLAS_NO_ONNXRUNTIME": "ON", "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 debug", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_debug" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/debug/default", + "name": "windows_win32_minsizerel", + "displayName": "windows win32 minsizerel", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/minsizerel/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 debug no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_debug_no_ort" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/debug/asan", + "name": "windows_win32_minsizerel_asan", + "displayName": "windows win32 minsizerel asan", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/minsizerel/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 debug asan", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_debug_asan" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/debug/asan", + "name": "windows_win32_minsizerel_asan_no_ort", + "displayName": "windows win32 minsizerel asan no ort", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/minsizerel/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 debug asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_debug_asan_no_ort" + "rhs": "Windows" + } }, { - "architecture": "x64", + "name": "windows_win32_minsizerel_no_ort", + "displayName": "windows win32 minsizerel no ort", + "generator": "Visual Studio 17 2022", "binaryDir": "${sourceDir}/build/minsizerel/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "MLAS_NO_ONNXRUNTIME": "ON", "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 minsizerel", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_minsizerel" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/minsizerel/default", + "name": "windows_win32_release", + "displayName": "windows win32 release", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/release/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 minsizerel no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_minsizerel_no_ort" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/minsizerel/asan", + "name": "windows_win32_release_asan", + "displayName": "windows win32 release asan", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/release/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 minsizerel asan", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_minsizerel_asan" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/minsizerel/asan", + "name": "windows_win32_release_asan_no_ort", + "displayName": "windows win32 release asan no ort", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/release/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 minsizerel asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_minsizerel_asan_no_ort" + "rhs": "Windows" + } }, { - "architecture": "x64", + "name": "windows_win32_release_no_ort", + "displayName": "windows win32 release no ort", + "generator": "Visual Studio 17 2022", "binaryDir": "${sourceDir}/build/release/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "MLAS_NO_ONNXRUNTIME": "ON", "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 release", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_release" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/release/default", + "name": "windows_win32_relwithdebinfo", + "displayName": "windows win32 relwithdebinfo", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/relwithdebinfo/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 release no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_release_no_ort" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/release/asan", + "name": "windows_win32_relwithdebinfo_asan", + "displayName": "windows win32 relwithdebinfo asan", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 release asan", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_release_asan" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/release/asan", + "name": "windows_win32_relwithdebinfo_asan_no_ort", + "displayName": "windows win32 relwithdebinfo asan no ort", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 release asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_release_asan_no_ort" + "rhs": "Windows" + } }, { - "architecture": "x64", + "name": "windows_win32_relwithdebinfo_no_ort", + "displayName": "windows win32 relwithdebinfo no ort", + "generator": "Visual Studio 17 2022", "binaryDir": "${sourceDir}/build/relwithdebinfo/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "MLAS_NO_ONNXRUNTIME": "ON", "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "Win32", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 relwithdebinfo", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_relwithdebinfo" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/relwithdebinfo/default", + "name": "windows_x64_debug", + "displayName": "windows x64 debug", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/debug/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 relwithdebinfo no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_relwithdebinfo_no_ort" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "name": "windows_x64_debug_asan", + "displayName": "windows x64 debug asan", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/debug/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 relwithdebinfo asan", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_relwithdebinfo_asan" + "rhs": "Windows" + } }, { - "architecture": "x64", - "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", + "name": "windows_x64_debug_asan_no_ort", + "displayName": "windows x64 debug asan no ort", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/debug/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows x64 relwithdebinfo asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_x64_relwithdebinfo_asan_no_ort" + "rhs": "Windows" + } }, { - "architecture": "ARM64", + "name": "windows_x64_debug_no_ort", + "displayName": "windows x64 debug no ort", + "generator": "Visual Studio 17 2022", "binaryDir": "${sourceDir}/build/debug/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "MLAS_NO_ONNXRUNTIME": "ON", "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 debug", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_debug" + "rhs": "Windows" + } }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/debug/default", + "name": "windows_x64_minsizerel", + "displayName": "windows x64 minsizerel", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/minsizerel/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 debug no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_debug_no_ort" + "rhs": "Windows" + } }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/debug/asan", + "name": "windows_x64_minsizerel_asan", + "displayName": "windows x64 minsizerel asan", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/minsizerel/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 debug asan", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_debug_asan" + "rhs": "Windows" + } }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/debug/asan", + "name": "windows_x64_minsizerel_asan_no_ort", + "displayName": "windows x64 minsizerel asan no ort", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/minsizerel/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /Ob0 /Od /RTC1 /fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 debug asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_debug_asan_no_ort" + "rhs": "Windows" + } }, { - "architecture": "ARM64", + "name": "windows_x64_minsizerel_no_ort", + "displayName": "windows x64 minsizerel no ort", + "generator": "Visual Studio 17 2022", "binaryDir": "${sourceDir}/build/minsizerel/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "MLAS_NO_ONNXRUNTIME": "ON", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" + }, + "architecture": "x64", + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "windows_x64_release", + "displayName": "windows x64 release", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/release/default", + "cacheVariables": { + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 minsizerel", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_minsizerel" + "rhs": "Windows" + } }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/minsizerel/default", + "name": "windows_x64_release_asan", + "displayName": "windows x64 release asan", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/release/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 minsizerel no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_minsizerel_no_ort" + "rhs": "Windows" + } }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/minsizerel/asan", + "name": "windows_x64_release_asan_no_ort", + "displayName": "windows x64 release asan no ort", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/release/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 minsizerel asan", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_minsizerel_asan" + "rhs": "Windows" + } }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/minsizerel/asan", + "name": "windows_x64_release_no_ort", + "displayName": "windows x64 release no ort", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/release/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O1 /Ob1 /DNDEBUG /fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 minsizerel asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_minsizerel_asan_no_ort" + "rhs": "Windows" + } }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/release/default", + "name": "windows_x64_relwithdebinfo", + "displayName": "windows x64 relwithdebinfo", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/relwithdebinfo/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 release", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_release" + "rhs": "Windows" + } }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/release/default", + "name": "windows_x64_relwithdebinfo_asan", + "displayName": "windows x64 relwithdebinfo asan", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 release no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_release_no_ort" + "rhs": "Windows" + } }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/release/asan", + "name": "windows_x64_relwithdebinfo_asan_no_ort", + "displayName": "windows x64 relwithdebinfo asan no ort", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 release asan", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_release_asan" + "rhs": "Windows" + } }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/release/asan", + "name": "windows_x64_relwithdebinfo_no_ort", + "displayName": "windows x64 relwithdebinfo no ort", + "generator": "Visual Studio 17 2022", + "binaryDir": "${sourceDir}/build/relwithdebinfo/default", "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob2 /DNDEBUG /fsanitize=address", + "MLAS_NO_ONNXRUNTIME": "ON", + "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", + "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" + "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" }, + "architecture": "x64", "condition": { + "type": "equals", "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 release asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_release_asan_no_ort" + "rhs": "Windows" + } + } + ], + "buildPresets": [ + { + "name": "linux_clang_debug", + "configurePreset": "linux_clang_debug", + "configuration": "Debug" + }, + { + "name": "linux_clang_debug_asan", + "configurePreset": "linux_clang_debug_asan", + "configuration": "Debug" + }, + { + "name": "linux_clang_debug_asan_no_ort", + "configurePreset": "linux_clang_debug_asan_no_ort", + "configuration": "Debug" + }, + { + "name": "linux_clang_debug_cov", + "configurePreset": "linux_clang_debug_cov", + "configuration": "Debug" + }, + { + "name": "linux_clang_debug_cov_no_ort", + "configurePreset": "linux_clang_debug_cov_no_ort", + "configuration": "Debug" + }, + { + "name": "linux_clang_debug_no_ort", + "configurePreset": "linux_clang_debug_no_ort", + "configuration": "Debug" + }, + { + "name": "linux_gcc_debug", + "configurePreset": "linux_gcc_debug", + "configuration": "Debug" + }, + { + "name": "linux_gcc_debug_asan", + "configurePreset": "linux_gcc_debug_asan", + "configuration": "Debug" + }, + { + "name": "linux_gcc_debug_asan_no_ort", + "configurePreset": "linux_gcc_debug_asan_no_ort", + "configuration": "Debug" + }, + { + "name": "linux_gcc_debug_no_ort", + "configurePreset": "linux_gcc_debug_no_ort", + "configuration": "Debug" + }, + { + "name": "linux_gcc_minsizerel", + "configurePreset": "linux_gcc_minsizerel", + "configuration": "MinSizeRel" + }, + { + "name": "linux_gcc_minsizerel_asan", + "configurePreset": "linux_gcc_minsizerel_asan", + "configuration": "MinSizeRel" + }, + { + "name": "linux_gcc_minsizerel_asan_no_ort", + "configurePreset": "linux_gcc_minsizerel_asan_no_ort", + "configuration": "MinSizeRel" + }, + { + "name": "linux_gcc_minsizerel_no_ort", + "configurePreset": "linux_gcc_minsizerel_no_ort", + "configuration": "MinSizeRel" + }, + { + "name": "linux_gcc_release", + "configurePreset": "linux_gcc_release", + "configuration": "Release" + }, + { + "name": "linux_gcc_release_asan", + "configurePreset": "linux_gcc_release_asan", + "configuration": "Release" + }, + { + "name": "linux_gcc_release_asan_no_ort", + "configurePreset": "linux_gcc_release_asan_no_ort", + "configuration": "Release" + }, + { + "name": "linux_gcc_release_no_ort", + "configurePreset": "linux_gcc_release_no_ort", + "configuration": "Release" + }, + { + "name": "linux_gcc_relwithdebinfo", + "configurePreset": "linux_gcc_relwithdebinfo", + "configuration": "RelWithDebInfo" + }, + { + "name": "linux_gcc_relwithdebinfo_asan", + "configurePreset": "linux_gcc_relwithdebinfo_asan", + "configuration": "RelWithDebInfo" + }, + { + "name": "linux_gcc_relwithdebinfo_asan_no_ort", + "configurePreset": "linux_gcc_relwithdebinfo_asan_no_ort", + "configuration": "RelWithDebInfo" + }, + { + "name": "linux_gcc_relwithdebinfo_no_ort", + "configurePreset": "linux_gcc_relwithdebinfo_no_ort", + "configuration": "RelWithDebInfo" + }, + { + "name": "macos_arm64_debug", + "configurePreset": "macos_arm64_debug", + "configuration": "Debug" + }, + { + "name": "macos_arm64_debug_asan", + "configurePreset": "macos_arm64_debug_asan", + "configuration": "Debug" + }, + { + "name": "macos_arm64_minsizerel", + "configurePreset": "macos_arm64_minsizerel", + "configuration": "MinSizeRel" + }, + { + "name": "macos_arm64_minsizerel_asan", + "configurePreset": "macos_arm64_minsizerel_asan", + "configuration": "MinSizeRel" + }, + { + "name": "macos_arm64_release", + "configurePreset": "macos_arm64_release", + "configuration": "Release" + }, + { + "name": "macos_arm64_release_asan", + "configurePreset": "macos_arm64_release_asan", + "configuration": "Release" + }, + { + "name": "macos_arm64_relwithdebinfo", + "configurePreset": "macos_arm64_relwithdebinfo", + "configuration": "RelWithDebInfo" + }, + { + "name": "macos_arm64_relwithdebinfo_asan", + "configurePreset": "macos_arm64_relwithdebinfo_asan", + "configuration": "RelWithDebInfo" + }, + { + "name": "macos_universal2_debug", + "configurePreset": "macos_universal2_debug", + "configuration": "Debug" + }, + { + "name": "macos_universal2_debug_asan", + "configurePreset": "macos_universal2_debug_asan", + "configuration": "Debug" + }, + { + "name": "macos_universal2_minsizerel", + "configurePreset": "macos_universal2_minsizerel", + "configuration": "MinSizeRel" + }, + { + "name": "macos_universal2_minsizerel_asan", + "configurePreset": "macos_universal2_minsizerel_asan", + "configuration": "MinSizeRel" + }, + { + "name": "macos_universal2_release", + "configurePreset": "macos_universal2_release", + "configuration": "Release" + }, + { + "name": "macos_universal2_release_asan", + "configurePreset": "macos_universal2_release_asan", + "configuration": "Release" + }, + { + "name": "macos_universal2_relwithdebinfo", + "configurePreset": "macos_universal2_relwithdebinfo", + "configuration": "RelWithDebInfo" + }, + { + "name": "macos_universal2_relwithdebinfo_asan", + "configurePreset": "macos_universal2_relwithdebinfo_asan", + "configuration": "RelWithDebInfo" + }, + { + "name": "macos_x86_64_debug", + "configurePreset": "macos_x86_64_debug", + "configuration": "Debug" + }, + { + "name": "macos_x86_64_debug_asan", + "configurePreset": "macos_x86_64_debug_asan", + "configuration": "Debug" + }, + { + "name": "macos_x86_64_minsizerel", + "configurePreset": "macos_x86_64_minsizerel", + "configuration": "MinSizeRel" + }, + { + "name": "macos_x86_64_minsizerel_asan", + "configurePreset": "macos_x86_64_minsizerel_asan", + "configuration": "MinSizeRel" + }, + { + "name": "macos_x86_64_release", + "configurePreset": "macos_x86_64_release", + "configuration": "Release" + }, + { + "name": "macos_x86_64_release_asan", + "configurePreset": "macos_x86_64_release_asan", + "configuration": "Release" + }, + { + "name": "macos_x86_64_relwithdebinfo", + "configurePreset": "macos_x86_64_relwithdebinfo", + "configuration": "RelWithDebInfo" + }, + { + "name": "macos_x86_64_relwithdebinfo_asan", + "configurePreset": "macos_x86_64_relwithdebinfo_asan", + "configuration": "RelWithDebInfo" + }, + { + "name": "windows_win32_debug", + "configurePreset": "windows_win32_debug", + "configuration": "Debug" + }, + { + "name": "windows_win32_debug_asan", + "configurePreset": "windows_win32_debug_asan", + "configuration": "Debug" + }, + { + "name": "windows_win32_debug_asan_no_ort", + "configurePreset": "windows_win32_debug_asan_no_ort", + "configuration": "Debug" + }, + { + "name": "windows_win32_debug_no_ort", + "configurePreset": "windows_win32_debug_no_ort", + "configuration": "Debug" }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/relwithdebinfo/default", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 relwithdebinfo", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_relwithdebinfo" + "name": "windows_win32_minsizerel", + "configurePreset": "windows_win32_minsizerel", + "configuration": "MinSizeRel" }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/relwithdebinfo/default", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 relwithdebinfo no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_relwithdebinfo_no_ort" + "name": "windows_win32_minsizerel_asan", + "configurePreset": "windows_win32_minsizerel_asan", + "configuration": "MinSizeRel" }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 relwithdebinfo asan", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_relwithdebinfo_asan" + "name": "windows_win32_minsizerel_asan_no_ort", + "configurePreset": "windows_win32_minsizerel_asan_no_ort", + "configuration": "MinSizeRel" }, { - "architecture": "ARM64", - "binaryDir": "${sourceDir}/build/relwithdebinfo/asan", - "cacheVariables": { - "CMAKE_CXX_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_C_FLAGS": "/EHsc /Qspectre /MP /guard:cf /DWIN32 /D_WINDOWS /DWINAPI_FAMILY=100 /DWINVER=0x0A00 /D_WIN32_WINNT=0x0A00 /DNTDDI_VERSION=0x0A000000 /O2 /Ob1 /DNDEBUG /fsanitize=address", - "CMAKE_EXE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_MODULE_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "CMAKE_SHARED_LINKER_FLAGS_INIT": "/profile /DYNAMICBASE", - "MLAS_NO_ONNXRUNTIME": "ON" - }, - "condition": { - "lhs": "${hostSystemName}", - "rhs": "Windows", - "type": "equals" - }, - "displayName": "windows arm64 relwithdebinfo asan no_ort", - "generator": "Visual Studio 17 2022", - "name": "windows_arm64_relwithdebinfo_asan_no_ort" - } - ], - "testPresets": [ + "name": "windows_win32_minsizerel_no_ort", + "configurePreset": "windows_win32_minsizerel_no_ort", + "configuration": "MinSizeRel" + }, { - "configuration": "Debug", - "configurePreset": "linux_clang_debug", - "name": "linux_clang_debug" + "name": "windows_win32_release", + "configurePreset": "windows_win32_release", + "configuration": "Release" }, { - "configuration": "Debug", - "configurePreset": "linux_clang_debug_asan", - "name": "linux_clang_debug_asan" + "name": "windows_win32_release_asan", + "configurePreset": "windows_win32_release_asan", + "configuration": "Release" }, { - "configuration": "Debug", - "configurePreset": "linux_clang_debug_asan_no_ort", - "name": "linux_clang_debug_asan_no_ort" + "name": "windows_win32_release_asan_no_ort", + "configurePreset": "windows_win32_release_asan_no_ort", + "configuration": "Release" }, { - "configuration": "Debug", - "configurePreset": "linux_clang_debug_cov", - "name": "linux_clang_debug_cov" + "name": "windows_win32_release_no_ort", + "configurePreset": "windows_win32_release_no_ort", + "configuration": "Release" }, { - "configuration": "Debug", - "configurePreset": "linux_clang_debug_cov_no_ort", - "name": "linux_clang_debug_cov_no_ort" + "name": "windows_win32_relwithdebinfo", + "configurePreset": "windows_win32_relwithdebinfo", + "configuration": "RelWithDebInfo" }, { - "configuration": "Debug", - "configurePreset": "linux_clang_debug_no_ort", - "name": "linux_clang_debug_no_ort" + "name": "windows_win32_relwithdebinfo_asan", + "configurePreset": "windows_win32_relwithdebinfo_asan", + "configuration": "RelWithDebInfo" }, { - "configuration": "Debug", - "configurePreset": "linux_gcc_debug", - "name": "linux_gcc_debug" + "name": "windows_win32_relwithdebinfo_asan_no_ort", + "configurePreset": "windows_win32_relwithdebinfo_asan_no_ort", + "configuration": "RelWithDebInfo" }, { - "configuration": "Debug", - "configurePreset": "linux_gcc_debug_asan", - "name": "linux_gcc_debug_asan" + "name": "windows_win32_relwithdebinfo_no_ort", + "configurePreset": "windows_win32_relwithdebinfo_no_ort", + "configuration": "RelWithDebInfo" }, { - "configuration": "Debug", - "configurePreset": "linux_gcc_debug_asan_no_ort", - "name": "linux_gcc_debug_asan_no_ort" + "name": "windows_x64_debug", + "configurePreset": "windows_x64_debug", + "configuration": "Debug" }, { - "configuration": "Debug", - "configurePreset": "linux_gcc_debug_no_ort", - "name": "linux_gcc_debug_no_ort" + "name": "windows_x64_debug_asan", + "configurePreset": "windows_x64_debug_asan", + "configuration": "Debug" }, { - "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel", - "name": "linux_gcc_minsizerel" + "name": "windows_x64_debug_asan_no_ort", + "configurePreset": "windows_x64_debug_asan_no_ort", + "configuration": "Debug" }, { - "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel_asan", - "name": "linux_gcc_minsizerel_asan" + "name": "windows_x64_debug_no_ort", + "configurePreset": "windows_x64_debug_no_ort", + "configuration": "Debug" }, { - "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel_asan_no_ort", - "name": "linux_gcc_minsizerel_asan_no_ort" + "name": "windows_x64_minsizerel", + "configurePreset": "windows_x64_minsizerel", + "configuration": "MinSizeRel" }, { - "configuration": "MinSizeRel", - "configurePreset": "linux_gcc_minsizerel_no_ort", - "name": "linux_gcc_minsizerel_no_ort" + "name": "windows_x64_minsizerel_asan", + "configurePreset": "windows_x64_minsizerel_asan", + "configuration": "MinSizeRel" }, { - "configuration": "Release", - "configurePreset": "linux_gcc_release", - "name": "linux_gcc_release" + "name": "windows_x64_minsizerel_asan_no_ort", + "configurePreset": "windows_x64_minsizerel_asan_no_ort", + "configuration": "MinSizeRel" }, { - "configuration": "Release", - "configurePreset": "linux_gcc_release_asan", - "name": "linux_gcc_release_asan" + "name": "windows_x64_minsizerel_no_ort", + "configurePreset": "windows_x64_minsizerel_no_ort", + "configuration": "MinSizeRel" }, { - "configuration": "Release", - "configurePreset": "linux_gcc_release_asan_no_ort", - "name": "linux_gcc_release_asan_no_ort" + "name": "windows_x64_release", + "configurePreset": "windows_x64_release", + "configuration": "Release" }, { - "configuration": "Release", - "configurePreset": "linux_gcc_release_no_ort", - "name": "linux_gcc_release_no_ort" + "name": "windows_x64_release_asan", + "configurePreset": "windows_x64_release_asan", + "configuration": "Release" }, { - "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo", - "name": "linux_gcc_relwithdebinfo" + "name": "windows_x64_release_asan_no_ort", + "configurePreset": "windows_x64_release_asan_no_ort", + "configuration": "Release" }, { - "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo_asan", - "name": "linux_gcc_relwithdebinfo_asan" + "name": "windows_x64_release_no_ort", + "configurePreset": "windows_x64_release_no_ort", + "configuration": "Release" }, { - "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo_asan_no_ort", - "name": "linux_gcc_relwithdebinfo_asan_no_ort" + "name": "windows_x64_relwithdebinfo", + "configurePreset": "windows_x64_relwithdebinfo", + "configuration": "RelWithDebInfo" }, { - "configuration": "RelWithDebInfo", - "configurePreset": "linux_gcc_relwithdebinfo_no_ort", - "name": "linux_gcc_relwithdebinfo_no_ort" + "name": "windows_x64_relwithdebinfo_asan", + "configurePreset": "windows_x64_relwithdebinfo_asan", + "configuration": "RelWithDebInfo" + }, + { + "name": "windows_x64_relwithdebinfo_asan_no_ort", + "configurePreset": "windows_x64_relwithdebinfo_asan_no_ort", + "configuration": "RelWithDebInfo" }, { + "name": "windows_x64_relwithdebinfo_no_ort", + "configurePreset": "windows_x64_relwithdebinfo_no_ort", + "configuration": "RelWithDebInfo" + } + ], + "testPresets": [ + { + "name": "linux_clang_debug", "configuration": "Debug", - "configurePreset": "macos_arm64_debug", - "name": "macos_arm64_debug" + "configurePreset": "linux_clang_debug" }, { + "name": "linux_clang_debug_asan", "configuration": "Debug", - "configurePreset": "macos_arm64_debug_asan", - "name": "macos_arm64_debug_asan" + "configurePreset": "linux_clang_debug_asan" }, { - "configuration": "MinSizeRel", - "configurePreset": "macos_arm64_minsizerel", - "name": "macos_arm64_minsizerel" + "name": "linux_clang_debug_asan_no_ort", + "configuration": "Debug", + "configurePreset": "linux_clang_debug_asan_no_ort" }, { - "configuration": "MinSizeRel", - "configurePreset": "macos_arm64_minsizerel_asan", - "name": "macos_arm64_minsizerel_asan" + "name": "linux_clang_debug_cov", + "configuration": "Debug", + "configurePreset": "linux_clang_debug_cov" }, { - "configuration": "Release", - "configurePreset": "macos_arm64_release", - "name": "macos_arm64_release" + "name": "linux_clang_debug_cov_no_ort", + "configuration": "Debug", + "configurePreset": "linux_clang_debug_cov_no_ort" }, { - "configuration": "Release", - "configurePreset": "macos_arm64_release_asan", - "name": "macos_arm64_release_asan" + "name": "linux_clang_debug_no_ort", + "configuration": "Debug", + "configurePreset": "linux_clang_debug_no_ort" }, { - "configuration": "RelWithDebInfo", - "configurePreset": "macos_arm64_relwithdebinfo", - "name": "macos_arm64_relwithdebinfo" + "name": "linux_gcc_debug", + "configuration": "Debug", + "configurePreset": "linux_gcc_debug" }, { - "configuration": "RelWithDebInfo", - "configurePreset": "macos_arm64_relwithdebinfo_asan", - "name": "macos_arm64_relwithdebinfo_asan" + "name": "linux_gcc_debug_asan", + "configuration": "Debug", + "configurePreset": "linux_gcc_debug_asan" }, { + "name": "linux_gcc_debug_asan_no_ort", "configuration": "Debug", - "configurePreset": "macos_universal2_debug", - "name": "macos_universal2_debug" + "configurePreset": "linux_gcc_debug_asan_no_ort" }, { + "name": "linux_gcc_debug_no_ort", "configuration": "Debug", - "configurePreset": "macos_universal2_debug_asan", - "name": "macos_universal2_debug_asan" + "configurePreset": "linux_gcc_debug_no_ort" }, { + "name": "linux_gcc_minsizerel", "configuration": "MinSizeRel", - "configurePreset": "macos_universal2_minsizerel", - "name": "macos_universal2_minsizerel" + "configurePreset": "linux_gcc_minsizerel" }, { + "name": "linux_gcc_minsizerel_asan", "configuration": "MinSizeRel", - "configurePreset": "macos_universal2_minsizerel_asan", - "name": "macos_universal2_minsizerel_asan" + "configurePreset": "linux_gcc_minsizerel_asan" + }, + { + "name": "linux_gcc_minsizerel_asan_no_ort", + "configuration": "MinSizeRel", + "configurePreset": "linux_gcc_minsizerel_asan_no_ort" }, { + "name": "linux_gcc_minsizerel_no_ort", + "configuration": "MinSizeRel", + "configurePreset": "linux_gcc_minsizerel_no_ort" + }, + { + "name": "linux_gcc_release", "configuration": "Release", - "configurePreset": "macos_universal2_release", - "name": "macos_universal2_release" + "configurePreset": "linux_gcc_release" }, { + "name": "linux_gcc_release_asan", "configuration": "Release", - "configurePreset": "macos_universal2_release_asan", - "name": "macos_universal2_release_asan" + "configurePreset": "linux_gcc_release_asan" + }, + { + "name": "linux_gcc_release_asan_no_ort", + "configuration": "Release", + "configurePreset": "linux_gcc_release_asan_no_ort" }, { + "name": "linux_gcc_release_no_ort", + "configuration": "Release", + "configurePreset": "linux_gcc_release_no_ort" + }, + { + "name": "linux_gcc_relwithdebinfo", "configuration": "RelWithDebInfo", - "configurePreset": "macos_universal2_relwithdebinfo", - "name": "macos_universal2_relwithdebinfo" + "configurePreset": "linux_gcc_relwithdebinfo" }, { + "name": "linux_gcc_relwithdebinfo_asan", "configuration": "RelWithDebInfo", - "configurePreset": "macos_universal2_relwithdebinfo_asan", - "name": "macos_universal2_relwithdebinfo_asan" + "configurePreset": "linux_gcc_relwithdebinfo_asan" + }, + { + "name": "linux_gcc_relwithdebinfo_asan_no_ort", + "configuration": "RelWithDebInfo", + "configurePreset": "linux_gcc_relwithdebinfo_asan_no_ort" }, { + "name": "linux_gcc_relwithdebinfo_no_ort", + "configuration": "RelWithDebInfo", + "configurePreset": "linux_gcc_relwithdebinfo_no_ort" + }, + { + "name": "macos_arm64_debug", "configuration": "Debug", - "configurePreset": "macos_x86_64_debug", - "name": "macos_x86_64_debug" + "configurePreset": "macos_arm64_debug" }, { + "name": "macos_arm64_debug_asan", "configuration": "Debug", - "configurePreset": "macos_x86_64_debug_asan", - "name": "macos_x86_64_debug_asan" + "configurePreset": "macos_arm64_debug_asan" }, { + "name": "macos_arm64_minsizerel", "configuration": "MinSizeRel", - "configurePreset": "macos_x86_64_minsizerel", - "name": "macos_x86_64_minsizerel" + "configurePreset": "macos_arm64_minsizerel" }, { + "name": "macos_arm64_minsizerel_asan", "configuration": "MinSizeRel", - "configurePreset": "macos_x86_64_minsizerel_asan", - "name": "macos_x86_64_minsizerel_asan" + "configurePreset": "macos_arm64_minsizerel_asan" }, { + "name": "macos_arm64_release", "configuration": "Release", - "configurePreset": "macos_x86_64_release", - "name": "macos_x86_64_release" + "configurePreset": "macos_arm64_release" }, { + "name": "macos_arm64_release_asan", "configuration": "Release", - "configurePreset": "macos_x86_64_release_asan", - "name": "macos_x86_64_release_asan" + "configurePreset": "macos_arm64_release_asan" }, { + "name": "macos_arm64_relwithdebinfo", "configuration": "RelWithDebInfo", - "configurePreset": "macos_x86_64_relwithdebinfo", - "name": "macos_x86_64_relwithdebinfo" + "configurePreset": "macos_arm64_relwithdebinfo" }, { + "name": "macos_arm64_relwithdebinfo_asan", "configuration": "RelWithDebInfo", - "configurePreset": "macos_x86_64_relwithdebinfo_asan", - "name": "macos_x86_64_relwithdebinfo_asan" + "configurePreset": "macos_arm64_relwithdebinfo_asan" }, { + "name": "macos_universal2_debug", "configuration": "Debug", - "configurePreset": "windows_arm64_debug", - "name": "windows_arm64_debug", - "output": { - "outputOnFailure": true - } + "configurePreset": "macos_universal2_debug" }, { + "name": "macos_universal2_debug_asan", "configuration": "Debug", - "configurePreset": "windows_arm64_debug_asan", - "name": "windows_arm64_debug_asan", - "output": { - "outputOnFailure": true - } + "configurePreset": "macos_universal2_debug_asan" }, { - "configuration": "Debug", - "configurePreset": "windows_arm64_debug_asan_no_ort", - "name": "windows_arm64_debug_asan_no_ort", - "output": { - "outputOnFailure": true - } + "name": "macos_universal2_minsizerel", + "configuration": "MinSizeRel", + "configurePreset": "macos_universal2_minsizerel" }, { - "configuration": "Debug", - "configurePreset": "windows_arm64_debug_no_ort", - "name": "windows_arm64_debug_no_ort", - "output": { - "outputOnFailure": true - } + "name": "macos_universal2_minsizerel_asan", + "configuration": "MinSizeRel", + "configurePreset": "macos_universal2_minsizerel_asan" }, { - "configuration": "MinSizeRel", - "configurePreset": "windows_arm64_minsizerel", - "name": "windows_arm64_minsizerel", - "output": { - "outputOnFailure": true - } + "name": "macos_universal2_release", + "configuration": "Release", + "configurePreset": "macos_universal2_release" }, { - "configuration": "MinSizeRel", - "configurePreset": "windows_arm64_minsizerel_asan", - "name": "windows_arm64_minsizerel_asan", - "output": { - "outputOnFailure": true - } + "name": "macos_universal2_release_asan", + "configuration": "Release", + "configurePreset": "macos_universal2_release_asan" }, { - "configuration": "MinSizeRel", - "configurePreset": "windows_arm64_minsizerel_asan_no_ort", - "name": "windows_arm64_minsizerel_asan_no_ort", - "output": { - "outputOnFailure": true - } + "name": "macos_universal2_relwithdebinfo", + "configuration": "RelWithDebInfo", + "configurePreset": "macos_universal2_relwithdebinfo" }, { - "configuration": "MinSizeRel", - "configurePreset": "windows_arm64_minsizerel_no_ort", - "name": "windows_arm64_minsizerel_no_ort", - "output": { - "outputOnFailure": true - } + "name": "macos_universal2_relwithdebinfo_asan", + "configuration": "RelWithDebInfo", + "configurePreset": "macos_universal2_relwithdebinfo_asan" }, { - "configuration": "Release", - "configurePreset": "windows_arm64_release", - "name": "windows_arm64_release", - "output": { - "outputOnFailure": true - } + "name": "macos_x86_64_debug", + "configuration": "Debug", + "configurePreset": "macos_x86_64_debug" }, { - "configuration": "Release", - "configurePreset": "windows_arm64_release_asan", - "name": "windows_arm64_release_asan", - "output": { - "outputOnFailure": true - } + "name": "macos_x86_64_debug_asan", + "configuration": "Debug", + "configurePreset": "macos_x86_64_debug_asan" }, { - "configuration": "Release", - "configurePreset": "windows_arm64_release_asan_no_ort", - "name": "windows_arm64_release_asan_no_ort", - "output": { - "outputOnFailure": true - } + "name": "macos_x86_64_minsizerel", + "configuration": "MinSizeRel", + "configurePreset": "macos_x86_64_minsizerel" }, { - "configuration": "Release", - "configurePreset": "windows_arm64_release_no_ort", - "name": "windows_arm64_release_no_ort", - "output": { - "outputOnFailure": true - } + "name": "macos_x86_64_minsizerel_asan", + "configuration": "MinSizeRel", + "configurePreset": "macos_x86_64_minsizerel_asan" }, { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_arm64_relwithdebinfo", - "name": "windows_arm64_relwithdebinfo", - "output": { - "outputOnFailure": true - } + "name": "macos_x86_64_release", + "configuration": "Release", + "configurePreset": "macos_x86_64_release" }, { - "configuration": "RelWithDebInfo", - "configurePreset": "windows_arm64_relwithdebinfo_asan", - "name": "windows_arm64_relwithdebinfo_asan", - "output": { - "outputOnFailure": true - } + "name": "macos_x86_64_release_asan", + "configuration": "Release", + "configurePreset": "macos_x86_64_release_asan" }, { + "name": "macos_x86_64_relwithdebinfo", "configuration": "RelWithDebInfo", - "configurePreset": "windows_arm64_relwithdebinfo_asan_no_ort", - "name": "windows_arm64_relwithdebinfo_asan_no_ort", - "output": { - "outputOnFailure": true - } + "configurePreset": "macos_x86_64_relwithdebinfo" }, { + "name": "macos_x86_64_relwithdebinfo_asan", "configuration": "RelWithDebInfo", - "configurePreset": "windows_arm64_relwithdebinfo_no_ort", - "name": "windows_arm64_relwithdebinfo_no_ort", - "output": { - "outputOnFailure": true - } + "configurePreset": "macos_x86_64_relwithdebinfo_asan" }, { + "name": "windows_win32_debug", "configuration": "Debug", "configurePreset": "windows_win32_debug", - "name": "windows_win32_debug", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_debug_asan", "configuration": "Debug", "configurePreset": "windows_win32_debug_asan", - "name": "windows_win32_debug_asan", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_debug_asan_no_ort", "configuration": "Debug", "configurePreset": "windows_win32_debug_asan_no_ort", - "name": "windows_win32_debug_asan_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_debug_no_ort", "configuration": "Debug", "configurePreset": "windows_win32_debug_no_ort", - "name": "windows_win32_debug_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_minsizerel", "configuration": "MinSizeRel", "configurePreset": "windows_win32_minsizerel", - "name": "windows_win32_minsizerel", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_minsizerel_asan", "configuration": "MinSizeRel", "configurePreset": "windows_win32_minsizerel_asan", - "name": "windows_win32_minsizerel_asan", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_minsizerel_asan_no_ort", "configuration": "MinSizeRel", "configurePreset": "windows_win32_minsizerel_asan_no_ort", - "name": "windows_win32_minsizerel_asan_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_minsizerel_no_ort", "configuration": "MinSizeRel", "configurePreset": "windows_win32_minsizerel_no_ort", - "name": "windows_win32_minsizerel_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_release", "configuration": "Release", "configurePreset": "windows_win32_release", - "name": "windows_win32_release", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_release_asan", "configuration": "Release", "configurePreset": "windows_win32_release_asan", - "name": "windows_win32_release_asan", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_release_asan_no_ort", "configuration": "Release", "configurePreset": "windows_win32_release_asan_no_ort", - "name": "windows_win32_release_asan_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_release_no_ort", "configuration": "Release", "configurePreset": "windows_win32_release_no_ort", - "name": "windows_win32_release_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_relwithdebinfo", "configuration": "RelWithDebInfo", "configurePreset": "windows_win32_relwithdebinfo", - "name": "windows_win32_relwithdebinfo", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_relwithdebinfo_asan", "configuration": "RelWithDebInfo", "configurePreset": "windows_win32_relwithdebinfo_asan", - "name": "windows_win32_relwithdebinfo_asan", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_relwithdebinfo_asan_no_ort", "configuration": "RelWithDebInfo", "configurePreset": "windows_win32_relwithdebinfo_asan_no_ort", - "name": "windows_win32_relwithdebinfo_asan_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_win32_relwithdebinfo_no_ort", "configuration": "RelWithDebInfo", "configurePreset": "windows_win32_relwithdebinfo_no_ort", - "name": "windows_win32_relwithdebinfo_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_debug", "configuration": "Debug", "configurePreset": "windows_x64_debug", - "name": "windows_x64_debug", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_debug_asan", "configuration": "Debug", "configurePreset": "windows_x64_debug_asan", - "name": "windows_x64_debug_asan", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_debug_asan_no_ort", "configuration": "Debug", "configurePreset": "windows_x64_debug_asan_no_ort", - "name": "windows_x64_debug_asan_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_debug_no_ort", "configuration": "Debug", "configurePreset": "windows_x64_debug_no_ort", - "name": "windows_x64_debug_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_minsizerel", "configuration": "MinSizeRel", "configurePreset": "windows_x64_minsizerel", - "name": "windows_x64_minsizerel", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_minsizerel_asan", "configuration": "MinSizeRel", "configurePreset": "windows_x64_minsizerel_asan", - "name": "windows_x64_minsizerel_asan", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_minsizerel_asan_no_ort", "configuration": "MinSizeRel", "configurePreset": "windows_x64_minsizerel_asan_no_ort", - "name": "windows_x64_minsizerel_asan_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_minsizerel_no_ort", "configuration": "MinSizeRel", "configurePreset": "windows_x64_minsizerel_no_ort", - "name": "windows_x64_minsizerel_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_release", "configuration": "Release", "configurePreset": "windows_x64_release", - "name": "windows_x64_release", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_release_asan", "configuration": "Release", "configurePreset": "windows_x64_release_asan", - "name": "windows_x64_release_asan", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_release_asan_no_ort", "configuration": "Release", "configurePreset": "windows_x64_release_asan_no_ort", - "name": "windows_x64_release_asan_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_release_no_ort", "configuration": "Release", "configurePreset": "windows_x64_release_no_ort", - "name": "windows_x64_release_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_relwithdebinfo", "configuration": "RelWithDebInfo", "configurePreset": "windows_x64_relwithdebinfo", - "name": "windows_x64_relwithdebinfo", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_relwithdebinfo_asan", "configuration": "RelWithDebInfo", "configurePreset": "windows_x64_relwithdebinfo_asan", - "name": "windows_x64_relwithdebinfo_asan", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_relwithdebinfo_asan_no_ort", "configuration": "RelWithDebInfo", "configurePreset": "windows_x64_relwithdebinfo_asan_no_ort", - "name": "windows_x64_relwithdebinfo_asan_no_ort", "output": { "outputOnFailure": true } }, { + "name": "windows_x64_relwithdebinfo_no_ort", "configuration": "RelWithDebInfo", "configurePreset": "windows_x64_relwithdebinfo_no_ort", - "name": "windows_x64_relwithdebinfo_no_ort", "output": { "outputOnFailure": true } } ], - "version": 8, "workflowPresets": [ { "name": "linux_clang_debug_asan_no_ort_workflow", "steps": [ { - "name": "linux_clang_debug_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "linux_clang_debug_asan_no_ort" }, { - "name": "linux_clang_debug_asan_no_ort", - "type": "build" + "type": "build", + "name": "linux_clang_debug_asan_no_ort" }, { - "name": "linux_clang_debug_asan_no_ort", - "type": "test" + "type": "test", + "name": "linux_clang_debug_asan_no_ort" } ] }, @@ -3061,16 +2540,16 @@ "name": "linux_clang_debug_asan_workflow", "steps": [ { - "name": "linux_clang_debug_asan", - "type": "configure" + "type": "configure", + "name": "linux_clang_debug_asan" }, { - "name": "linux_clang_debug_asan", - "type": "build" + "type": "build", + "name": "linux_clang_debug_asan" }, { - "name": "linux_clang_debug_asan", - "type": "test" + "type": "test", + "name": "linux_clang_debug_asan" } ] }, @@ -3078,16 +2557,16 @@ "name": "linux_clang_debug_cov_no_ort_workflow", "steps": [ { - "name": "linux_clang_debug_cov_no_ort", - "type": "configure" + "type": "configure", + "name": "linux_clang_debug_cov_no_ort" }, { - "name": "linux_clang_debug_cov_no_ort", - "type": "build" + "type": "build", + "name": "linux_clang_debug_cov_no_ort" }, { - "name": "linux_clang_debug_cov_no_ort", - "type": "test" + "type": "test", + "name": "linux_clang_debug_cov_no_ort" } ] }, @@ -3095,16 +2574,16 @@ "name": "linux_clang_debug_cov_workflow", "steps": [ { - "name": "linux_clang_debug_cov", - "type": "configure" + "type": "configure", + "name": "linux_clang_debug_cov" }, { - "name": "linux_clang_debug_cov", - "type": "build" + "type": "build", + "name": "linux_clang_debug_cov" }, { - "name": "linux_clang_debug_cov", - "type": "test" + "type": "test", + "name": "linux_clang_debug_cov" } ] }, @@ -3112,16 +2591,16 @@ "name": "linux_clang_debug_no_ort_workflow", "steps": [ { - "name": "linux_clang_debug_no_ort", - "type": "configure" + "type": "configure", + "name": "linux_clang_debug_no_ort" }, { - "name": "linux_clang_debug_no_ort", - "type": "build" + "type": "build", + "name": "linux_clang_debug_no_ort" }, { - "name": "linux_clang_debug_no_ort", - "type": "test" + "type": "test", + "name": "linux_clang_debug_no_ort" } ] }, @@ -3129,16 +2608,16 @@ "name": "linux_clang_debug_workflow", "steps": [ { - "name": "linux_clang_debug", - "type": "configure" + "type": "configure", + "name": "linux_clang_debug" }, { - "name": "linux_clang_debug", - "type": "build" + "type": "build", + "name": "linux_clang_debug" }, { - "name": "linux_clang_debug", - "type": "test" + "type": "test", + "name": "linux_clang_debug" } ] }, @@ -3146,16 +2625,16 @@ "name": "linux_gcc_debug_asan_no_ort_workflow", "steps": [ { - "name": "linux_gcc_debug_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "linux_gcc_debug_asan_no_ort" }, { - "name": "linux_gcc_debug_asan_no_ort", - "type": "build" + "type": "build", + "name": "linux_gcc_debug_asan_no_ort" }, { - "name": "linux_gcc_debug_asan_no_ort", - "type": "test" + "type": "test", + "name": "linux_gcc_debug_asan_no_ort" } ] }, @@ -3163,16 +2642,16 @@ "name": "linux_gcc_debug_asan_workflow", "steps": [ { - "name": "linux_gcc_debug_asan", - "type": "configure" + "type": "configure", + "name": "linux_gcc_debug_asan" }, { - "name": "linux_gcc_debug_asan", - "type": "build" + "type": "build", + "name": "linux_gcc_debug_asan" }, { - "name": "linux_gcc_debug_asan", - "type": "test" + "type": "test", + "name": "linux_gcc_debug_asan" } ] }, @@ -3180,16 +2659,16 @@ "name": "linux_gcc_debug_no_ort_workflow", "steps": [ { - "name": "linux_gcc_debug_no_ort", - "type": "configure" + "type": "configure", + "name": "linux_gcc_debug_no_ort" }, { - "name": "linux_gcc_debug_no_ort", - "type": "build" + "type": "build", + "name": "linux_gcc_debug_no_ort" }, { - "name": "linux_gcc_debug_no_ort", - "type": "test" + "type": "test", + "name": "linux_gcc_debug_no_ort" } ] }, @@ -3197,16 +2676,16 @@ "name": "linux_gcc_debug_workflow", "steps": [ { - "name": "linux_gcc_debug", - "type": "configure" + "type": "configure", + "name": "linux_gcc_debug" }, { - "name": "linux_gcc_debug", - "type": "build" + "type": "build", + "name": "linux_gcc_debug" }, { - "name": "linux_gcc_debug", - "type": "test" + "type": "test", + "name": "linux_gcc_debug" } ] }, @@ -3214,16 +2693,16 @@ "name": "linux_gcc_minsizerel_asan_no_ort_workflow", "steps": [ { - "name": "linux_gcc_minsizerel_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "linux_gcc_minsizerel_asan_no_ort" }, { - "name": "linux_gcc_minsizerel_asan_no_ort", - "type": "build" + "type": "build", + "name": "linux_gcc_minsizerel_asan_no_ort" }, { - "name": "linux_gcc_minsizerel_asan_no_ort", - "type": "test" + "type": "test", + "name": "linux_gcc_minsizerel_asan_no_ort" } ] }, @@ -3231,16 +2710,16 @@ "name": "linux_gcc_minsizerel_asan_workflow", "steps": [ { - "name": "linux_gcc_minsizerel_asan", - "type": "configure" + "type": "configure", + "name": "linux_gcc_minsizerel_asan" }, { - "name": "linux_gcc_minsizerel_asan", - "type": "build" + "type": "build", + "name": "linux_gcc_minsizerel_asan" }, { - "name": "linux_gcc_minsizerel_asan", - "type": "test" + "type": "test", + "name": "linux_gcc_minsizerel_asan" } ] }, @@ -3248,16 +2727,16 @@ "name": "linux_gcc_minsizerel_no_ort_workflow", "steps": [ { - "name": "linux_gcc_minsizerel_no_ort", - "type": "configure" + "type": "configure", + "name": "linux_gcc_minsizerel_no_ort" }, { - "name": "linux_gcc_minsizerel_no_ort", - "type": "build" + "type": "build", + "name": "linux_gcc_minsizerel_no_ort" }, { - "name": "linux_gcc_minsizerel_no_ort", - "type": "test" + "type": "test", + "name": "linux_gcc_minsizerel_no_ort" } ] }, @@ -3265,16 +2744,16 @@ "name": "linux_gcc_minsizerel_workflow", "steps": [ { - "name": "linux_gcc_minsizerel", - "type": "configure" + "type": "configure", + "name": "linux_gcc_minsizerel" }, { - "name": "linux_gcc_minsizerel", - "type": "build" + "type": "build", + "name": "linux_gcc_minsizerel" }, { - "name": "linux_gcc_minsizerel", - "type": "test" + "type": "test", + "name": "linux_gcc_minsizerel" } ] }, @@ -3282,16 +2761,16 @@ "name": "linux_gcc_release_asan_no_ort_workflow", "steps": [ { - "name": "linux_gcc_release_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "linux_gcc_release_asan_no_ort" }, { - "name": "linux_gcc_release_asan_no_ort", - "type": "build" + "type": "build", + "name": "linux_gcc_release_asan_no_ort" }, { - "name": "linux_gcc_release_asan_no_ort", - "type": "test" + "type": "test", + "name": "linux_gcc_release_asan_no_ort" } ] }, @@ -3299,16 +2778,16 @@ "name": "linux_gcc_release_asan_workflow", "steps": [ { - "name": "linux_gcc_release_asan", - "type": "configure" + "type": "configure", + "name": "linux_gcc_release_asan" }, { - "name": "linux_gcc_release_asan", - "type": "build" + "type": "build", + "name": "linux_gcc_release_asan" }, { - "name": "linux_gcc_release_asan", - "type": "test" + "type": "test", + "name": "linux_gcc_release_asan" } ] }, @@ -3316,16 +2795,16 @@ "name": "linux_gcc_release_no_ort_workflow", "steps": [ { - "name": "linux_gcc_release_no_ort", - "type": "configure" + "type": "configure", + "name": "linux_gcc_release_no_ort" }, { - "name": "linux_gcc_release_no_ort", - "type": "build" + "type": "build", + "name": "linux_gcc_release_no_ort" }, { - "name": "linux_gcc_release_no_ort", - "type": "test" + "type": "test", + "name": "linux_gcc_release_no_ort" } ] }, @@ -3333,16 +2812,16 @@ "name": "linux_gcc_release_workflow", "steps": [ { - "name": "linux_gcc_release", - "type": "configure" + "type": "configure", + "name": "linux_gcc_release" }, { - "name": "linux_gcc_release", - "type": "build" + "type": "build", + "name": "linux_gcc_release" }, { - "name": "linux_gcc_release", - "type": "test" + "type": "test", + "name": "linux_gcc_release" } ] }, @@ -3350,16 +2829,16 @@ "name": "linux_gcc_relwithdebinfo_asan_no_ort_workflow", "steps": [ { - "name": "linux_gcc_relwithdebinfo_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "linux_gcc_relwithdebinfo_asan_no_ort" }, { - "name": "linux_gcc_relwithdebinfo_asan_no_ort", - "type": "build" + "type": "build", + "name": "linux_gcc_relwithdebinfo_asan_no_ort" }, { - "name": "linux_gcc_relwithdebinfo_asan_no_ort", - "type": "test" + "type": "test", + "name": "linux_gcc_relwithdebinfo_asan_no_ort" } ] }, @@ -3367,16 +2846,16 @@ "name": "linux_gcc_relwithdebinfo_asan_workflow", "steps": [ { - "name": "linux_gcc_relwithdebinfo_asan", - "type": "configure" + "type": "configure", + "name": "linux_gcc_relwithdebinfo_asan" }, { - "name": "linux_gcc_relwithdebinfo_asan", - "type": "build" + "type": "build", + "name": "linux_gcc_relwithdebinfo_asan" }, { - "name": "linux_gcc_relwithdebinfo_asan", - "type": "test" + "type": "test", + "name": "linux_gcc_relwithdebinfo_asan" } ] }, @@ -3384,16 +2863,16 @@ "name": "linux_gcc_relwithdebinfo_no_ort_workflow", "steps": [ { - "name": "linux_gcc_relwithdebinfo_no_ort", - "type": "configure" + "type": "configure", + "name": "linux_gcc_relwithdebinfo_no_ort" }, { - "name": "linux_gcc_relwithdebinfo_no_ort", - "type": "build" + "type": "build", + "name": "linux_gcc_relwithdebinfo_no_ort" }, { - "name": "linux_gcc_relwithdebinfo_no_ort", - "type": "test" + "type": "test", + "name": "linux_gcc_relwithdebinfo_no_ort" } ] }, @@ -3401,16 +2880,16 @@ "name": "linux_gcc_relwithdebinfo_workflow", "steps": [ { - "name": "linux_gcc_relwithdebinfo", - "type": "configure" + "type": "configure", + "name": "linux_gcc_relwithdebinfo" }, { - "name": "linux_gcc_relwithdebinfo", - "type": "build" + "type": "build", + "name": "linux_gcc_relwithdebinfo" }, { - "name": "linux_gcc_relwithdebinfo", - "type": "test" + "type": "test", + "name": "linux_gcc_relwithdebinfo" } ] }, @@ -3418,16 +2897,16 @@ "name": "macos_arm64_debug_asan_workflow", "steps": [ { - "name": "macos_arm64_debug_asan", - "type": "configure" + "type": "configure", + "name": "macos_arm64_debug_asan" }, { - "name": "macos_arm64_debug_asan", - "type": "build" + "type": "build", + "name": "macos_arm64_debug_asan" }, { - "name": "macos_arm64_debug_asan", - "type": "test" + "type": "test", + "name": "macos_arm64_debug_asan" } ] }, @@ -3435,16 +2914,16 @@ "name": "macos_arm64_debug_workflow", "steps": [ { - "name": "macos_arm64_debug", - "type": "configure" + "type": "configure", + "name": "macos_arm64_debug" }, { - "name": "macos_arm64_debug", - "type": "build" + "type": "build", + "name": "macos_arm64_debug" }, { - "name": "macos_arm64_debug", - "type": "test" + "type": "test", + "name": "macos_arm64_debug" } ] }, @@ -3452,16 +2931,16 @@ "name": "macos_arm64_minsizerel_asan_workflow", "steps": [ { - "name": "macos_arm64_minsizerel_asan", - "type": "configure" + "type": "configure", + "name": "macos_arm64_minsizerel_asan" }, { - "name": "macos_arm64_minsizerel_asan", - "type": "build" + "type": "build", + "name": "macos_arm64_minsizerel_asan" }, { - "name": "macos_arm64_minsizerel_asan", - "type": "test" + "type": "test", + "name": "macos_arm64_minsizerel_asan" } ] }, @@ -3469,16 +2948,16 @@ "name": "macos_arm64_minsizerel_workflow", "steps": [ { - "name": "macos_arm64_minsizerel", - "type": "configure" + "type": "configure", + "name": "macos_arm64_minsizerel" }, { - "name": "macos_arm64_minsizerel", - "type": "build" + "type": "build", + "name": "macos_arm64_minsizerel" }, { - "name": "macos_arm64_minsizerel", - "type": "test" + "type": "test", + "name": "macos_arm64_minsizerel" } ] }, @@ -3486,16 +2965,16 @@ "name": "macos_arm64_release_asan_workflow", "steps": [ { - "name": "macos_arm64_release_asan", - "type": "configure" + "type": "configure", + "name": "macos_arm64_release_asan" }, { - "name": "macos_arm64_release_asan", - "type": "build" + "type": "build", + "name": "macos_arm64_release_asan" }, { - "name": "macos_arm64_release_asan", - "type": "test" + "type": "test", + "name": "macos_arm64_release_asan" } ] }, @@ -3503,16 +2982,16 @@ "name": "macos_arm64_release_workflow", "steps": [ { - "name": "macos_arm64_release", - "type": "configure" + "type": "configure", + "name": "macos_arm64_release" }, { - "name": "macos_arm64_release", - "type": "build" + "type": "build", + "name": "macos_arm64_release" }, { - "name": "macos_arm64_release", - "type": "test" + "type": "test", + "name": "macos_arm64_release" } ] }, @@ -3520,16 +2999,16 @@ "name": "macos_arm64_relwithdebinfo_asan_workflow", "steps": [ { - "name": "macos_arm64_relwithdebinfo_asan", - "type": "configure" + "type": "configure", + "name": "macos_arm64_relwithdebinfo_asan" }, { - "name": "macos_arm64_relwithdebinfo_asan", - "type": "build" + "type": "build", + "name": "macos_arm64_relwithdebinfo_asan" }, { - "name": "macos_arm64_relwithdebinfo_asan", - "type": "test" + "type": "test", + "name": "macos_arm64_relwithdebinfo_asan" } ] }, @@ -3537,16 +3016,16 @@ "name": "macos_arm64_relwithdebinfo_workflow", "steps": [ { - "name": "macos_arm64_relwithdebinfo", - "type": "configure" + "type": "configure", + "name": "macos_arm64_relwithdebinfo" }, { - "name": "macos_arm64_relwithdebinfo", - "type": "build" + "type": "build", + "name": "macos_arm64_relwithdebinfo" }, { - "name": "macos_arm64_relwithdebinfo", - "type": "test" + "type": "test", + "name": "macos_arm64_relwithdebinfo" } ] }, @@ -3554,16 +3033,16 @@ "name": "macos_universal2_debug_asan_workflow", "steps": [ { - "name": "macos_universal2_debug_asan", - "type": "configure" + "type": "configure", + "name": "macos_universal2_debug_asan" }, { - "name": "macos_universal2_debug_asan", - "type": "build" + "type": "build", + "name": "macos_universal2_debug_asan" }, { - "name": "macos_universal2_debug_asan", - "type": "test" + "type": "test", + "name": "macos_universal2_debug_asan" } ] }, @@ -3571,16 +3050,16 @@ "name": "macos_universal2_debug_workflow", "steps": [ { - "name": "macos_universal2_debug", - "type": "configure" + "type": "configure", + "name": "macos_universal2_debug" }, { - "name": "macos_universal2_debug", - "type": "build" + "type": "build", + "name": "macos_universal2_debug" }, { - "name": "macos_universal2_debug", - "type": "test" + "type": "test", + "name": "macos_universal2_debug" } ] }, @@ -3588,16 +3067,16 @@ "name": "macos_universal2_minsizerel_asan_workflow", "steps": [ { - "name": "macos_universal2_minsizerel_asan", - "type": "configure" + "type": "configure", + "name": "macos_universal2_minsizerel_asan" }, { - "name": "macos_universal2_minsizerel_asan", - "type": "build" + "type": "build", + "name": "macos_universal2_minsizerel_asan" }, { - "name": "macos_universal2_minsizerel_asan", - "type": "test" + "type": "test", + "name": "macos_universal2_minsizerel_asan" } ] }, @@ -3605,16 +3084,16 @@ "name": "macos_universal2_minsizerel_workflow", "steps": [ { - "name": "macos_universal2_minsizerel", - "type": "configure" + "type": "configure", + "name": "macos_universal2_minsizerel" }, { - "name": "macos_universal2_minsizerel", - "type": "build" + "type": "build", + "name": "macos_universal2_minsizerel" }, { - "name": "macos_universal2_minsizerel", - "type": "test" + "type": "test", + "name": "macos_universal2_minsizerel" } ] }, @@ -3622,16 +3101,16 @@ "name": "macos_universal2_release_asan_workflow", "steps": [ { - "name": "macos_universal2_release_asan", - "type": "configure" + "type": "configure", + "name": "macos_universal2_release_asan" }, { - "name": "macos_universal2_release_asan", - "type": "build" + "type": "build", + "name": "macos_universal2_release_asan" }, { - "name": "macos_universal2_release_asan", - "type": "test" + "type": "test", + "name": "macos_universal2_release_asan" } ] }, @@ -3639,16 +3118,16 @@ "name": "macos_universal2_release_workflow", "steps": [ { - "name": "macos_universal2_release", - "type": "configure" + "type": "configure", + "name": "macos_universal2_release" }, { - "name": "macos_universal2_release", - "type": "build" + "type": "build", + "name": "macos_universal2_release" }, { - "name": "macos_universal2_release", - "type": "test" + "type": "test", + "name": "macos_universal2_release" } ] }, @@ -3656,16 +3135,16 @@ "name": "macos_universal2_relwithdebinfo_asan_workflow", "steps": [ { - "name": "macos_universal2_relwithdebinfo_asan", - "type": "configure" + "type": "configure", + "name": "macos_universal2_relwithdebinfo_asan" }, { - "name": "macos_universal2_relwithdebinfo_asan", - "type": "build" + "type": "build", + "name": "macos_universal2_relwithdebinfo_asan" }, { - "name": "macos_universal2_relwithdebinfo_asan", - "type": "test" + "type": "test", + "name": "macos_universal2_relwithdebinfo_asan" } ] }, @@ -3673,16 +3152,16 @@ "name": "macos_universal2_relwithdebinfo_workflow", "steps": [ { - "name": "macos_universal2_relwithdebinfo", - "type": "configure" + "type": "configure", + "name": "macos_universal2_relwithdebinfo" }, { - "name": "macos_universal2_relwithdebinfo", - "type": "build" + "type": "build", + "name": "macos_universal2_relwithdebinfo" }, { - "name": "macos_universal2_relwithdebinfo", - "type": "test" + "type": "test", + "name": "macos_universal2_relwithdebinfo" } ] }, @@ -3690,16 +3169,16 @@ "name": "macos_x86_64_debug_asan_workflow", "steps": [ { - "name": "macos_x86_64_debug_asan", - "type": "configure" + "type": "configure", + "name": "macos_x86_64_debug_asan" }, { - "name": "macos_x86_64_debug_asan", - "type": "build" + "type": "build", + "name": "macos_x86_64_debug_asan" }, { - "name": "macos_x86_64_debug_asan", - "type": "test" + "type": "test", + "name": "macos_x86_64_debug_asan" } ] }, @@ -3707,16 +3186,16 @@ "name": "macos_x86_64_debug_workflow", "steps": [ { - "name": "macos_x86_64_debug", - "type": "configure" + "type": "configure", + "name": "macos_x86_64_debug" }, { - "name": "macos_x86_64_debug", - "type": "build" + "type": "build", + "name": "macos_x86_64_debug" }, { - "name": "macos_x86_64_debug", - "type": "test" + "type": "test", + "name": "macos_x86_64_debug" } ] }, @@ -3724,16 +3203,16 @@ "name": "macos_x86_64_minsizerel_asan_workflow", "steps": [ { - "name": "macos_x86_64_minsizerel_asan", - "type": "configure" + "type": "configure", + "name": "macos_x86_64_minsizerel_asan" }, { - "name": "macos_x86_64_minsizerel_asan", - "type": "build" + "type": "build", + "name": "macos_x86_64_minsizerel_asan" }, { - "name": "macos_x86_64_minsizerel_asan", - "type": "test" + "type": "test", + "name": "macos_x86_64_minsizerel_asan" } ] }, @@ -3741,16 +3220,16 @@ "name": "macos_x86_64_minsizerel_workflow", "steps": [ { - "name": "macos_x86_64_minsizerel", - "type": "configure" + "type": "configure", + "name": "macos_x86_64_minsizerel" }, { - "name": "macos_x86_64_minsizerel", - "type": "build" + "type": "build", + "name": "macos_x86_64_minsizerel" }, { - "name": "macos_x86_64_minsizerel", - "type": "test" + "type": "test", + "name": "macos_x86_64_minsizerel" } ] }, @@ -3758,16 +3237,16 @@ "name": "macos_x86_64_release_asan_workflow", "steps": [ { - "name": "macos_x86_64_release_asan", - "type": "configure" + "type": "configure", + "name": "macos_x86_64_release_asan" }, { - "name": "macos_x86_64_release_asan", - "type": "build" + "type": "build", + "name": "macos_x86_64_release_asan" }, { - "name": "macos_x86_64_release_asan", - "type": "test" + "type": "test", + "name": "macos_x86_64_release_asan" } ] }, @@ -3775,16 +3254,16 @@ "name": "macos_x86_64_release_workflow", "steps": [ { - "name": "macos_x86_64_release", - "type": "configure" + "type": "configure", + "name": "macos_x86_64_release" }, { - "name": "macos_x86_64_release", - "type": "build" + "type": "build", + "name": "macos_x86_64_release" }, { - "name": "macos_x86_64_release", - "type": "test" + "type": "test", + "name": "macos_x86_64_release" } ] }, @@ -3792,16 +3271,16 @@ "name": "macos_x86_64_relwithdebinfo_asan_workflow", "steps": [ { - "name": "macos_x86_64_relwithdebinfo_asan", - "type": "configure" + "type": "configure", + "name": "macos_x86_64_relwithdebinfo_asan" }, { - "name": "macos_x86_64_relwithdebinfo_asan", - "type": "build" + "type": "build", + "name": "macos_x86_64_relwithdebinfo_asan" }, { - "name": "macos_x86_64_relwithdebinfo_asan", - "type": "test" + "type": "test", + "name": "macos_x86_64_relwithdebinfo_asan" } ] }, @@ -3809,288 +3288,16 @@ "name": "macos_x86_64_relwithdebinfo_workflow", "steps": [ { - "name": "macos_x86_64_relwithdebinfo", - "type": "configure" - }, - { - "name": "macos_x86_64_relwithdebinfo", - "type": "build" - }, - { - "name": "macos_x86_64_relwithdebinfo", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_debug_asan_no_ort_workflow", - "steps": [ - { - "name": "windows_arm64_debug_asan_no_ort", - "type": "configure" - }, - { - "name": "windows_arm64_debug_asan_no_ort", - "type": "build" - }, - { - "name": "windows_arm64_debug_asan_no_ort", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_debug_asan_workflow", - "steps": [ - { - "name": "windows_arm64_debug_asan", - "type": "configure" - }, - { - "name": "windows_arm64_debug_asan", - "type": "build" - }, - { - "name": "windows_arm64_debug_asan", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_debug_no_ort_workflow", - "steps": [ - { - "name": "windows_arm64_debug_no_ort", - "type": "configure" - }, - { - "name": "windows_arm64_debug_no_ort", - "type": "build" - }, - { - "name": "windows_arm64_debug_no_ort", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_debug_workflow", - "steps": [ - { - "name": "windows_arm64_debug", - "type": "configure" - }, - { - "name": "windows_arm64_debug", - "type": "build" - }, - { - "name": "windows_arm64_debug", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_minsizerel_asan_no_ort_workflow", - "steps": [ - { - "name": "windows_arm64_minsizerel_asan_no_ort", - "type": "configure" - }, - { - "name": "windows_arm64_minsizerel_asan_no_ort", - "type": "build" - }, - { - "name": "windows_arm64_minsizerel_asan_no_ort", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_minsizerel_asan_workflow", - "steps": [ - { - "name": "windows_arm64_minsizerel_asan", - "type": "configure" - }, - { - "name": "windows_arm64_minsizerel_asan", - "type": "build" - }, - { - "name": "windows_arm64_minsizerel_asan", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_minsizerel_no_ort_workflow", - "steps": [ - { - "name": "windows_arm64_minsizerel_no_ort", - "type": "configure" - }, - { - "name": "windows_arm64_minsizerel_no_ort", - "type": "build" - }, - { - "name": "windows_arm64_minsizerel_no_ort", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_minsizerel_workflow", - "steps": [ - { - "name": "windows_arm64_minsizerel", - "type": "configure" - }, - { - "name": "windows_arm64_minsizerel", - "type": "build" - }, - { - "name": "windows_arm64_minsizerel", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_release_asan_no_ort_workflow", - "steps": [ - { - "name": "windows_arm64_release_asan_no_ort", - "type": "configure" - }, - { - "name": "windows_arm64_release_asan_no_ort", - "type": "build" - }, - { - "name": "windows_arm64_release_asan_no_ort", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_release_asan_workflow", - "steps": [ - { - "name": "windows_arm64_release_asan", - "type": "configure" - }, - { - "name": "windows_arm64_release_asan", - "type": "build" - }, - { - "name": "windows_arm64_release_asan", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_release_no_ort_workflow", - "steps": [ - { - "name": "windows_arm64_release_no_ort", - "type": "configure" - }, - { - "name": "windows_arm64_release_no_ort", - "type": "build" - }, - { - "name": "windows_arm64_release_no_ort", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_release_workflow", - "steps": [ - { - "name": "windows_arm64_release", - "type": "configure" - }, - { - "name": "windows_arm64_release", - "type": "build" - }, - { - "name": "windows_arm64_release", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_relwithdebinfo_asan_no_ort_workflow", - "steps": [ - { - "name": "windows_arm64_relwithdebinfo_asan_no_ort", - "type": "configure" - }, - { - "name": "windows_arm64_relwithdebinfo_asan_no_ort", - "type": "build" - }, - { - "name": "windows_arm64_relwithdebinfo_asan_no_ort", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_relwithdebinfo_asan_workflow", - "steps": [ - { - "name": "windows_arm64_relwithdebinfo_asan", - "type": "configure" - }, - { - "name": "windows_arm64_relwithdebinfo_asan", - "type": "build" - }, - { - "name": "windows_arm64_relwithdebinfo_asan", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_relwithdebinfo_no_ort_workflow", - "steps": [ - { - "name": "windows_arm64_relwithdebinfo_no_ort", - "type": "configure" - }, - { - "name": "windows_arm64_relwithdebinfo_no_ort", - "type": "build" - }, - { - "name": "windows_arm64_relwithdebinfo_no_ort", - "type": "test" - } - ] - }, - { - "name": "windows_arm64_relwithdebinfo_workflow", - "steps": [ - { - "name": "windows_arm64_relwithdebinfo", - "type": "configure" + "type": "configure", + "name": "macos_x86_64_relwithdebinfo" }, { - "name": "windows_arm64_relwithdebinfo", - "type": "build" + "type": "build", + "name": "macos_x86_64_relwithdebinfo" }, { - "name": "windows_arm64_relwithdebinfo", - "type": "test" + "type": "test", + "name": "macos_x86_64_relwithdebinfo" } ] }, @@ -4098,16 +3305,16 @@ "name": "windows_win32_debug_asan_no_ort_workflow", "steps": [ { - "name": "windows_win32_debug_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_win32_debug_asan_no_ort" }, { - "name": "windows_win32_debug_asan_no_ort", - "type": "build" + "type": "build", + "name": "windows_win32_debug_asan_no_ort" }, { - "name": "windows_win32_debug_asan_no_ort", - "type": "test" + "type": "test", + "name": "windows_win32_debug_asan_no_ort" } ] }, @@ -4115,16 +3322,16 @@ "name": "windows_win32_debug_asan_workflow", "steps": [ { - "name": "windows_win32_debug_asan", - "type": "configure" + "type": "configure", + "name": "windows_win32_debug_asan" }, { - "name": "windows_win32_debug_asan", - "type": "build" + "type": "build", + "name": "windows_win32_debug_asan" }, { - "name": "windows_win32_debug_asan", - "type": "test" + "type": "test", + "name": "windows_win32_debug_asan" } ] }, @@ -4132,16 +3339,16 @@ "name": "windows_win32_debug_no_ort_workflow", "steps": [ { - "name": "windows_win32_debug_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_win32_debug_no_ort" }, { - "name": "windows_win32_debug_no_ort", - "type": "build" + "type": "build", + "name": "windows_win32_debug_no_ort" }, { - "name": "windows_win32_debug_no_ort", - "type": "test" + "type": "test", + "name": "windows_win32_debug_no_ort" } ] }, @@ -4149,16 +3356,16 @@ "name": "windows_win32_debug_workflow", "steps": [ { - "name": "windows_win32_debug", - "type": "configure" + "type": "configure", + "name": "windows_win32_debug" }, { - "name": "windows_win32_debug", - "type": "build" + "type": "build", + "name": "windows_win32_debug" }, { - "name": "windows_win32_debug", - "type": "test" + "type": "test", + "name": "windows_win32_debug" } ] }, @@ -4166,16 +3373,16 @@ "name": "windows_win32_minsizerel_asan_no_ort_workflow", "steps": [ { - "name": "windows_win32_minsizerel_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_win32_minsizerel_asan_no_ort" }, { - "name": "windows_win32_minsizerel_asan_no_ort", - "type": "build" + "type": "build", + "name": "windows_win32_minsizerel_asan_no_ort" }, { - "name": "windows_win32_minsizerel_asan_no_ort", - "type": "test" + "type": "test", + "name": "windows_win32_minsizerel_asan_no_ort" } ] }, @@ -4183,16 +3390,16 @@ "name": "windows_win32_minsizerel_asan_workflow", "steps": [ { - "name": "windows_win32_minsizerel_asan", - "type": "configure" + "type": "configure", + "name": "windows_win32_minsizerel_asan" }, { - "name": "windows_win32_minsizerel_asan", - "type": "build" + "type": "build", + "name": "windows_win32_minsizerel_asan" }, { - "name": "windows_win32_minsizerel_asan", - "type": "test" + "type": "test", + "name": "windows_win32_minsizerel_asan" } ] }, @@ -4200,16 +3407,16 @@ "name": "windows_win32_minsizerel_no_ort_workflow", "steps": [ { - "name": "windows_win32_minsizerel_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_win32_minsizerel_no_ort" }, { - "name": "windows_win32_minsizerel_no_ort", - "type": "build" + "type": "build", + "name": "windows_win32_minsizerel_no_ort" }, { - "name": "windows_win32_minsizerel_no_ort", - "type": "test" + "type": "test", + "name": "windows_win32_minsizerel_no_ort" } ] }, @@ -4217,16 +3424,16 @@ "name": "windows_win32_minsizerel_workflow", "steps": [ { - "name": "windows_win32_minsizerel", - "type": "configure" + "type": "configure", + "name": "windows_win32_minsizerel" }, { - "name": "windows_win32_minsizerel", - "type": "build" + "type": "build", + "name": "windows_win32_minsizerel" }, { - "name": "windows_win32_minsizerel", - "type": "test" + "type": "test", + "name": "windows_win32_minsizerel" } ] }, @@ -4234,16 +3441,16 @@ "name": "windows_win32_release_asan_no_ort_workflow", "steps": [ { - "name": "windows_win32_release_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_win32_release_asan_no_ort" }, { - "name": "windows_win32_release_asan_no_ort", - "type": "build" + "type": "build", + "name": "windows_win32_release_asan_no_ort" }, { - "name": "windows_win32_release_asan_no_ort", - "type": "test" + "type": "test", + "name": "windows_win32_release_asan_no_ort" } ] }, @@ -4251,16 +3458,16 @@ "name": "windows_win32_release_asan_workflow", "steps": [ { - "name": "windows_win32_release_asan", - "type": "configure" + "type": "configure", + "name": "windows_win32_release_asan" }, { - "name": "windows_win32_release_asan", - "type": "build" + "type": "build", + "name": "windows_win32_release_asan" }, { - "name": "windows_win32_release_asan", - "type": "test" + "type": "test", + "name": "windows_win32_release_asan" } ] }, @@ -4268,16 +3475,16 @@ "name": "windows_win32_release_no_ort_workflow", "steps": [ { - "name": "windows_win32_release_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_win32_release_no_ort" }, { - "name": "windows_win32_release_no_ort", - "type": "build" + "type": "build", + "name": "windows_win32_release_no_ort" }, { - "name": "windows_win32_release_no_ort", - "type": "test" + "type": "test", + "name": "windows_win32_release_no_ort" } ] }, @@ -4285,16 +3492,16 @@ "name": "windows_win32_release_workflow", "steps": [ { - "name": "windows_win32_release", - "type": "configure" + "type": "configure", + "name": "windows_win32_release" }, { - "name": "windows_win32_release", - "type": "build" + "type": "build", + "name": "windows_win32_release" }, { - "name": "windows_win32_release", - "type": "test" + "type": "test", + "name": "windows_win32_release" } ] }, @@ -4302,16 +3509,16 @@ "name": "windows_win32_relwithdebinfo_asan_no_ort_workflow", "steps": [ { - "name": "windows_win32_relwithdebinfo_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_win32_relwithdebinfo_asan_no_ort" }, { - "name": "windows_win32_relwithdebinfo_asan_no_ort", - "type": "build" + "type": "build", + "name": "windows_win32_relwithdebinfo_asan_no_ort" }, { - "name": "windows_win32_relwithdebinfo_asan_no_ort", - "type": "test" + "type": "test", + "name": "windows_win32_relwithdebinfo_asan_no_ort" } ] }, @@ -4319,16 +3526,16 @@ "name": "windows_win32_relwithdebinfo_asan_workflow", "steps": [ { - "name": "windows_win32_relwithdebinfo_asan", - "type": "configure" + "type": "configure", + "name": "windows_win32_relwithdebinfo_asan" }, { - "name": "windows_win32_relwithdebinfo_asan", - "type": "build" + "type": "build", + "name": "windows_win32_relwithdebinfo_asan" }, { - "name": "windows_win32_relwithdebinfo_asan", - "type": "test" + "type": "test", + "name": "windows_win32_relwithdebinfo_asan" } ] }, @@ -4336,16 +3543,16 @@ "name": "windows_win32_relwithdebinfo_no_ort_workflow", "steps": [ { - "name": "windows_win32_relwithdebinfo_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_win32_relwithdebinfo_no_ort" }, { - "name": "windows_win32_relwithdebinfo_no_ort", - "type": "build" + "type": "build", + "name": "windows_win32_relwithdebinfo_no_ort" }, { - "name": "windows_win32_relwithdebinfo_no_ort", - "type": "test" + "type": "test", + "name": "windows_win32_relwithdebinfo_no_ort" } ] }, @@ -4353,16 +3560,16 @@ "name": "windows_win32_relwithdebinfo_workflow", "steps": [ { - "name": "windows_win32_relwithdebinfo", - "type": "configure" + "type": "configure", + "name": "windows_win32_relwithdebinfo" }, { - "name": "windows_win32_relwithdebinfo", - "type": "build" + "type": "build", + "name": "windows_win32_relwithdebinfo" }, { - "name": "windows_win32_relwithdebinfo", - "type": "test" + "type": "test", + "name": "windows_win32_relwithdebinfo" } ] }, @@ -4370,16 +3577,16 @@ "name": "windows_x64_debug_asan_no_ort_workflow", "steps": [ { - "name": "windows_x64_debug_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_x64_debug_asan_no_ort" }, { - "name": "windows_x64_debug_asan_no_ort", - "type": "build" + "type": "build", + "name": "windows_x64_debug_asan_no_ort" }, { - "name": "windows_x64_debug_asan_no_ort", - "type": "test" + "type": "test", + "name": "windows_x64_debug_asan_no_ort" } ] }, @@ -4387,16 +3594,16 @@ "name": "windows_x64_debug_asan_workflow", "steps": [ { - "name": "windows_x64_debug_asan", - "type": "configure" + "type": "configure", + "name": "windows_x64_debug_asan" }, { - "name": "windows_x64_debug_asan", - "type": "build" + "type": "build", + "name": "windows_x64_debug_asan" }, { - "name": "windows_x64_debug_asan", - "type": "test" + "type": "test", + "name": "windows_x64_debug_asan" } ] }, @@ -4404,16 +3611,16 @@ "name": "windows_x64_debug_no_ort_workflow", "steps": [ { - "name": "windows_x64_debug_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_x64_debug_no_ort" }, { - "name": "windows_x64_debug_no_ort", - "type": "build" + "type": "build", + "name": "windows_x64_debug_no_ort" }, { - "name": "windows_x64_debug_no_ort", - "type": "test" + "type": "test", + "name": "windows_x64_debug_no_ort" } ] }, @@ -4421,16 +3628,16 @@ "name": "windows_x64_debug_workflow", "steps": [ { - "name": "windows_x64_debug", - "type": "configure" + "type": "configure", + "name": "windows_x64_debug" }, { - "name": "windows_x64_debug", - "type": "build" + "type": "build", + "name": "windows_x64_debug" }, { - "name": "windows_x64_debug", - "type": "test" + "type": "test", + "name": "windows_x64_debug" } ] }, @@ -4438,16 +3645,16 @@ "name": "windows_x64_minsizerel_asan_no_ort_workflow", "steps": [ { - "name": "windows_x64_minsizerel_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_x64_minsizerel_asan_no_ort" }, { - "name": "windows_x64_minsizerel_asan_no_ort", - "type": "build" + "type": "build", + "name": "windows_x64_minsizerel_asan_no_ort" }, { - "name": "windows_x64_minsizerel_asan_no_ort", - "type": "test" + "type": "test", + "name": "windows_x64_minsizerel_asan_no_ort" } ] }, @@ -4455,16 +3662,16 @@ "name": "windows_x64_minsizerel_asan_workflow", "steps": [ { - "name": "windows_x64_minsizerel_asan", - "type": "configure" + "type": "configure", + "name": "windows_x64_minsizerel_asan" }, { - "name": "windows_x64_minsizerel_asan", - "type": "build" + "type": "build", + "name": "windows_x64_minsizerel_asan" }, { - "name": "windows_x64_minsizerel_asan", - "type": "test" + "type": "test", + "name": "windows_x64_minsizerel_asan" } ] }, @@ -4472,16 +3679,16 @@ "name": "windows_x64_minsizerel_no_ort_workflow", "steps": [ { - "name": "windows_x64_minsizerel_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_x64_minsizerel_no_ort" }, { - "name": "windows_x64_minsizerel_no_ort", - "type": "build" + "type": "build", + "name": "windows_x64_minsizerel_no_ort" }, { - "name": "windows_x64_minsizerel_no_ort", - "type": "test" + "type": "test", + "name": "windows_x64_minsizerel_no_ort" } ] }, @@ -4489,16 +3696,16 @@ "name": "windows_x64_minsizerel_workflow", "steps": [ { - "name": "windows_x64_minsizerel", - "type": "configure" + "type": "configure", + "name": "windows_x64_minsizerel" }, { - "name": "windows_x64_minsizerel", - "type": "build" + "type": "build", + "name": "windows_x64_minsizerel" }, { - "name": "windows_x64_minsizerel", - "type": "test" + "type": "test", + "name": "windows_x64_minsizerel" } ] }, @@ -4506,16 +3713,16 @@ "name": "windows_x64_release_asan_no_ort_workflow", "steps": [ { - "name": "windows_x64_release_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_x64_release_asan_no_ort" }, { - "name": "windows_x64_release_asan_no_ort", - "type": "build" + "type": "build", + "name": "windows_x64_release_asan_no_ort" }, { - "name": "windows_x64_release_asan_no_ort", - "type": "test" + "type": "test", + "name": "windows_x64_release_asan_no_ort" } ] }, @@ -4523,16 +3730,16 @@ "name": "windows_x64_release_asan_workflow", "steps": [ { - "name": "windows_x64_release_asan", - "type": "configure" + "type": "configure", + "name": "windows_x64_release_asan" }, { - "name": "windows_x64_release_asan", - "type": "build" + "type": "build", + "name": "windows_x64_release_asan" }, { - "name": "windows_x64_release_asan", - "type": "test" + "type": "test", + "name": "windows_x64_release_asan" } ] }, @@ -4540,16 +3747,16 @@ "name": "windows_x64_release_no_ort_workflow", "steps": [ { - "name": "windows_x64_release_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_x64_release_no_ort" }, { - "name": "windows_x64_release_no_ort", - "type": "build" + "type": "build", + "name": "windows_x64_release_no_ort" }, { - "name": "windows_x64_release_no_ort", - "type": "test" + "type": "test", + "name": "windows_x64_release_no_ort" } ] }, @@ -4557,16 +3764,16 @@ "name": "windows_x64_release_workflow", "steps": [ { - "name": "windows_x64_release", - "type": "configure" + "type": "configure", + "name": "windows_x64_release" }, { - "name": "windows_x64_release", - "type": "build" + "type": "build", + "name": "windows_x64_release" }, { - "name": "windows_x64_release", - "type": "test" + "type": "test", + "name": "windows_x64_release" } ] }, @@ -4574,16 +3781,16 @@ "name": "windows_x64_relwithdebinfo_asan_no_ort_workflow", "steps": [ { - "name": "windows_x64_relwithdebinfo_asan_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_x64_relwithdebinfo_asan_no_ort" }, { - "name": "windows_x64_relwithdebinfo_asan_no_ort", - "type": "build" + "type": "build", + "name": "windows_x64_relwithdebinfo_asan_no_ort" }, { - "name": "windows_x64_relwithdebinfo_asan_no_ort", - "type": "test" + "type": "test", + "name": "windows_x64_relwithdebinfo_asan_no_ort" } ] }, @@ -4591,16 +3798,16 @@ "name": "windows_x64_relwithdebinfo_asan_workflow", "steps": [ { - "name": "windows_x64_relwithdebinfo_asan", - "type": "configure" + "type": "configure", + "name": "windows_x64_relwithdebinfo_asan" }, { - "name": "windows_x64_relwithdebinfo_asan", - "type": "build" + "type": "build", + "name": "windows_x64_relwithdebinfo_asan" }, { - "name": "windows_x64_relwithdebinfo_asan", - "type": "test" + "type": "test", + "name": "windows_x64_relwithdebinfo_asan" } ] }, @@ -4608,16 +3815,16 @@ "name": "windows_x64_relwithdebinfo_no_ort_workflow", "steps": [ { - "name": "windows_x64_relwithdebinfo_no_ort", - "type": "configure" + "type": "configure", + "name": "windows_x64_relwithdebinfo_no_ort" }, { - "name": "windows_x64_relwithdebinfo_no_ort", - "type": "build" + "type": "build", + "name": "windows_x64_relwithdebinfo_no_ort" }, { - "name": "windows_x64_relwithdebinfo_no_ort", - "type": "test" + "type": "test", + "name": "windows_x64_relwithdebinfo_no_ort" } ] }, @@ -4625,16 +3832,16 @@ "name": "windows_x64_relwithdebinfo_workflow", "steps": [ { - "name": "windows_x64_relwithdebinfo", - "type": "configure" + "type": "configure", + "name": "windows_x64_relwithdebinfo" }, { - "name": "windows_x64_relwithdebinfo", - "type": "build" + "type": "build", + "name": "windows_x64_relwithdebinfo" }, { - "name": "windows_x64_relwithdebinfo", - "type": "test" + "type": "test", + "name": "windows_x64_relwithdebinfo" } ] } From c51224695534fb279a5a1dd20a262e50a00cda40 Mon Sep 17 00:00:00 2001 From: Changming Sun Date: Fri, 17 Oct 2025 21:29:43 -0700 Subject: [PATCH 33/33] revert --- .../platform/EigenNonBlockingThreadPool.h | 119 +----------------- 1 file changed, 1 insertion(+), 118 deletions(-) diff --git a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h index 45b1751..c313944 100644 --- a/src/ort_include/core/platform/EigenNonBlockingThreadPool.h +++ b/src/ort_include/core/platform/EigenNonBlockingThreadPool.h @@ -199,100 +199,6 @@ struct PaddingToAvoidFalseSharing { char padding[ORT_FALSE_SHARING_BYTES]; }; -/* Usage: -1. In executor, call Start() before profiling and Stop() to get profiled numbers; -2. Inside thread pool, call LogStart() before interested section and LogEnd... after to log elapsed time; -3. To extend, just add more events in enum Event before "All", and update GetEventName(...) accordingly; -4. Note LogStart must pair with either LogEnd or LogEndAndStart, otherwise ORT_ENFORCE will fail; -5. ThreadPoolProfiler is thread-safe. -*/ -#ifdef ORT_MINIMAL_BUILD -class ThreadPoolProfiler { - public: - enum ThreadPoolEvent { - DISTRIBUTION = 0, - DISTRIBUTION_ENQUEUE, - RUN, - WAIT, - WAIT_REVOKE, - MAX_EVENT - }; - ThreadPoolProfiler(int, const CHAR_TYPE*) {} - ~ThreadPoolProfiler() = default; - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ThreadPoolProfiler); - void Start() {} - std::string Stop() { return "not available for minimal build"; } - void LogStart() {} - void LogEnd(ThreadPoolEvent) {} - void LogEndAndStart(ThreadPoolEvent) {} - void LogStartAndCoreAndBlock(std::ptrdiff_t) {} - void LogCoreAndBlock(std::ptrdiff_t) {} - void LogThreadId(int) {} - void LogRun(int) {} - std::string DumpChildThreadStat() { return {}; } -}; -#else -class ThreadPoolProfiler { - public: - enum ThreadPoolEvent { - DISTRIBUTION = 0, - DISTRIBUTION_ENQUEUE, - RUN, - WAIT, - WAIT_REVOKE, - MAX_EVENT - }; - ThreadPoolProfiler(int num_threads, const CHAR_TYPE* threal_pool_name); - ~ThreadPoolProfiler(); - ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ThreadPoolProfiler); - using Clock = std::chrono::high_resolution_clock; - void Start(); // called by executor to start profiling - std::string Stop(); // called by executor to stop profiling and return collected numbers - void LogStart(); // called in main thread to record the starting time point - void LogEnd(ThreadPoolEvent); // called in main thread to calculate and save the time elapsed from last start point - void LogEndAndStart(ThreadPoolEvent); - void LogStartAndCoreAndBlock(std::ptrdiff_t block_size); - void LogCoreAndBlock(std::ptrdiff_t block_size); // called in main thread to log core and block size for task breakdown - void LogThreadId(int thread_idx); // called in child thread to log its id - void LogRun(int thread_idx); // called in child thread to log num of run - std::string DumpChildThreadStat(); // return all child statistics collected so far - - private: - static const char* GetEventName(ThreadPoolEvent); - struct MainThreadStat { - uint64_t events_[MAX_EVENT] = {}; - int32_t core_ = -1; - std::vector blocks_; // block size determined by cost model - std::vector points_; - void LogCore(); - void LogBlockSize(std::ptrdiff_t block_size); - void LogStart(); - void LogEnd(ThreadPoolEvent); - void LogEndAndStart(ThreadPoolEvent); - std::string Reset(); - }; - bool enabled_ = false; - MainThreadStat& GetMainThreadStat(); // return thread local stat - int num_threads_; -#ifdef _MSC_VER -#pragma warning(push) - // C4324: structure was padded due to alignment specifier -#pragma warning(disable : 4324) -#endif // _MSC_VER - struct ORT_ALIGN_TO_AVOID_FALSE_SHARING ChildThreadStat { - std::thread::id thread_id_; - uint64_t num_run_ = 0; - onnxruntime::TimePoint last_logged_point_ = Clock::now(); - int32_t core_ = -1; // core that the child thread is running on - }; -#ifdef _MSC_VER -#pragma warning(pop) -#endif // _MSC_VER - std::vector child_thread_stats_; - std::string thread_pool_name_; -}; -#endif - // Extended Eigen thread pool interface, avoiding the need to modify // the ThreadPoolInterface.h header from the external Eigen // repository. @@ -335,8 +241,6 @@ class ExtendedThreadPoolInterface : public Eigen::ThreadPoolInterface { // two loops execute in series in a parallel section. ] virtual void RunInParallel(std::function fn, unsigned n, std::ptrdiff_t block_size) = 0; - virtual void StartProfiling() = 0; - virtual std::string StopProfiling() = 0; }; class ThreadPoolParallelSection { @@ -705,7 +609,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter return 0; } - ThreadPoolProfiler profiler_; void SignalAllAndWait() { done_ = true; @@ -720,13 +623,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } public: - void StartProfiling() override { - profiler_.Start(); - } - std::string StopProfiling() override { - return profiler_.Stop(); - } struct Tag { constexpr Tag() : v_(0) { @@ -767,7 +664,7 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ThreadPoolTempl(const CHAR_TYPE* name, int num_threads, bool allow_spinning, Environment& env, const ThreadOptions& thread_options) - : profiler_(num_threads, name), + : env_(env), num_threads_(num_threads), allow_spinning_(allow_spinning), @@ -915,7 +812,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // tasks that were created (if any) for the parallel section. We // revoke tasks still in queues, and then wait for any that are // still running. - profiler_.LogStart(); unsigned tasks_started = static_cast(ps.tasks.size()); while (!ps.tasks.empty()) { const auto& item = ps.tasks.back(); @@ -925,7 +821,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } ps.tasks.pop_back(); } - profiler_.LogEnd(ThreadPoolProfiler::WAIT_REVOKE); // Wait for the dispatch task's own work... if (ps.dispatch_q_idx > -1) { @@ -1204,7 +1099,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter ps.work_done.store(true, std::memory_order_release); }; - profiler_.LogStart(); ps.dispatch_q_idx = preferred_workers[current_dop] % num_threads_; WorkerData& dispatch_td = worker_data_[ps.dispatch_q_idx]; Queue& dispatch_que = dispatch_td.queue; @@ -1222,7 +1116,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter } else { ps.dispatch_q_idx = -1; // failed to enqueue dispatch_task } - profiler_.LogEnd(ThreadPoolProfiler::DISTRIBUTION_ENQUEUE); } else { // Synchronous dispatch ScheduleOnPreferredWorkers(pt, ps, preferred_workers, current_dop, new_dop, std::move(worker_fn)); @@ -1240,7 +1133,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter unsigned n, std::ptrdiff_t block_size) override { ORT_ENFORCE(n <= num_threads_ + 1, "More work items than threads"); - profiler_.LogStartAndCoreAndBlock(block_size); PerThread* pt = GetPerThread(); assert(pt->leading_par_section && "RunInParallel, but not in parallel section"); assert((n > 1) && "Trivial parallel section; should be avoided by caller"); @@ -1270,18 +1162,15 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter }; RunInParallelInternal(*pt, ps, n, false, std::move(worker_fn)); assert(ps.dispatch_q_idx == -1); - profiler_.LogEndAndStart(ThreadPoolProfiler::DISTRIBUTION); // Run work in the main thread loop.fn(0); - profiler_.LogEndAndStart(ThreadPoolProfiler::RUN); // Wait for workers to exit the loop ps.current_loop = 0; while (ps.workers_in_loop) { onnxruntime::concurrency::SpinPause(); } - profiler_.LogEnd(ThreadPoolProfiler::WAIT); } // Run a single parallel loop _without_ a parallel section. This is a @@ -1298,16 +1187,12 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter // 1. run fn(...); void RunInParallel(std::function fn, unsigned n, std::ptrdiff_t block_size) override { ORT_ENFORCE(n <= num_threads_ + 1, "More work items than threads"); - profiler_.LogStartAndCoreAndBlock(block_size); PerThread* pt = GetPerThread(); ThreadPoolParallelSection ps; StartParallelSectionInternal(*pt, ps); RunInParallelInternal(*pt, ps, n, true, fn); // select dispatcher and do job distribution; - profiler_.LogEndAndStart(ThreadPoolProfiler::DISTRIBUTION); fn(0); // run fn(0) - profiler_.LogEndAndStart(ThreadPoolProfiler::RUN); EndParallelSectionInternal(*pt, ps); // wait for all - profiler_.LogEnd(ThreadPoolProfiler::WAIT); } int NumThreads() const final { @@ -1539,7 +1424,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter const int steal_count = spin_count / 100; SetDenormalAsZero(set_denormal_as_zero_); - profiler_.LogThreadId(thread_id); while (!should_exit) { Task t = q.PopFront(); @@ -1632,7 +1516,6 @@ class ThreadPoolTempl : public onnxruntime::concurrency::ExtendedThreadPoolInter if (t) { td.SetActive(); t(); - profiler_.LogRun(thread_id); td.SetSpinning(); } }