diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index 379ec1c8..c9afc204 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -81,6 +81,32 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, return 0; } +extern "C" void onemklDnrm2(syclQueue_t device_queue, int64_t n, const double *x, + int64_t incx, double *result) { + auto status = oneapi::mkl::blas::column_major::nrm2(device_queue->val, n, x, incx, result); + status.wait(); +} + +extern "C" void onemklSnrm2(syclQueue_t device_queue, int64_t n, const float *x, + int64_t incx, float *result) { + auto status = oneapi::mkl::blas::column_major::nrm2(device_queue->val, n, x, incx, result); + status.wait(); +} + +extern "C" void onemklCnrm2(syclQueue_t device_queue, int64_t n, const float _Complex *x, + int64_t incx, float *result) { + auto status = oneapi::mkl::blas::column_major::nrm2(device_queue->val, n, + reinterpret_cast *>(x), incx, result); + status.wait(); +} + +extern "C" void onemklZnrm2(syclQueue_t device_queue, int64_t n, const double _Complex *x, + int64_t incx, double *result) { + auto status = oneapi::mkl::blas::column_major::nrm2(device_queue->val, n, + reinterpret_cast *>(x), incx, result); + status.wait(); +} + extern "C" void onemklDcopy(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx, double *y, int64_t incy) { oneapi::mkl::blas::column_major::copy(device_queue->val, n, x, incx, y, incy); diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index 80e9aa31..98406faa 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -39,6 +39,16 @@ int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, const double _Complex *B, int64_t ldb, double _Complex beta, double _Complex *C, int64_t ldc); +// Supported Level-1: Nrm2 +void onemklDnrm2(syclQueue_t device_queue, int64_t n, const double *x, + int64_t incx, double *result); +void onemklSnrm2(syclQueue_t device_queue, int64_t n, const float *x, + int64_t incx, float *result); +void onemklCnrm2(syclQueue_t device_queue, int64_t n, const float _Complex *x, + int64_t incx, float *result); +void onemklZnrm2(syclQueue_t device_queue, int64_t n, const double _Complex *x, + int64_t incx, double *result); + void onemklDcopy(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx, double *y, int64_t incy); void onemklScopy(syclQueue_t device_queue, int64_t n, const float *x, diff --git a/lib/mkl/libonemkl.jl b/lib/mkl/libonemkl.jl index 3927a12d..56767b67 100644 --- a/lib/mkl/libonemkl.jl +++ b/lib/mkl/libonemkl.jl @@ -42,6 +42,31 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld C::ZePtr{ComplexF64}, ldc::Int64)::Cint end +function onemklDnrm2(device_queue, n, x, incx, result) + @ccall liboneapi_support.onemklDnrm2(device_queue::syclQueue_t, + n::Int64, x::ZePtr{Cdouble}, incx::Int64, + result::RefOrZeRef{Cdouble})::Cvoid +end + +function onemklSnrm2(device_queue, n, x, incx, result) + @ccall liboneapi_support.onemklSnrm2(device_queue::syclQueue_t, + n::Int64, x::ZePtr{Cfloat}, incx::Int64, + result::RefOrZeRef{Cfloat})::Cvoid +end + +function onemklCnrm2(device_queue, n, x, incx, result) + @ccall liboneapi_support.onemklCnrm2(device_queue::syclQueue_t, + n::Int64, x::ZePtr{ComplexF32}, incx::Int64, + result::RefOrZeRef{Cfloat})::Cvoid +end + +function onemklZnrm2(device_queue, n, x, incx, result) + @ccall liboneapi_support.onemklZnrm2(device_queue::syclQueue_t, + n::Int64, x::ZePtr{ComplexF64}, incx::Int64, + result::RefOrZeRef{Cdouble})::Cvoid +end + + function onemklDcopy(device_queue, n, x, incx, y, incy) @ccall liboneapi_support.onemklDcopy(device_queue::syclQueue_t, n::Int64, x::ZePtr{Cdouble}, incx::Int64, @@ -104,4 +129,4 @@ end function onemklZamin(device_queue, n, x, incx, result) @ccall liboneapi_support.onemklZamin(device_queue::syclQueue_t, n::Int64, x::ZePtr{ComplexF64}, incx::Int64, result::ZePtr{Int64})::Cvoid -end +end \ No newline at end of file diff --git a/lib/mkl/linalg.jl b/lib/mkl/linalg.jl index d1d6ae6b..94905416 100644 --- a/lib/mkl/linalg.jl +++ b/lib/mkl/linalg.jl @@ -49,6 +49,8 @@ function gemm_dispatch!(C::oneStridedVecOrMat, A, B, alpha::Number=true, beta::N end end +LinearAlgebra.norm(x::oneStridedVecOrMat{<:onemklFloat}) = oneMKL.nrm2(length(x), x) + for NT in (Number, Real) # NOTE: alpha/beta also ::Real to avoid ambiguities with certain Base methods @eval begin diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index ff4a46e5..43808b17 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -14,7 +14,23 @@ function Base.convert(::Type{onemklTranspose}, trans::Char) end end - +# level 1 +## nrm2 +for (fname, elty, ret_type) in + ((:onemklDnrm2, :Float64,:Float64), + (:onemklSnrm2, :Float32,:Float32), + (:onemklCnrm2, :ComplexF32,:Float32), + (:onemklZnrm2, :ComplexF64,:Float64)) + @eval begin + function nrm2(n::Integer, x::oneStridedArray{$elty}) + queue = global_queue(context(x), device(x)) + result = oneArray{$ret_type}([0]); + $fname(sycl_queue(queue), n, x, stride(x,1), result) + res = Array(result) + return res[1] + end + end +end # # BLAS diff --git a/test/onemkl.jl b/test/onemkl.jl index b9d47ace..5f0e5ad3 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -13,10 +13,17 @@ m = 20 oneMKL.copy!(m,A,B) @test Array(A) == Array(B) - # testing oneMKL max and min - a = convert.(T, [1.0, 2.0, -0.8, 5.0, 3.0]) - ca = oneArray(a) - @test BLAS.iamax(a) == oneMKL.iamax(ca) - @test oneMKL.iamin(ca) == 3 + @testset "nrm2" begin + # Test nrm2 primitive + @test testf(norm, rand(T,m)) + end + + @testset "iamax/iamin" begin + # testing oneMKL max and min + a = convert.(T, [1.0, 2.0, -0.8, 5.0, 3.0]) + ca = oneArray(a) + @test BLAS.iamax(a) == oneMKL.iamax(ca) + @test oneMKL.iamin(ca) == 3 + end end # level 1 testset end