From da5731977605e31de2039e2057147e940c6706da Mon Sep 17 00:00:00 2001 From: Kali Uday Balleda Date: Thu, 16 Feb 2023 23:20:46 +0530 Subject: [PATCH 01/14] segfault occurs when tests run --- deps/src/onemkl.cpp | 45 +++++++++++++++++++++++++++- deps/src/onemkl.h | 7 +++++ lib/mkl/wrappers.jl | 50 ++++++++++++++++++++++++++++++++ lib/support/liboneapi_support.jl | 12 ++++++++ test/onemkl.jl | 43 +++++++++++++++++++++++++-- 5 files changed, 153 insertions(+), 4 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index f11ce26e..8d726639 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -1,6 +1,6 @@ #include "onemkl.h" #include "sycl.hpp" - +#include #include // This is a workaround to flush MKL submissions into Level-zero queue, using @@ -119,6 +119,49 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, return 0; } +extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, float alpha, + const float **a, int64_t lda, const float **b, + int64_t ldb, float beta, float **c, + int64_t ldc, int64_t group_count) { + std::cout << "Group Count " << group_count << std::endl; + std::vector transa_vec(group_count); + std::vector transb_vec(group_count); + std::vector m_vec(group_count); + std::vector n_vec(group_count); + std::vector k_vec(group_count); + std::vector alpha_vec(group_count); + std::vector beta_vec(group_count); + std::vector lda_vec(group_count); + std::vector ldb_vec(group_count); + std::vector ldc_vec(group_count); + std::vector group_size(group_count); + auto t_a = convert(transa); + auto t_b = convert(transb); + for (int i = 0; i < group_count; i++) { + transa_vec[i] = t_a; + transb_vec[i] = t_b; + m_vec[i] = m; + n_vec[i] = n; + k_vec[i] = k; + alpha_vec[i] = alpha; + beta_vec[i] = beta; + lda_vec[i] = lda; + ldb_vec[i] = ldb; + ldc_vec[i] = ldc; + group_size[i] = m * n * k; + } + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, &transa_vec[0], + &transb_vec[0], &m_vec[0], &n_vec[0], &k_vec[0], + &alpha_vec[0], (const float **)&a[0], + &lda_vec[0], (const float **)&b[0], + &ldb_vec[0], &beta_vec[0], + &c[0], &ldc_vec[0], group_count, &group_size[0]); + __FORCE_MKL_FLUSH__(status); + std::cout << "Done with gemm_batch" << std::endl; +} + extern "C" void onemklSsymm(syclQueue_t device_queue, onemklSide left_right, onemklUplo upper_lower, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, const float *b, diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index 476369fc..d8dd2cdd 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -54,6 +54,13 @@ int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, const double _Complex *B, int64_t ldb, double _Complex beta, double _Complex *C, int64_t ldc); +void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, float alpha, + const float **a, int64_t lda, const float **b, + int64_t ldb, float beta, float **c, + int64_t ldc, int64_t group_count); + void onemklSsymm(syclQueue_t device_queue, onemklSide left_right, onemklUplo upper_lower, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, const float *b, diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index 7cab47e0..4fca7f61 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -44,6 +44,56 @@ function Base.convert(::Type{onemklDiag}, diag::Char) end end +# create a batch of pointers in device memory from a batch of device arrays +@inline function unsafe_batch(batch::Vector{<:oneArray{T}}) where {T} + ptrs = pointer.(batch) + return oneArray(ptrs) +end + +## (GE) general matrix-matrix multiplication batched +for (fname, elty) in + ((:onemklDgemmBatched,:Float64), + (:onemklSgemmBatched,:Float32)) + @eval begin + function gemm_batched!(transA::Char, + transB::Char, + alpha::Number, + A::Vector{<:oneStridedMatrix{$elty}}, + B::Vector{<:oneStridedMatrix{$elty}}, + beta::Number, + C::Vector{<:oneStridedMatrix{$elty}}) + if length(A) != length(B) || length(A) != length(C) + throw(DimensionMismatch("")) + end + for (As,Bs,Cs) in zip(A,B,C) + m = size(As, transA == 'N' ? 1 : 2) + k = size(As, transA == 'N' ? 2 : 1) + n = size(Bs, transB == 'N' ? 2 : 1) + if m != size(Cs,1) || n != size(Cs,2) || k != size(Bs, transB == 'N' ? 1 : 2) + throw(DimensionMismatch("")) + end + end + + m = size(A[1], transA == 'N' ? 1 : 2) + k = size(A[1], transA == 'N' ? 2 : 1) + n = size(B[1], transB == 'N' ? 2 : 1) + lda = max(1,stride(A[1],2)) + ldb = max(1,stride(B[1],2)) + ldc = max(1,stride(C[1],2)) + Aptrs = unsafe_batch(A) + Bptrs = unsafe_batch(B) + Cptrs = unsafe_batch(C) + queue = global_queue(context(A[1]), device(A[1])) + $fname(sycl_queue(queue), transA, transB, m, n, k, alpha, Aptrs, lda, Bptrs, + ldb, beta, Cptrs, ldc, length(A)) + #unsafe_free!(Cptrs) + #unsafe_free!(Bptrs) + #unsafe_free!(Aptrs) + C + end + end +end + ## (L3: symm) symmetric matrix-matrix and matrix-vector multiplication for (fname, elty) in ((:onemklSsymm, :Float32), (:onemklDsymm, :Float64), diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl index b358f5c3..0dfea776 100644 --- a/lib/support/liboneapi_support.jl +++ b/lib/support/liboneapi_support.jl @@ -131,6 +131,18 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld ldc::Int64)::Cint end +function onemklSgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, group_count) + @ccall liboneapi_support.onemklSgemmBatched(device_queue::syclQueue_t, + transa::onemklTranspose, + transb::onemklTranspose, m::Int64, n::Int64, + k::Int64, alpha::Cfloat, + a::ZePtr{Ptr{Cfloat}}, lda::Int64, + b::ZePtr{Ptr{Cfloat}}, ldb::Int64, + beta::Cfloat, c::ZePtr{Ptr{Cfloat}}, + ldc::Int64, group_count::Int64)::Cvoid +end + function onemklSsymm(device_queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc) @ccall liboneapi_support.onemklSsymm(device_queue::syclQueue_t, left_right::onemklSide, diff --git a/test/onemkl.jl b/test/onemkl.jl index 1c515c60..d17ab703 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -3,10 +3,11 @@ using oneAPI.oneMKL: band, bandex using LinearAlgebra -m = 20 -n = 35 -k = 13 +m = 2 +n = 2 +k = 2 +#= ############################################################################################ @testset "level 1" begin @testset for T in intersect(eltypes, [Float32, Float64, ComplexF32, ComplexF64]) @@ -808,4 +809,40 @@ end end end end +end +=# + +@testset for T in [Float32] + alpha = rand(T) + beta = rand(T) + group_count = 1 + # generate matrices + bA = [rand(T,m,k) for i in 1:group_count] + bB = [rand(T,k,n) for i in 1:group_count] + bC = [rand(T,m,n) for i in 1:group_count] + # move to device + bd_A = oneArray{T, 2}[] + bd_B = oneArray{T, 2}[] + bd_C = oneArray{T, 2}[] + bd_bad = oneArray{T, 2}[] + for i in 1:length(bA) + push!(bd_A,oneArray(bA[i])) + push!(bd_B,oneArray(bB[i])) + push!(bd_C,oneArray(bC[i])) + if i < length(bA) - 2 + push!(bd_bad,oneArray(bC[i])) + end + end + + @testset "gemm_batched!" begin + # C = (alpha*A)*B + beta*C + oneMKL.gemm_batched!('N','N',alpha,bd_A,bd_B,beta,bd_C) + for i in 1:length(bd_C) + bC[i] = (alpha*bA[i])*bB[i] + beta*bC[i] + h_C = Array(bd_C[i]) + #compare + @test bC[i] ≈ h_C + end + @test_throws DimensionMismatch oneMKL.gemm_batched!('N','N',alpha,bd_A,bd_bad,beta,bd_C) + end end \ No newline at end of file From 88d2b2159397bbf9feb83cb134b869e74bcb0196 Mon Sep 17 00:00:00 2001 From: kballeda Date: Fri, 24 Feb 2023 16:13:37 +0530 Subject: [PATCH 02/14] use malloc_shared to create buffers --- deps/src/onemkl.cpp | 47 +++++++++++++++++++++++---------------------- test/onemkl.jl | 2 +- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 8d726639..8ea7fd3c 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -125,33 +125,33 @@ extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose tra const float **a, int64_t lda, const float **b, int64_t ldb, float beta, float **c, int64_t ldc, int64_t group_count) { - std::cout << "Group Count " << group_count << std::endl; - std::vector transa_vec(group_count); - std::vector transb_vec(group_count); - std::vector m_vec(group_count); - std::vector n_vec(group_count); - std::vector k_vec(group_count); - std::vector alpha_vec(group_count); - std::vector beta_vec(group_count); - std::vector lda_vec(group_count); - std::vector ldb_vec(group_count); - std::vector ldc_vec(group_count); - std::vector group_size(group_count); + auto main_queue = device_queue->val; + auto device = main_queue.get_device(); + auto context = main_queue.get_context(); + + auto m_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + auto n_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + auto k_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + auto lda_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + auto ldb_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + auto ldc_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + auto alpha_dev = (float *) malloc_shared(group_count * sizeof(float), device, context); + auto beta_dev = (float *) malloc_shared(group_count * sizeof(float), device, context); + auto transa_dev = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), device, context); auto t_a = convert(transa); auto t_b = convert(transb); for (int i = 0; i < group_count; i++) { - transa_vec[i] = t_a; - transb_vec[i] = t_b; - m_vec[i] = m; - n_vec[i] = n; - k_vec[i] = k; - alpha_vec[i] = alpha; - beta_vec[i] = beta; - lda_vec[i] = lda; - ldb_vec[i] = ldb; - ldc_vec[i] = ldc; - group_size[i] = m * n * k; + m_dev[i] = m; + n_dev[i] = n; + k_dev[i] = k; + lda_dev[i] = lda; + ldb_dev[i] = ldb; + ldc_dev[i] = ldc; + alpha_dev[i] = alpha; + beta_dev[i] = beta; + transa_dev[i] = t_a; } +#if 0 auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, &transa_vec[0], &transb_vec[0], &m_vec[0], &n_vec[0], &k_vec[0], &alpha_vec[0], (const float **)&a[0], @@ -159,6 +159,7 @@ extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose tra &ldb_vec[0], &beta_vec[0], &c[0], &ldc_vec[0], group_count, &group_size[0]); __FORCE_MKL_FLUSH__(status); +#endif std::cout << "Done with gemm_batch" << std::endl; } diff --git a/test/onemkl.jl b/test/onemkl.jl index d17ab703..df07879b 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -815,7 +815,7 @@ end @testset for T in [Float32] alpha = rand(T) beta = rand(T) - group_count = 1 + group_count = 10 # generate matrices bA = [rand(T,m,k) for i in 1:group_count] bB = [rand(T,k,n) for i in 1:group_count] From 2d46cacf07b7fcd73261b7b6d76c59f238139148 Mon Sep 17 00:00:00 2001 From: kballeda Date: Fri, 24 Feb 2023 16:38:35 +0530 Subject: [PATCH 03/14] single matrix gemm_batch passes --- deps/src/onemkl.cpp | 22 +++++++++++++--------- test/onemkl.jl | 8 ++++---- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 8ea7fd3c..d19d144d 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -138,8 +138,10 @@ extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose tra auto alpha_dev = (float *) malloc_shared(group_count * sizeof(float), device, context); auto beta_dev = (float *) malloc_shared(group_count * sizeof(float), device, context); auto transa_dev = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), device, context); + auto transb_dev = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), device, context); auto t_a = convert(transa); auto t_b = convert(transb); + auto group_size = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); for (int i = 0; i < group_count; i++) { m_dev[i] = m; n_dev[i] = n; @@ -150,16 +152,18 @@ extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose tra alpha_dev[i] = alpha; beta_dev[i] = beta; transa_dev[i] = t_a; + transb_dev[i] = t_b; + group_size[i] = group_count; } -#if 0 - auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, &transa_vec[0], - &transb_vec[0], &m_vec[0], &n_vec[0], &k_vec[0], - &alpha_vec[0], (const float **)&a[0], - &lda_vec[0], (const float **)&b[0], - &ldb_vec[0], &beta_vec[0], - &c[0], &ldc_vec[0], group_count, &group_size[0]); - __FORCE_MKL_FLUSH__(status); -#endif + + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, &transa_dev[0], + &transb_dev[0], &m_dev[0], &n_dev[0], &k_dev[0], + &alpha_dev[0], (const float **)&a[0], + &lda_dev[0], (const float **)&b[0], + &ldb_dev[0], &beta_dev[0], + &c[0], &ldc_dev[0], group_count, &group_size[0]); + __FORCE_MKL_FLUSH__(status); + std::cout << "Done with gemm_batch" << std::endl; } diff --git a/test/onemkl.jl b/test/onemkl.jl index df07879b..3b58b7e8 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -3,9 +3,9 @@ using oneAPI.oneMKL: band, bandex using LinearAlgebra -m = 2 -n = 2 -k = 2 +m = 20 +n = 35 +k = 13 #= ############################################################################################ @@ -815,7 +815,7 @@ end @testset for T in [Float32] alpha = rand(T) beta = rand(T) - group_count = 10 + group_count = 1 # generate matrices bA = [rand(T,m,k) for i in 1:group_count] bB = [rand(T,k,n) for i in 1:group_count] From c3f40e22139aec5dcfc10a98872cbd50786c080d Mon Sep 17 00:00:00 2001 From: kballeda Date: Sun, 26 Feb 2023 19:24:36 +0530 Subject: [PATCH 04/14] multiple matrices go through gemm_batch --- deps/src/onemkl.cpp | 6 +++--- lib/support/liboneapi_support.jl | 6 +++--- test/onemkl.jl | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index d19d144d..00111007 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -1,6 +1,7 @@ #include "onemkl.h" #include "sycl.hpp" #include +#include #include // This is a workaround to flush MKL submissions into Level-zero queue, using @@ -153,9 +154,9 @@ extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose tra beta_dev[i] = beta; transa_dev[i] = t_a; transb_dev[i] = t_b; - group_size[i] = group_count; + group_size[i] = 1; } - + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, &transa_dev[0], &transb_dev[0], &m_dev[0], &n_dev[0], &k_dev[0], &alpha_dev[0], (const float **)&a[0], @@ -164,7 +165,6 @@ extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose tra &c[0], &ldc_dev[0], group_count, &group_size[0]); __FORCE_MKL_FLUSH__(status); - std::cout << "Done with gemm_batch" << std::endl; } extern "C" void onemklSsymm(syclQueue_t device_queue, onemklSide left_right, diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl index 0dfea776..4d677a1f 100644 --- a/lib/support/liboneapi_support.jl +++ b/lib/support/liboneapi_support.jl @@ -137,9 +137,9 @@ function onemklSgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda transa::onemklTranspose, transb::onemklTranspose, m::Int64, n::Int64, k::Int64, alpha::Cfloat, - a::ZePtr{Ptr{Cfloat}}, lda::Int64, - b::ZePtr{Ptr{Cfloat}}, ldb::Int64, - beta::Cfloat, c::ZePtr{Ptr{Cfloat}}, + a::ZePtr{ZePtr{Cfloat}}, lda::Int64, + b::ZePtr{ZePtr{Cfloat}}, ldb::Int64, + beta::Cfloat, c::ZePtr{ZePtr{Cfloat}}, ldc::Int64, group_count::Int64)::Cvoid end diff --git a/test/onemkl.jl b/test/onemkl.jl index 3b58b7e8..1a3e70f1 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -815,7 +815,7 @@ end @testset for T in [Float32] alpha = rand(T) beta = rand(T) - group_count = 1 + group_count = 10 # generate matrices bA = [rand(T,m,k) for i in 1:group_count] bB = [rand(T,k,n) for i in 1:group_count] From 224d87ea2c1bb4e8f705e452cbe10b87a2ad9843 Mon Sep 17 00:00:00 2001 From: kballeda Date: Mon, 27 Feb 2023 11:00:36 +0530 Subject: [PATCH 05/14] support gemm_batched and increase test case count --- lib/mkl/wrappers.jl | 10 ++++++++++ test/onemkl.jl | 12 +++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index 4fca7f61..45bf67f5 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -94,6 +94,16 @@ for (fname, elty) in end end +function gemm_batched(transA::Char, transB::Char, alpha::Number, + A::Vector{<:oneStridedMatrix{T}}, B::Vector{<:oneStridedMatrix{T}}) where T + C = oneMatrix{T}[similar(B[1], (size(A[1], transA == 'N' ? 1 : 2),size(B[1], transB == 'N' ? 2 : 1))) for i in 1:length(A)] + gemm_batched!(transA, transB, alpha, A, B, zero(T), C ) +end +function gemm_batched(transA::Char, transB::Char, + A::Vector{<:oneStridedMatrix{T}}, B::Vector{<:oneStridedMatrix{T}}) where T + gemm_batched(transA, transB, one(T), A, B) +end + ## (L3: symm) symmetric matrix-matrix and matrix-vector multiplication for (fname, elty) in ((:onemklSsymm, :Float32), (:onemklDsymm, :Float64), diff --git a/test/onemkl.jl b/test/onemkl.jl index 1a3e70f1..9fdb6401 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -815,7 +815,7 @@ end @testset for T in [Float32] alpha = rand(T) beta = rand(T) - group_count = 10 + group_count = 20 # generate matrices bA = [rand(T,m,k) for i in 1:group_count] bB = [rand(T,k,n) for i in 1:group_count] @@ -845,4 +845,14 @@ end end @test_throws DimensionMismatch oneMKL.gemm_batched!('N','N',alpha,bd_A,bd_bad,beta,bd_C) end + + @testset "gemm_batched" begin + bd_C = oneMKL.gemm_batched('N','N',bd_A,bd_B) + for i in 1:length(bA) + bC = bA[i]*bB[i] + h_C = Array(bd_C[i]) + @test bC ≈ h_C + end + @test_throws DimensionMismatch oneMKL.gemm_batched('N','N',alpha,bd_A,bd_bad) + end end \ No newline at end of file From bcbf9c85ca75f0c671bb9e4f1134491567d1557b Mon Sep 17 00:00:00 2001 From: kballeda Date: Mon, 27 Feb 2023 11:15:08 +0530 Subject: [PATCH 06/14] unsafe_free enabled --- lib/mkl/oneMKL.jl | 2 +- lib/mkl/wrappers.jl | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/mkl/oneMKL.jl b/lib/mkl/oneMKL.jl index fd57e3ea..735dce03 100644 --- a/lib/mkl/oneMKL.jl +++ b/lib/mkl/oneMKL.jl @@ -1,7 +1,7 @@ module oneMKL using ..oneAPI - +using ..oneAPI: unsafe_free! using ..oneL0 using ..Support diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index 45bf67f5..6d7269a5 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -86,9 +86,9 @@ for (fname, elty) in queue = global_queue(context(A[1]), device(A[1])) $fname(sycl_queue(queue), transA, transB, m, n, k, alpha, Aptrs, lda, Bptrs, ldb, beta, Cptrs, ldc, length(A)) - #unsafe_free!(Cptrs) - #unsafe_free!(Bptrs) - #unsafe_free!(Aptrs) + unsafe_free!(Cptrs) + unsafe_free!(Bptrs) + unsafe_free!(Aptrs) C end end From 0e6cd7d8b3cc90bc64893e0736c8c395c695347a Mon Sep 17 00:00:00 2001 From: kballeda Date: Mon, 27 Feb 2023 13:13:20 +0530 Subject: [PATCH 07/14] cleanup mem init code into class and enable all tests --- deps/src/onemkl.cpp | 106 ++++++++++++++++++++++++++++---------------- test/onemkl.jl | 80 ++++++++++++++++----------------- 2 files changed, 109 insertions(+), 77 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 00111007..80c220b9 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -52,6 +52,62 @@ oneapi::mkl::side convert(onemklSide val) { } } +template +class gemmBatchInfo { + public: + void memInit(syclQueue_t device_queue, + int64_t group_count, + onemklTranspose transa, + onemklTranspose transb, + int64_t m, int64_t n, int64_t k, + int64_t lda, int64_t ldb, int64_t ldc, + T alpha, T beta + ) { + auto main_queue = device_queue->val; + auto device = main_queue.get_device(); + auto context = main_queue.get_context(); + m_mbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + m_nbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + m_kbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + m_ldabuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + m_ldbbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + m_ldcbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + m_alphabuf = (T *) malloc_shared(group_count * sizeof(float), device, context); + m_betabuf = (T *) malloc_shared(group_count * sizeof(float), device, context); + m_transa = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), device, context); + m_transb = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), device, context); + auto t_a = convert(transa); + auto t_b = convert(transb); + m_group_size = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + + for (int i = 0; i < group_count; i++) { + m_mbuf[i] = m; + m_nbuf[i] = n; + m_kbuf[i] = k; + m_ldabuf[i] = lda; + m_ldbbuf[i] = ldb; + m_ldcbuf[i] = ldc; + m_alphabuf[i] = alpha; + m_betabuf[i] = beta; + m_transa[i] = t_a; + m_transb[i] = t_b; + m_group_size[i] = 1; + } + }; + + int64_t *m_mbuf = nullptr; + int64_t *m_nbuf = nullptr; + int64_t *m_kbuf = nullptr; + int64_t *m_ldabuf = nullptr; + int64_t *m_ldbbuf = nullptr; + int64_t *m_ldcbuf = nullptr; + oneapi::mkl::transpose *m_transa = nullptr; + oneapi::mkl::transpose *m_transb = nullptr; + T *m_alphabuf = nullptr; + T *m_betabuf = nullptr; + int64_t *m_group_size = nullptr; +}; + extern "C" int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA, onemklTranspose transB, int64_t m, int64_t n, int64_t k, sycl::half alpha, const sycl::half *A, int64_t lda, @@ -126,43 +182,19 @@ extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose tra const float **a, int64_t lda, const float **b, int64_t ldb, float beta, float **c, int64_t ldc, int64_t group_count) { - auto main_queue = device_queue->val; - auto device = main_queue.get_device(); - auto context = main_queue.get_context(); - - auto m_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - auto n_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - auto k_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - auto lda_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - auto ldb_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - auto ldc_dev = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - auto alpha_dev = (float *) malloc_shared(group_count * sizeof(float), device, context); - auto beta_dev = (float *) malloc_shared(group_count * sizeof(float), device, context); - auto transa_dev = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), device, context); - auto transb_dev = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), device, context); - auto t_a = convert(transa); - auto t_b = convert(transb); - auto group_size = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - for (int i = 0; i < group_count; i++) { - m_dev[i] = m; - n_dev[i] = n; - k_dev[i] = k; - lda_dev[i] = lda; - ldb_dev[i] = ldb; - ldc_dev[i] = ldc; - alpha_dev[i] = alpha; - beta_dev[i] = beta; - transa_dev[i] = t_a; - transb_dev[i] = t_b; - group_size[i] = 1; - } - - auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, &transa_dev[0], - &transb_dev[0], &m_dev[0], &n_dev[0], &k_dev[0], - &alpha_dev[0], (const float **)&a[0], - &lda_dev[0], (const float **)&b[0], - &ldb_dev[0], &beta_dev[0], - &c[0], &ldc_dev[0], group_count, &group_size[0]); + gemmBatchInfo gemmInfo; + gemmInfo.memInit(device_queue, group_count, transa, transb, + m, n, k, lda, ldb, ldc, alpha, beta); + + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, + &gemmInfo.m_transa[0], &gemmInfo.m_transb[0], + &gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0], + &gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0], + (const float **)&a[0], &gemmInfo.m_ldabuf[0], + (const float **)&b[0], &gemmInfo.m_ldbbuf[0], + &gemmInfo.m_betabuf[0], &c[0], &gemmInfo.m_ldcbuf[0], + group_count, &gemmInfo.m_group_size[0]); + __FORCE_MKL_FLUSH__(status); } diff --git a/test/onemkl.jl b/test/onemkl.jl index 9fdb6401..f801eca0 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -7,7 +7,6 @@ m = 20 n = 35 k = 13 -#= ############################################################################################ @testset "level 1" begin @testset for T in intersect(eltypes, [Float32, Float64, ComplexF32, ComplexF64]) @@ -810,49 +809,50 @@ end end end end -=# - -@testset for T in [Float32] - alpha = rand(T) - beta = rand(T) - group_count = 20 - # generate matrices - bA = [rand(T,m,k) for i in 1:group_count] - bB = [rand(T,k,n) for i in 1:group_count] - bC = [rand(T,m,n) for i in 1:group_count] - # move to device - bd_A = oneArray{T, 2}[] - bd_B = oneArray{T, 2}[] - bd_C = oneArray{T, 2}[] - bd_bad = oneArray{T, 2}[] - for i in 1:length(bA) - push!(bd_A,oneArray(bA[i])) - push!(bd_B,oneArray(bB[i])) - push!(bd_C,oneArray(bC[i])) - if i < length(bA) - 2 - push!(bd_bad,oneArray(bC[i])) + +@testset "Batch Primitives" begin + @testset for T in [Float32] + alpha = rand(T) + beta = rand(T) + group_count = 20 + # generate matrices + bA = [rand(T,m,k) for i in 1:group_count] + bB = [rand(T,k,n) for i in 1:group_count] + bC = [rand(T,m,n) for i in 1:group_count] + # move to device + bd_A = oneArray{T, 2}[] + bd_B = oneArray{T, 2}[] + bd_C = oneArray{T, 2}[] + bd_bad = oneArray{T, 2}[] + for i in 1:length(bA) + push!(bd_A,oneArray(bA[i])) + push!(bd_B,oneArray(bB[i])) + push!(bd_C,oneArray(bC[i])) + if i < length(bA) - 2 + push!(bd_bad,oneArray(bC[i])) + end end - end - @testset "gemm_batched!" begin - # C = (alpha*A)*B + beta*C - oneMKL.gemm_batched!('N','N',alpha,bd_A,bd_B,beta,bd_C) - for i in 1:length(bd_C) - bC[i] = (alpha*bA[i])*bB[i] + beta*bC[i] - h_C = Array(bd_C[i]) - #compare - @test bC[i] ≈ h_C + @testset "gemm_batched!" begin + # C = (alpha*A)*B + beta*C + oneMKL.gemm_batched!('N','N',alpha,bd_A,bd_B,beta,bd_C) + for i in 1:length(bd_C) + bC[i] = (alpha*bA[i])*bB[i] + beta*bC[i] + h_C = Array(bd_C[i]) + #compare + @test bC[i] ≈ h_C + end + @test_throws DimensionMismatch oneMKL.gemm_batched!('N','N',alpha,bd_A,bd_bad,beta,bd_C) end - @test_throws DimensionMismatch oneMKL.gemm_batched!('N','N',alpha,bd_A,bd_bad,beta,bd_C) - end - @testset "gemm_batched" begin - bd_C = oneMKL.gemm_batched('N','N',bd_A,bd_B) - for i in 1:length(bA) - bC = bA[i]*bB[i] - h_C = Array(bd_C[i]) - @test bC ≈ h_C + @testset "gemm_batched" begin + bd_C = oneMKL.gemm_batched('N','N',bd_A,bd_B) + for i in 1:length(bA) + bC = bA[i]*bB[i] + h_C = Array(bd_C[i]) + @test bC ≈ h_C + end + @test_throws DimensionMismatch oneMKL.gemm_batched('N','N',alpha,bd_A,bd_bad) end - @test_throws DimensionMismatch oneMKL.gemm_batched('N','N',alpha,bd_A,bd_bad) end end \ No newline at end of file From 7f2975964ccbc0e62fc8b1bb5fb342fabd2382be Mon Sep 17 00:00:00 2001 From: kballeda Date: Mon, 27 Feb 2023 13:38:56 +0530 Subject: [PATCH 08/14] cleanup into const & dest --- deps/src/onemkl.cpp | 86 +++++++++++++++++++++++++++------------------ 1 file changed, 52 insertions(+), 34 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 80c220b9..2354d94a 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -55,30 +55,46 @@ oneapi::mkl::side convert(onemklSide val) { template class gemmBatchInfo { public: - void memInit(syclQueue_t device_queue, - int64_t group_count, - onemklTranspose transa, - onemklTranspose transb, - int64_t m, int64_t n, int64_t k, - int64_t lda, int64_t ldb, int64_t ldc, - T alpha, T beta - ) { + int64_t *m_mbuf = nullptr; + int64_t *m_nbuf = nullptr; + int64_t *m_kbuf = nullptr; + int64_t *m_ldabuf = nullptr; + int64_t *m_ldbbuf = nullptr; + int64_t *m_ldcbuf = nullptr; + oneapi::mkl::transpose *m_transa = nullptr; + oneapi::mkl::transpose *m_transb = nullptr; + T *m_alphabuf = nullptr; + T *m_betabuf = nullptr; + int64_t *m_group_size = nullptr; + sycl::device m_device; + sycl::context m_context; + + // Constructor + gemmBatchInfo(syclQueue_t device_queue, + int64_t group_count, + onemklTranspose transa, + onemklTranspose transb, + int64_t m, int64_t n, int64_t k, + int64_t lda, int64_t ldb, int64_t ldc, + T alpha, T beta) { auto main_queue = device_queue->val; - auto device = main_queue.get_device(); - auto context = main_queue.get_context(); - m_mbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - m_nbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - m_kbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - m_ldabuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - m_ldbbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - m_ldcbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); - m_alphabuf = (T *) malloc_shared(group_count * sizeof(float), device, context); - m_betabuf = (T *) malloc_shared(group_count * sizeof(float), device, context); - m_transa = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), device, context); - m_transb = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), device, context); + m_device = main_queue.get_device(); + m_context = main_queue.get_context(); + m_mbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_nbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_kbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_ldabuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_ldbbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_ldcbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_alphabuf = (T *) malloc_shared(group_count * sizeof(float), m_device, m_context); + m_betabuf = (T *) malloc_shared(group_count * sizeof(float), m_device, m_context); + m_transa = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), + m_device, m_context); + m_transb = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), + m_device, m_context); auto t_a = convert(transa); auto t_b = convert(transb); - m_group_size = (int64_t *) malloc_shared(group_count * sizeof(int64_t), device, context); + m_group_size = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); for (int i = 0; i < group_count; i++) { m_mbuf[i] = m; @@ -95,17 +111,20 @@ class gemmBatchInfo { } }; - int64_t *m_mbuf = nullptr; - int64_t *m_nbuf = nullptr; - int64_t *m_kbuf = nullptr; - int64_t *m_ldabuf = nullptr; - int64_t *m_ldbbuf = nullptr; - int64_t *m_ldcbuf = nullptr; - oneapi::mkl::transpose *m_transa = nullptr; - oneapi::mkl::transpose *m_transb = nullptr; - T *m_alphabuf = nullptr; - T *m_betabuf = nullptr; - int64_t *m_group_size = nullptr; + // Destructor + ~gemmBatchInfo() { + free(m_mbuf, m_context); + free(m_nbuf, m_context); + free(m_kbuf, m_context); + free(m_ldabuf, m_context); + free(m_ldbbuf, m_context); + free(m_ldcbuf, m_context); + free(m_alphabuf, m_context); + free(m_betabuf, m_context); + free(m_transa, m_context); + free(m_transb, m_context); + free(m_group_size, m_context); + } }; extern "C" int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA, @@ -182,8 +201,7 @@ extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose tra const float **a, int64_t lda, const float **b, int64_t ldb, float beta, float **c, int64_t ldc, int64_t group_count) { - gemmBatchInfo gemmInfo; - gemmInfo.memInit(device_queue, group_count, transa, transb, + gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb, m, n, k, lda, ldb, ldc, alpha, beta); auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, From 185cd4d476b63fa7e6a7d3702550784714a1c752 Mon Sep 17 00:00:00 2001 From: kballeda Date: Mon, 27 Feb 2023 14:20:29 +0530 Subject: [PATCH 09/14] automatically generate liboneapi and cleanup --- deps/src/onemkl.h | 7 +++++++ lib/support/liboneapi_support.jl | 18 +++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index d8dd2cdd..5056d1db 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -61,6 +61,13 @@ void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose transa, int64_t ldb, float beta, float **c, int64_t ldc, int64_t group_count); +void onemklDgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, double alpha, + const double **a, int64_t lda, const double **b, + int64_t ldb, double beta, double **c, + int64_t ldc, int64_t group_count); + void onemklSsymm(syclQueue_t device_queue, onemklSide left_right, onemklUplo upper_lower, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, const float *b, diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl index 4d677a1f..b4032ee4 100644 --- a/lib/support/liboneapi_support.jl +++ b/lib/support/liboneapi_support.jl @@ -137,9 +137,21 @@ function onemklSgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda transa::onemklTranspose, transb::onemklTranspose, m::Int64, n::Int64, k::Int64, alpha::Cfloat, - a::ZePtr{ZePtr{Cfloat}}, lda::Int64, - b::ZePtr{ZePtr{Cfloat}}, ldb::Int64, - beta::Cfloat, c::ZePtr{ZePtr{Cfloat}}, + a::ZePtr{Ptr{Cfloat}}, lda::Int64, + b::ZePtr{Ptr{Cfloat}}, ldb::Int64, + beta::Cfloat, c::ZePtr{Ptr{Cfloat}}, + ldc::Int64, group_count::Int64)::Cvoid +end + +function onemklDgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, group_count) + @ccall liboneapi_support.onemklDgemmBatched(device_queue::syclQueue_t, + transa::onemklTranspose, + transb::onemklTranspose, m::Int64, n::Int64, + k::Int64, alpha::Cdouble, + a::Ptr{Ptr{Cdouble}}, lda::Int64, + b::Ptr{Ptr{Cdouble}}, ldb::Int64, + beta::Cdouble, c::Ptr{Ptr{Cdouble}}, ldc::Int64, group_count::Int64)::Cvoid end From c07717ccb625641c8ba82edf085ca6736b250eeb Mon Sep 17 00:00:00 2001 From: kballeda Date: Mon, 27 Feb 2023 14:29:52 +0530 Subject: [PATCH 10/14] dgemm supported --- deps/src/onemkl.cpp | 22 ++++++++++++++++++++++ lib/support/liboneapi_support.jl | 6 +++--- test/onemkl.jl | 2 +- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 2354d94a..34a98b6d 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -217,6 +217,28 @@ extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose tra } +extern "C" void onemklDgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, double alpha, + const double **a, int64_t lda, const double **b, + int64_t ldb, double beta, double **c, + int64_t ldc, int64_t group_count) { + gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb, + m, n, k, lda, ldb, ldc, alpha, beta); + + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, + &gemmInfo.m_transa[0], &gemmInfo.m_transb[0], + &gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0], + &gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0], + (const double **)&a[0], &gemmInfo.m_ldabuf[0], + (const double **)&b[0], &gemmInfo.m_ldbbuf[0], + &gemmInfo.m_betabuf[0], &c[0], &gemmInfo.m_ldcbuf[0], + group_count, &gemmInfo.m_group_size[0]); + + __FORCE_MKL_FLUSH__(status); + +} + extern "C" void onemklSsymm(syclQueue_t device_queue, onemklSide left_right, onemklUplo upper_lower, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, const float *b, diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl index b4032ee4..84483ecf 100644 --- a/lib/support/liboneapi_support.jl +++ b/lib/support/liboneapi_support.jl @@ -149,9 +149,9 @@ function onemklDgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda transa::onemklTranspose, transb::onemklTranspose, m::Int64, n::Int64, k::Int64, alpha::Cdouble, - a::Ptr{Ptr{Cdouble}}, lda::Int64, - b::Ptr{Ptr{Cdouble}}, ldb::Int64, - beta::Cdouble, c::Ptr{Ptr{Cdouble}}, + a::ZePtr{Ptr{Cdouble}}, lda::Int64, + b::ZePtr{Ptr{Cdouble}}, ldb::Int64, + beta::Cdouble, c::ZePtr{Ptr{Cdouble}}, ldc::Int64, group_count::Int64)::Cvoid end diff --git a/test/onemkl.jl b/test/onemkl.jl index f801eca0..2717d52d 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -811,7 +811,7 @@ end end @testset "Batch Primitives" begin - @testset for T in [Float32] + @testset for T in [Float32, Float64] alpha = rand(T) beta = rand(T) group_count = 20 From 8d03b1303e1754a6e71079afef192b722466b95c Mon Sep 17 00:00:00 2001 From: kballeda Date: Mon, 27 Feb 2023 14:53:21 +0530 Subject: [PATCH 11/14] support cgemm_batch --- deps/src/onemkl.cpp | 26 ++++++++++++++++++++++++++ deps/src/onemkl.h | 9 +++++++++ lib/mkl/wrappers.jl | 3 ++- lib/support/liboneapi_support.jl | 12 ++++++++++++ test/onemkl.jl | 2 +- 5 files changed, 50 insertions(+), 2 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 34a98b6d..ff6104d8 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -239,6 +239,32 @@ extern "C" void onemklDgemmBatched(syclQueue_t device_queue, onemklTranspose tra } +extern "C" void onemklCgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, float _Complex alpha, + const float _Complex **a, int64_t lda, + const float _Complex **b, + int64_t ldb, float _Complex beta, float _Complex **c, + int64_t ldc, int64_t group_count) { + gemmBatchInfo> gemmInfo(device_queue, group_count, transa, transb, + m, n, k, lda, ldb, ldc, alpha, beta); + + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, + &gemmInfo.m_transa[0], &gemmInfo.m_transb[0], + &gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0], + &gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0], + reinterpret_cast **>(&a[0]), + &gemmInfo.m_ldabuf[0], + reinterpret_cast **>(&b[0]), + &gemmInfo.m_ldbbuf[0], + &gemmInfo.m_betabuf[0], + reinterpret_cast **>(&c[0]), &gemmInfo.m_ldcbuf[0], + group_count, &gemmInfo.m_group_size[0]); + + __FORCE_MKL_FLUSH__(status); + +} + extern "C" void onemklSsymm(syclQueue_t device_queue, onemklSide left_right, onemklUplo upper_lower, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, const float *b, diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index 5056d1db..8e4eb3cf 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -68,6 +68,15 @@ void onemklDgemmBatched(syclQueue_t device_queue, onemklTranspose transa, int64_t ldb, double beta, double **c, int64_t ldc, int64_t group_count); +void onemklCgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, float _Complex alpha, + const float _Complex **a, int64_t lda, + const float _Complex **b, + int64_t ldb, float _Complex beta, + float _Complex **c, int64_t ldc, + int64_t group_count); + void onemklSsymm(syclQueue_t device_queue, onemklSide left_right, onemklUplo upper_lower, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, const float *b, diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index 6d7269a5..6eed4c68 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -53,7 +53,8 @@ end ## (GE) general matrix-matrix multiplication batched for (fname, elty) in ((:onemklDgemmBatched,:Float64), - (:onemklSgemmBatched,:Float32)) + (:onemklSgemmBatched,:Float32), + (:onemklCgemmBatched,:ComplexF32)) @eval begin function gemm_batched!(transA::Char, transB::Char, diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl index 84483ecf..4727b57d 100644 --- a/lib/support/liboneapi_support.jl +++ b/lib/support/liboneapi_support.jl @@ -155,6 +155,18 @@ function onemklDgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda ldc::Int64, group_count::Int64)::Cvoid end +function onemklCgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, group_count) + @ccall liboneapi_support.onemklCgemmBatched(device_queue::syclQueue_t, + transa::onemklTranspose, + transb::onemklTranspose, m::Int64, n::Int64, + k::Int64, alpha::ComplexF32, + a::ZePtr{Ptr{ComplexF32}}, lda::Int64, + b::ZePtr{Ptr{ComplexF32}}, ldb::Int64, + beta::ComplexF32, c::ZePtr{Ptr{ComplexF32}}, + ldc::Int64, group_count::Int64)::Cvoid +end + function onemklSsymm(device_queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc) @ccall liboneapi_support.onemklSsymm(device_queue::syclQueue_t, left_right::onemklSide, diff --git a/test/onemkl.jl b/test/onemkl.jl index 2717d52d..b8166efe 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -811,7 +811,7 @@ end end @testset "Batch Primitives" begin - @testset for T in [Float32, Float64] + @testset for T in [Float32, Float64, ComplexF32] alpha = rand(T) beta = rand(T) group_count = 20 From ce6e4cef6b3d86412774dd9ed926932e76c1c1d9 Mon Sep 17 00:00:00 2001 From: kballeda Date: Mon, 27 Feb 2023 15:35:26 +0530 Subject: [PATCH 12/14] support zgemm_batch --- deps/src/onemkl.cpp | 30 ++++++++++++++++++++++++++++-- deps/src/onemkl.h | 9 +++++++++ lib/mkl/wrappers.jl | 3 ++- lib/support/liboneapi_support.jl | 12 ++++++++++++ test/onemkl.jl | 2 +- 5 files changed, 52 insertions(+), 4 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index ff6104d8..24f89101 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -86,8 +86,8 @@ class gemmBatchInfo { m_ldabuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); m_ldbbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); m_ldcbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); - m_alphabuf = (T *) malloc_shared(group_count * sizeof(float), m_device, m_context); - m_betabuf = (T *) malloc_shared(group_count * sizeof(float), m_device, m_context); + m_alphabuf = (T *) malloc_shared(group_count * sizeof(T), m_device, m_context); + m_betabuf = (T *) malloc_shared(group_count * sizeof(T), m_device, m_context); m_transa = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), m_device, m_context); m_transb = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), @@ -265,6 +265,32 @@ extern "C" void onemklCgemmBatched(syclQueue_t device_queue, onemklTranspose tra } +extern "C" void onemklZgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, double _Complex alpha, + const double _Complex **a, int64_t lda, + const double _Complex **b, + int64_t ldb, double _Complex beta, + double _Complex **c, + int64_t ldc, int64_t group_count) { + gemmBatchInfo> gemmInfo(device_queue, group_count, transa, transb, + m, n, k, lda, ldb, ldc, alpha, beta); + + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, + &gemmInfo.m_transa[0], &gemmInfo.m_transb[0], + &gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0], + &gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0], + reinterpret_cast **>(&a[0]), + &gemmInfo.m_ldabuf[0], + reinterpret_cast **>(&b[0]), + &gemmInfo.m_ldbbuf[0], + &gemmInfo.m_betabuf[0], + reinterpret_cast **>(&c[0]), &gemmInfo.m_ldcbuf[0], + group_count, &gemmInfo.m_group_size[0]); + + __FORCE_MKL_FLUSH__(status); +} + extern "C" void onemklSsymm(syclQueue_t device_queue, onemklSide left_right, onemklUplo upper_lower, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, const float *b, diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index 8e4eb3cf..5d1031c5 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -77,6 +77,15 @@ void onemklCgemmBatched(syclQueue_t device_queue, onemklTranspose transa, float _Complex **c, int64_t ldc, int64_t group_count); +void onemklZgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, double _Complex alpha, + const double _Complex **a, int64_t lda, + const double _Complex **b, + int64_t ldb, double _Complex beta, + double _Complex **c, int64_t ldc, + int64_t group_count); + void onemklSsymm(syclQueue_t device_queue, onemklSide left_right, onemklUplo upper_lower, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, const float *b, diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index 6eed4c68..39dca8bc 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -54,7 +54,8 @@ end for (fname, elty) in ((:onemklDgemmBatched,:Float64), (:onemklSgemmBatched,:Float32), - (:onemklCgemmBatched,:ComplexF32)) + (:onemklCgemmBatched,:ComplexF32), + (:onemklZgemmBatched,:ComplexF64)) @eval begin function gemm_batched!(transA::Char, transB::Char, diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl index 4727b57d..6027350e 100644 --- a/lib/support/liboneapi_support.jl +++ b/lib/support/liboneapi_support.jl @@ -167,6 +167,18 @@ function onemklCgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda ldc::Int64, group_count::Int64)::Cvoid end +function onemklZgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, group_count) + @ccall liboneapi_support.onemklZgemmBatched(device_queue::syclQueue_t, + transa::onemklTranspose, + transb::onemklTranspose, m::Int64, n::Int64, + k::Int64, alpha::ComplexF64, + a::ZePtr{Ptr{ComplexF64}}, lda::Int64, + b::ZePtr{Ptr{ComplexF64}}, ldb::Int64, + beta::ComplexF64, c::ZePtr{Ptr{ComplexF64}}, + ldc::Int64, group_count::Int64)::Cvoid +end + function onemklSsymm(device_queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, c, ldc) @ccall liboneapi_support.onemklSsymm(device_queue::syclQueue_t, left_right::onemklSide, diff --git a/test/onemkl.jl b/test/onemkl.jl index b8166efe..a8b872a4 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -811,7 +811,7 @@ end end @testset "Batch Primitives" begin - @testset for T in [Float32, Float64, ComplexF32] + @testset for T in [Float32, Float64, ComplexF32, ComplexF64] alpha = rand(T) beta = rand(T) group_count = 20 From 9b12795ebc56121b5d64840a00a81c5068c1b5ba Mon Sep 17 00:00:00 2001 From: kballeda Date: Mon, 27 Feb 2023 16:14:45 +0530 Subject: [PATCH 13/14] memory alloc failure handling --- deps/src/onemkl.cpp | 48 +++++++++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 19 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 24f89101..26eb0f51 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -1,7 +1,8 @@ #include "onemkl.h" #include "sycl.hpp" #include -#include +#include +#include #include // This is a workaround to flush MKL submissions into Level-zero queue, using @@ -68,6 +69,8 @@ class gemmBatchInfo { int64_t *m_group_size = nullptr; sycl::device m_device; sycl::context m_context; + oneapi::mkl::transpose m_ta; + oneapi::mkl::transpose m_tb; // Constructor gemmBatchInfo(syclQueue_t device_queue, @@ -77,25 +80,32 @@ class gemmBatchInfo { int64_t m, int64_t n, int64_t k, int64_t lda, int64_t ldb, int64_t ldc, T alpha, T beta) { + // Get device and context info from device_queue auto main_queue = device_queue->val; m_device = main_queue.get_device(); m_context = main_queue.get_context(); - m_mbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); - m_nbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); - m_kbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); - m_ldabuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); - m_ldbbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); - m_ldcbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); - m_alphabuf = (T *) malloc_shared(group_count * sizeof(T), m_device, m_context); - m_betabuf = (T *) malloc_shared(group_count * sizeof(T), m_device, m_context); - m_transa = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), - m_device, m_context); - m_transb = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), - m_device, m_context); - auto t_a = convert(transa); - auto t_b = convert(transb); - m_group_size = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); - + try { + // Allocate uniform arrays of m,n,k,lda,ldb,ldc,alpha,beta + // group_size and transpose_a, transpose_b supporting oneMKL + // gemm_batch API + m_mbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_nbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_kbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_ldabuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_ldbbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_ldcbuf = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + m_alphabuf = (T *) malloc_shared(group_count * sizeof(T), m_device, m_context); + m_betabuf = (T *) malloc_shared(group_count * sizeof(T), m_device, m_context); + m_transa = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), + m_device, m_context); + m_transb = (oneapi::mkl::transpose *) malloc_shared(group_count * sizeof(oneapi::mkl::transpose), + m_device, m_context); + m_ta = convert(transa); + m_tb = convert(transb); + m_group_size = (int64_t *) malloc_shared(group_count * sizeof(int64_t), m_device, m_context); + } catch(const std::bad_alloc& e) { + std::cerr << "Error: " << e.what() << std::endl; + } for (int i = 0; i < group_count; i++) { m_mbuf[i] = m; m_nbuf[i] = n; @@ -105,8 +115,8 @@ class gemmBatchInfo { m_ldcbuf[i] = ldc; m_alphabuf[i] = alpha; m_betabuf[i] = beta; - m_transa[i] = t_a; - m_transb[i] = t_b; + m_transa[i] = m_ta; + m_transb[i] = m_tb; m_group_size[i] = 1; } }; From 20ef11557ca3a503e3592444e340638d66e6cc62 Mon Sep 17 00:00:00 2001 From: kballeda Date: Mon, 27 Feb 2023 19:00:51 +0530 Subject: [PATCH 14/14] support half type gemm_batch --- deps/src/onemkl.cpp | 22 ++++++++++++++++++++++ deps/src/onemkl.h | 7 +++++++ lib/mkl/wrappers.jl | 1 + lib/support/liboneapi_support.jl | 24 ++++++++++++++++++------ test/onemkl.jl | 3 ++- 5 files changed, 50 insertions(+), 7 deletions(-) diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 41416f31..35b0d01e 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -208,6 +208,28 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, return 0; } +extern "C" void onemklHgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, uint16_t alpha, + const short **a, int64_t lda, const short **b, + int64_t ldb, uint16_t beta, short **c, + int64_t ldc, int64_t group_count) { + gemmBatchInfo gemmInfo(device_queue, group_count, transa, transb, + m, n, k, lda, ldb, ldc, sycl::bit_cast(alpha), + sycl::bit_cast(beta)); + auto status = oneapi::mkl::blas::column_major::gemm_batch(device_queue->val, + &gemmInfo.m_transa[0], &gemmInfo.m_transb[0], + &gemmInfo.m_mbuf[0], &gemmInfo.m_nbuf[0], + &gemmInfo.m_kbuf[0], &gemmInfo.m_alphabuf[0], + reinterpret_cast(&a[0]), &gemmInfo.m_ldabuf[0], + reinterpret_cast(&b[0]), &gemmInfo.m_ldbbuf[0], + &gemmInfo.m_betabuf[0], reinterpret_cast(&c[0]), + &gemmInfo.m_ldcbuf[0], group_count, &gemmInfo.m_group_size[0]); + + __FORCE_MKL_FLUSH__(status); + +} + extern "C" void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose transa, onemklTranspose transb, int64_t m, int64_t n, int64_t k, float alpha, diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index 2ae5dadc..2953ec7d 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -59,6 +59,13 @@ int onemklHgemm(syclQueue_t device_queue, onemklTranspose transA, const short *B, int64_t ldb, uint16_t beta, short *C, int64_t ldc); +void onemklHgemmBatched(syclQueue_t device_queue, onemklTranspose transa, + onemklTranspose transb, int64_t m, + int64_t n, int64_t k, uint16_t alpha, + const short **a, int64_t lda, const short **b, + int64_t ldb, uint16_t beta, short **c, + int64_t ldc, int64_t group_count); + void onemklSgemmBatched(syclQueue_t device_queue, onemklTranspose transa, onemklTranspose transb, int64_t m, int64_t n, int64_t k, float alpha, diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index 45b1b2be..ef9ae284 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -54,6 +54,7 @@ end for (fname, elty) in ((:onemklDgemmBatched,:Float64), (:onemklSgemmBatched,:Float32), + (:onemklHgemmBatched,:Float16), (:onemklCgemmBatched,:ComplexF32), (:onemklZgemmBatched,:ComplexF64)) @eval begin diff --git a/lib/support/liboneapi_support.jl b/lib/support/liboneapi_support.jl index 8eb2a0ec..c9646c02 100644 --- a/lib/support/liboneapi_support.jl +++ b/lib/support/liboneapi_support.jl @@ -132,12 +132,24 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld end function onemklHgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ldb, beta, C, - ldc) -@ccall liboneapi_support.onemklHgemm(device_queue::syclQueue_t, transA::onemklTranspose, - transB::onemklTranspose, m::Int64, n::Int64, - k::Int64, alpha::Float16, A::ZePtr{Float16}, - lda::Int64, B::ZePtr{Float16}, ldb::Int64, - beta::Float16, C::ZePtr{Float16}, ldc::Int64)::Cint + ldc) + @ccall liboneapi_support.onemklHgemm(device_queue::syclQueue_t, transA::onemklTranspose, + transB::onemklTranspose, m::Int64, n::Int64, + k::Int64, alpha::Float16, A::ZePtr{Float16}, + lda::Int64, B::ZePtr{Float16}, ldb::Int64, + beta::Float16, C::ZePtr{Float16}, ldc::Int64)::Cint +end + +function onemklHgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, + beta, c, ldc, group_count) + @ccall liboneapi_support.onemklHgemmBatched(device_queue::syclQueue_t, + transa::onemklTranspose, + transb::onemklTranspose, m::Int64, n::Int64, + k::Int64, alpha::Float16, + a::ZePtr{Ptr{Float16}}, lda::Int64, + b::ZePtr{Ptr{Float16}}, ldb::Int64, + beta::Float16, c::ZePtr{Ptr{Float16}}, + ldc::Int64, group_count::Int64)::Cvoid end function onemklSgemmBatched(device_queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, diff --git a/test/onemkl.jl b/test/onemkl.jl index b812b5e3..51d229e7 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -893,8 +893,9 @@ end end end + @testset "BLAS Extension" begin - @testset for T in [Float32, Float64, ComplexF32, ComplexF64] + @testset for T in [Float16, Float32, Float64, ComplexF32, ComplexF64] alpha = rand(T) beta = rand(T) group_count = 20