diff --git a/deps/src/onemkl.cpp b/deps/src/onemkl.cpp index d397937f..49e0a321 100644 --- a/deps/src/onemkl.cpp +++ b/deps/src/onemkl.cpp @@ -81,6 +81,22 @@ extern "C" int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, return 0; } +extern "C" void onemklSaxpy(syclQueue_t device_queue, int64_t n, float alpha, const float *x, std::int64_t incx, float *y, int64_t incy) { + oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, x, incx, y, incy); +} + +extern "C" void onemklDaxpy(syclQueue_t device_queue, int64_t n, double alpha, const double *x, std::int64_t incx, double *y, int64_t incy) { + oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, x, incx, y, incy); +} + +extern "C" void onemklCaxpy(syclQueue_t device_queue, int64_t n, float _Complex alpha, const float _Complex *x, std::int64_t incx, float _Complex *y, int64_t incy) { + oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, reinterpret_cast *>(x), incx, reinterpret_cast *>(y), incy); +} + +extern "C" void onemklZaxpy(syclQueue_t device_queue, int64_t n, double _Complex alpha, const double _Complex *x, std::int64_t incx, double _Complex *y, int64_t incy) { + oneapi::mkl::blas::column_major::axpy(device_queue->val, n, alpha, reinterpret_cast *>(x), incx, reinterpret_cast *>(y), incy); +} + extern "C" void onemklDcopy(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx, double *y, int64_t incy) { oneapi::mkl::blas::column_major::copy(device_queue->val, n, x, incx, y, incy); diff --git a/deps/src/onemkl.h b/deps/src/onemkl.h index 54b0e7d9..1c7ce3ee 100644 --- a/deps/src/onemkl.h +++ b/deps/src/onemkl.h @@ -39,6 +39,11 @@ int onemklZgemm(syclQueue_t device_queue, onemklTranspose transA, const double _Complex *B, int64_t ldb, double _Complex beta, double _Complex *C, int64_t ldc); +void onemklSaxpy(syclQueue_t device_queue, int64_t n, float alpha, const float *x, int64_t incx, float *y, int64_t incy); +void onemklDaxpy(syclQueue_t device_queue, int64_t n, double alpha, const double *x, int64_t incx, double *y, int64_t incy); +void onemklCaxpy(syclQueue_t device_queue, int64_t n, float _Complex alpha, const float _Complex *x, int64_t incx, float _Complex *y, int64_t incy); +void onemklZaxpy(syclQueue_t device_queue, int64_t n, double _Complex alpha, const double _Complex *x, int64_t incx, double _Complex *y, int64_t incy); + void onemklDcopy(syclQueue_t device_queue, int64_t n, const double *x, int64_t incx, double *y, int64_t incy); void onemklScopy(syclQueue_t device_queue, int64_t n, const float *x, diff --git a/lib/mkl/libonemkl.jl b/lib/mkl/libonemkl.jl index 6b0f95c9..dc24674e 100644 --- a/lib/mkl/libonemkl.jl +++ b/lib/mkl/libonemkl.jl @@ -42,6 +42,22 @@ function onemklZgemm(device_queue, transA, transB, m, n, k, alpha, A, lda, B, ld C::ZePtr{ComplexF64}, ldc::Int64)::Cint end +function onemklSaxpy(device_queue, n, alpha, x, incx, y, incy) + @ccall liboneapi_support.onemklSaxpy(device_queue::syclQueue_t, n::Int64, alpha::Cfloat, x::ZePtr{Cfloat}, incx::Int64, y::ZePtr{Cfloat}, incy::Int64)::Cvoid +end + +function onemklDaxpy(device_queue, n, alpha, x, incx, y, incy) + @ccall liboneapi_support.onemklDaxpy(device_queue::syclQueue_t, n::Int64, alpha::Cdouble, x::ZePtr{Cdouble}, incx::Int64, y::ZePtr{Cdouble}, incy::Int64)::Cvoid +end + +function onemklCaxpy(device_queue, n, alpha, x, incx, y, incy) + @ccall liboneapi_support.onemklCaxpy(device_queue::syclQueue_t, n::Int64, alpha::ComplexF32, x::ZePtr{ComplexF32}, incx::Int64, y::ZePtr{ComplexF32}, incy::Int64)::Cvoid +end + +function onemklZaxpy(device_queue, n, alpha, x, incx, y, incy) + @ccall liboneapi_support.onemklZaxpy(device_queue::syclQueue_t, n::Int64, alpha::ComplexF64, x::ZePtr{ComplexF64}, incx::Int64, y::ZePtr{ComplexF64}, incy::Int64)::Cvoid +end + function onemklDcopy(device_queue, n, x, incx, y, incy) @ccall liboneapi_support.onemklDcopy(device_queue::syclQueue_t, n::Int64, x::ZePtr{Cdouble}, incx::Int64, @@ -65,4 +81,3 @@ function onemklCcopy(device_queue, n, x, incx, y, incy) x::ZePtr{ComplexF32}, incx::Int64, y::ZePtr{ComplexF32}, incy::Int64)::Cvoid end - diff --git a/lib/mkl/linalg.jl b/lib/mkl/linalg.jl index d1d6ae6b..a3ae80c5 100644 --- a/lib/mkl/linalg.jl +++ b/lib/mkl/linalg.jl @@ -49,6 +49,11 @@ function gemm_dispatch!(C::oneStridedVecOrMat, A, B, alpha::Number=true, beta::N end end +function LinearAlgebra.axpy!(alpha::Number, x::oneStridedVecOrMat{<:onemklFloat}, y::oneStridedVecOrMat{<:onemklFloat}) where T<:Union{onemklFloat} + length(x)==length(y) || throw(DimensionMismatch("axpy arguments have lengths $(length(x)) and $(length(y))")) + oneMKL.axpy!(length(x), alpha, x, y) +end + for NT in (Number, Real) # NOTE: alpha/beta also ::Real to avoid ambiguities with certain Base methods @eval begin diff --git a/lib/mkl/wrappers.jl b/lib/mkl/wrappers.jl index 54d382c4..010517a7 100644 --- a/lib/mkl/wrappers.jl +++ b/lib/mkl/wrappers.jl @@ -14,8 +14,26 @@ function Base.convert(::Type{onemklTranspose}, trans::Char) end end - - +# level 1 +## axpy primitive +for (fname, elty) in + ((:onemklDaxpy,:Float64), + (:onemklSaxpy,:Float32), + (:onemklZaxpy,:ComplexF64), + (:onemklCaxpy,:ComplexF32)) + @eval begin + function axpy!(n::Integer, + alpha::Number, + x::oneStridedArray{$elty}, + y::oneStridedArray{$elty} + ) + queue = global_queue(context(x), device(x)) + alpha = $elty(alpha) + $fname(sycl_queue(queue), n, alpha, x, stride(x,1), y, stride(y,1)) + y + end + end +end # # BLAS # diff --git a/test/onemkl.jl b/test/onemkl.jl index 5627aefb..e2ab833d 100644 --- a/test/onemkl.jl +++ b/test/onemkl.jl @@ -10,9 +10,17 @@ k = 13 ############################################################################################ @testset "level 1" begin @testset for T in intersect(eltypes, [Float32, Float64, ComplexF32, ComplexF64]) - A = oneArray(rand(T, m)) - B = oneArray{T}(undef, m) - oneMKL.copy!(m,A,B) - @test Array(A) == Array(B) + @testset "copy" begin + A = oneArray(rand(T, m)) + B = oneArray{T}(undef, m) + oneMKL.copy!(m,A,B) + @test Array(A) == Array(B) + end + + @testset "axpy" begin + # Test axpy primitive + alpha = rand(T,1) + @test testf(axpy!, alpha[1], rand(T,m), rand(T,m)) + end end # level 1 testset end