diff --git a/algorithms/kernel/optimization_solver/sgd/oneapi/sgd_dense_minibatch_oneapi_impl.i b/algorithms/kernel/optimization_solver/sgd/oneapi/sgd_dense_minibatch_oneapi_impl.i index 0101646e60e..d9545f9ea3d 100644 --- a/algorithms/kernel/optimization_solver/sgd/oneapi/sgd_dense_minibatch_oneapi_impl.i +++ b/algorithms/kernel/optimization_solver/sgd/oneapi/sgd_dense_minibatch_oneapi_impl.i @@ -228,11 +228,20 @@ services::Status SGDKernelOneAPI::compute(HostA const IndicesStatus indicesStatus = (batchIndices ? user : (batchSize < nTerms ? random : all)); services::SharedPtr > ntBatchIndices; + services::SharedPtr > ntBatchIndices2; + services::SharedPtr > ntBatchIndicesSycl; + services::SharedPtr > ntBatchIndices2Sycl; + BlockDescriptor batchIndicesBD; + BlockDescriptor batchIndicesSyclBD; + BlockDescriptor batchIndices2BD; + BlockDescriptor batchIndices2SyclBD; if (indicesStatus == user || indicesStatus == random) { - // Replace by SyclNumericTable when will be RNG on GPU - ntBatchIndices = HomogenNumericTableCPU::create(batchSize, 1, &status); + ntBatchIndices = HomogenNumericTableCPU::create(batchSize, 1, &status); + ntBatchIndices2 = HomogenNumericTableCPU::create(batchSize, 1, &status); + ntBatchIndicesSycl = SyclHomogenNumericTable::create(batchSize, 1, NumericTableIface::doAllocate); + ntBatchIndices2Sycl = SyclHomogenNumericTable::create(batchSize, 1, NumericTableIface::doAllocate); } NumericTablePtr previousBatchIndices = function->sumOfFunctionsParameter->batchIndices; @@ -288,23 +297,84 @@ services::Status SGDKernelOneAPI::compute(HostA *nProceededIterations = static_cast(nIter); + bool isSync = false; + bool isSecondPartOfIndices = false; + bool isFirstPartOfIndicesInitialized = false; + bool isSecondPartOfIndicesInitialized = false; + services::internal::HostAppHelper host(pHost, 10); for (size_t epoch = startIteration; epoch < (startIteration + nIter); epoch++) { - if (epoch % L == 0 || epoch == startIteration) + if ((epoch % (L << 1) == 0) || (epoch == startIteration)) { learningRate = learningRateArray[(epoch / L) % learningRateLength]; consCoeff = consCoeffsArray[(epoch / L) % consCoeffsLength]; - if (indicesStatus == user || indicesStatus == random) + + if ((indicesStatus == user) || (indicesStatus == random)) { DAAL_ITTNOTIFY_SCOPED_TASK(generateUniform); + const int * pValues = nullptr; DAAL_CHECK_STATUS(status, rngTask.get(pValues)); ntBatchIndices->setArray(const_cast(pValues), ntBatchIndices->getNumberOfRows()); + + DAAL_CHECK_STATUS(status, + ntBatchIndices->getBlockOfRows(0, ntBatchIndices->getNumberOfRows(), ReadWriteMode::readOnly, batchIndicesBD)); + const services::Buffer batchIndicesBuffer = batchIndicesBD.getBuffer(); + + DAAL_CHECK_STATUS(status, ntBatchIndicesSycl->getBlockOfRows(0, ntBatchIndicesSycl->getNumberOfRows(), ReadWriteMode::writeOnly, + batchIndicesSyclBD)); + const services::Buffer batchIndicesSyclBuffer = batchIndicesSyclBD.getBuffer(); + + ctx.copy(batchIndicesSyclBuffer, 0, batchIndicesBuffer, 0, batchSize, &status, isSync); + } + if ((indicesStatus == user) || (indicesStatus == random)) + { + DAAL_ITTNOTIFY_SCOPED_TASK(generateUniform); + + const int * pValues2 = nullptr; + DAAL_CHECK_STATUS(status, rngTask.get(pValues2)); + ntBatchIndices2->setArray(const_cast(pValues2), ntBatchIndices2->getNumberOfRows()); + + DAAL_CHECK_STATUS(status, + ntBatchIndices2->getBlockOfRows(0, ntBatchIndices2->getNumberOfRows(), ReadWriteMode::readOnly, batchIndices2BD)); + const services::Buffer batchIndices2Buffer = batchIndices2BD.getBuffer(); + + DAAL_CHECK_STATUS(status, ntBatchIndices2Sycl->getBlockOfRows(0, ntBatchIndices2Sycl->getNumberOfRows(), ReadWriteMode::writeOnly, + batchIndices2SyclBD)); + const services::Buffer batchIndices2SyclBuffer = batchIndices2SyclBD.getBuffer(); + + ctx.copy(batchIndices2SyclBuffer, 0, batchIndices2Buffer, 0, batchSize, &status, isSync); } + + isSecondPartOfIndices = false; + isFirstPartOfIndicesInitialized = false; + isSecondPartOfIndicesInitialized = false; } - DAAL_CHECK_STATUS(status, function->computeNoThrow()); + if ((epoch % L == 0) && !(epoch == startIteration)) + { + isSecondPartOfIndices = true; + } + + if (isSecondPartOfIndices) + { + if (!isSecondPartOfIndicesInitialized) + { + function->sumOfFunctionsParameter->batchIndices = ntBatchIndices2Sycl; + isSecondPartOfIndicesInitialized = true; + } + DAAL_CHECK_STATUS(status, function->computeNoThrow()); + } + else + { + if (!isFirstPartOfIndicesInitialized) + { + function->sumOfFunctionsParameter->batchIndices = ntBatchIndicesSycl; + isFirstPartOfIndicesInitialized = true; + } + DAAL_CHECK_STATUS(status, function->computeNoThrow()); + } if (host.isCancelled(status, 1)) { @@ -332,6 +402,14 @@ services::Status SGDKernelOneAPI::compute(HostA } DAAL_CHECK_STATUS(status, makeStep(argumentSize, prevWorkValueBuff, gradientBuff, workValueBuff, learningRate, consCoeff)); nProceededIters++; + + if ((epoch % (L << 1) == (L << 1) - 1) && !(epoch == startIteration)) + { + ntBatchIndices->releaseBlockOfRows(batchIndicesBD); + ntBatchIndicesSycl->releaseBlockOfRows(batchIndicesSyclBD); + ntBatchIndices2->releaseBlockOfRows(batchIndices2BD); + ntBatchIndices2Sycl->releaseBlockOfRows(batchIndices2SyclBD); + } } if (lastIterationResult) diff --git a/include/oneapi/internal/execution_context.h b/include/oneapi/internal/execution_context.h index 0293356df09..1943a7b149c 100644 --- a/include/oneapi/internal/execution_context.h +++ b/include/oneapi/internal/execution_context.h @@ -339,17 +339,18 @@ class ExecutionContextIface virtual void syrk(math::UpLo upper_lower, math::Transpose trans, size_t n, size_t k, double alpha, const UniversalBuffer & a_buffer, size_t lda, size_t offsetA, double beta, UniversalBuffer & c_buffer, size_t ldc, size_t offsetC, services::Status * status = NULL) = 0; - virtual void axpy(const uint32_t n, const double a, const UniversalBuffer x_buffer, const int incx, - const UniversalBuffer y_buffer, const int incy, services::Status * status = NULL) = 0; + virtual void axpy(const uint32_t n, const double a, const UniversalBuffer x_buffer, const int incx, const UniversalBuffer y_buffer, + const int incy, services::Status * status = NULL) = 0; virtual void potrf(math::UpLo uplo, size_t n, UniversalBuffer & a_buffer, size_t lda, services::Status * status = NULL) = 0; virtual void potrs(math::UpLo uplo, size_t n, size_t ny, UniversalBuffer & a_buffer, size_t lda, UniversalBuffer & b_buffer, size_t ldb, services::Status * status = NULL) = 0; - virtual void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count, services::Status * status) = 0; + virtual void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count, services::Status * status, + bool isSync = true) = 0; - virtual void fill(UniversalBuffer dest, double value, services::Status * status) = 0; + virtual void fill(UniversalBuffer dest, double value, services::Status * status, bool isSync = true) = 0; virtual UniversalBuffer allocate(TypeId type, size_t bufferSize, services::Status * status) = 0; @@ -357,7 +358,8 @@ class ExecutionContextIface virtual InfoDevice & getInfoDevice() = 0; - virtual void copy(UniversalBuffer dest, size_t desOffset, void *src, size_t srcOffset, size_t count, services::Status *status) = 0; + virtual void copy(UniversalBuffer dest, size_t desOffset, void * src, size_t srcOffset, size_t count, services::Status * status, + bool isSync = true) = 0; }; /** @@ -414,8 +416,8 @@ class CpuExecutionContextImpl : public Base, public ExecutionContextIface services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented); } - void axpy(const uint32_t n, const double a, const UniversalBuffer x_buffer, const int incx, - const UniversalBuffer y_buffer, const int incy, services::Status * status = NULL) DAAL_C11_OVERRIDE + void axpy(const uint32_t n, const double a, const UniversalBuffer x_buffer, const int incx, const UniversalBuffer y_buffer, const int incy, + services::Status * status = NULL) DAAL_C11_OVERRIDE { services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented); } @@ -431,13 +433,13 @@ class CpuExecutionContextImpl : public Base, public ExecutionContextIface services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented); } - void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count, - services::Status * status = NULL) DAAL_C11_OVERRIDE + void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count, services::Status * status = NULL, + bool isSync = true) DAAL_C11_OVERRIDE { services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented); } - void fill(UniversalBuffer dest, double value, services::Status * status = NULL) DAAL_C11_OVERRIDE + void fill(UniversalBuffer dest, double value, services::Status * status = NULL, bool isSync = true) DAAL_C11_OVERRIDE { services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented); } @@ -452,12 +454,8 @@ class CpuExecutionContextImpl : public Base, public ExecutionContextIface InfoDevice & getInfoDevice() DAAL_C11_OVERRIDE { return _infoDevice; } - void copy(UniversalBuffer dest, - size_t desOffset, - void *src, - size_t srcOffset, - size_t count, - services::Status *status = NULL) DAAL_C11_OVERRIDE + void copy(UniversalBuffer dest, size_t desOffset, void * src, size_t srcOffset, size_t count, services::Status * status = NULL, + bool isSync = true) DAAL_C11_OVERRIDE { services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented); } diff --git a/include/oneapi/internal/execution_context_sycl.h b/include/oneapi/internal/execution_context_sycl.h index 7b11af482e8..ce3bcfab455 100644 --- a/include/oneapi/internal/execution_context_sycl.h +++ b/include/oneapi/internal/execution_context_sycl.h @@ -16,20 +16,20 @@ *******************************************************************************/ #ifdef DAAL_SYCL_INTERFACE -#ifndef __DAAL_ONEAPI_INTERNAL_EXECUTION_CONTEXT_SYCL_H__ -#define __DAAL_ONEAPI_INTERNAL_EXECUTION_CONTEXT_SYCL_H__ + #ifndef __DAAL_ONEAPI_INTERNAL_EXECUTION_CONTEXT_SYCL_H__ + #define __DAAL_ONEAPI_INTERNAL_EXECUTION_CONTEXT_SYCL_H__ -#include -#include -#include -#include + #include + #include + #include + #include -#include "services/daal_string.h" -#include "oneapi/internal/execution_context.h" -#include "oneapi/internal/kernel_scheduler_sycl.h" -#include "oneapi/internal/math/blas_executor.h" -#include "oneapi/internal/math/lapack_executor.h" -#include "oneapi/internal/error_handling.h" + #include "services/daal_string.h" + #include "oneapi/internal/execution_context.h" + #include "oneapi/internal/kernel_scheduler_sycl.h" + #include "oneapi/internal/math/blas_executor.h" + #include "oneapi/internal/math/lapack_executor.h" + #include "oneapi/internal/error_handling.h" namespace daal { @@ -50,22 +50,13 @@ class OpenClKernelFactory : public Base, public ClKernelFactoryIface public: ProgramCacheEntry() : _program(nullptr) {} - ~ProgramCacheEntry() - { - delete _program; - } + ~ProgramCacheEntry() { delete _program; } - void setProgram(OpenClProgramRef *program) - { - _program = program; - } + void setProgram(OpenClProgramRef * program) { _program = program; } - OpenClProgramRef * getProgram() - { - return _program; - } + OpenClProgramRef * getProgram() { return _program; } - const char* getName(services::Status * status = nullptr) + const char * getName(services::Status * status = nullptr) { if (!_program) { @@ -90,25 +81,20 @@ class OpenClKernelFactory : public Base, public ClKernelFactoryIface ~KernelCacheEntry() {} - void setKernel(KernelPtr kernel, const char *name) + void setKernel(KernelPtr kernel, const char * name) { - _name = name; + _name = name; _kernel = kernel; } - KernelPtr getKernel() - { - return _kernel; - } + KernelPtr getKernel() { return _kernel; } - const char* getName() - { - return _name.c_str(); - } + const char * getName() { return _name.c_str(); } }; public: - explicit OpenClKernelFactory(cl::sycl::queue & deviceQueue) : _clProgramRef(nullptr), _executionTarget(ExecutionTargetIds::unspecified), _deviceQueue(deviceQueue) + explicit OpenClKernelFactory(cl::sycl::queue & deviceQueue) + : _clProgramRef(nullptr), _executionTarget(ExecutionTargetIds::unspecified), _deviceQueue(deviceQueue) {} void build(ExecutionTargetId target, const char * key, const char * program, const char * options = "", @@ -131,8 +117,8 @@ class OpenClKernelFactory : public Base, public ClKernelFactoryIface } else { - _clProgramCache[id].setProgram(new OpenClProgramRef(_deviceQueue.get_context().get(), - _deviceQueue.get_device().get(), key, program, options, status)); + _clProgramCache[id].setProgram( + new OpenClProgramRef(_deviceQueue.get_context().get(), _deviceQueue.get_device().get(), key, program, options, status)); if (status != nullptr && !status->ok()) { return; @@ -169,14 +155,13 @@ class OpenClKernelFactory : public Base, public ClKernelFactoryIface { return KernelPtr(); } - kernelPtr = KernelPtr(new OpenClKernel(_executionTarget, *_clProgramRef, kernelRef)); + kernelPtr = KernelPtr(new OpenClKernel(_executionTarget, *_clProgramRef, kernelRef)); _kernelCache[id].setKernel(kernelPtr, kernelName); } return kernelPtr; } - ~OpenClKernelFactory() DAAL_C11_OVERRIDE - {} + ~OpenClKernelFactory() DAAL_C11_OVERRIDE {} protected: uint64_t hash(const char * key) @@ -293,14 +278,14 @@ class SyclExecutionContextImpl : public Base, public ExecutionContextIface } } - void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count, - services::Status * status = nullptr) DAAL_C11_OVERRIDE + void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count, services::Status * status = nullptr, + bool isSync = true) DAAL_C11_OVERRIDE { DAAL_ASSERT(dest.type() == src.type()); // TODO: Thread safe? try { - BufferCopier::copy(_deviceQueue, dest, desOffset, src, srcOffset, count); + BufferCopier::copy(_deviceQueue, dest, desOffset, src, srcOffset, count, isSync); } catch (cl::sycl::exception const & e) { @@ -308,12 +293,12 @@ class SyclExecutionContextImpl : public Base, public ExecutionContextIface } } - void fill(UniversalBuffer dest, double value, services::Status * status = nullptr) DAAL_C11_OVERRIDE + void fill(UniversalBuffer dest, double value, services::Status * status = nullptr, bool isSync = true) DAAL_C11_OVERRIDE { // TODO: Thread safe? try { - BufferFiller::fill(_deviceQueue, dest, value); + BufferFiller::fill(_deviceQueue, dest, value, isSync); } catch (cl::sycl::exception const & e) { @@ -325,20 +310,15 @@ class SyclExecutionContextImpl : public Base, public ExecutionContextIface InfoDevice & getInfoDevice() DAAL_C11_OVERRIDE { return _infoDevice; } - void copy(UniversalBuffer dest, - size_t desOffset, - void *src, - size_t srcOffset, - size_t count, - services::Status *status = nullptr) DAAL_C11_OVERRIDE + void copy(UniversalBuffer dest, size_t desOffset, void * src, size_t srcOffset, size_t count, services::Status * status = nullptr, + bool isSync = true) DAAL_C11_OVERRIDE { // TODO: Thread safe? try { - ArrayCopier::copy(_deviceQueue, dest, - desOffset, src, srcOffset, count); + ArrayCopier::copy(_deviceQueue, dest, desOffset, src, srcOffset, count, isSync); } - catch (cl::sycl::exception const &e) + catch (cl::sycl::exception const & e) { convertSyclExceptionToStatus(e, status); } diff --git a/include/oneapi/internal/types_utils_cxx11.h b/include/oneapi/internal/types_utils_cxx11.h index 0a3f424bc6d..8209bb108cd 100644 --- a/include/oneapi/internal/types_utils_cxx11.h +++ b/include/oneapi/internal/types_utils_cxx11.h @@ -77,9 +77,11 @@ class BufferCopier UniversalBuffer & srcUnivers; size_t srcOffset; size_t count; + bool isSync; - explicit Execute(cl::sycl::queue & queue, UniversalBuffer & dst, size_t desOffset, UniversalBuffer & src, size_t srcOffset, size_t count) - : queue(queue), dstUnivers(dst), dstOffset(desOffset), srcUnivers(src), srcOffset(srcOffset), count(count) + explicit Execute(cl::sycl::queue & queue, UniversalBuffer & dst, size_t desOffset, UniversalBuffer & src, size_t srcOffset, size_t count, + bool isSync = true) + : queue(queue), dstUnivers(dst), dstOffset(desOffset), srcUnivers(src), srcOffset(srcOffset), count(count), isSync(isSync) {} template @@ -92,14 +94,22 @@ class BufferCopier auto dst_acc = dst.template get_access(cgh, cl::sycl::range<1>(count), cl::sycl::id<1>(dstOffset)); cgh.copy(src_acc, dst_acc); }); - event.wait(); + if (isSync) + { + event.wait(); + } + else + { + dst.set_write_back(false); + } } }; public: - static void copy(cl::sycl::queue & queue, UniversalBuffer & dest, size_t dstOffset, UniversalBuffer & src, size_t srcOffset, size_t count) + static void copy(cl::sycl::queue & queue, UniversalBuffer & dest, size_t dstOffset, UniversalBuffer & src, size_t srcOffset, size_t count, + bool isSync = true) { - Execute op(queue, dest, dstOffset, src, srcOffset, count); + Execute op(queue, dest, dstOffset, src, srcOffset, count, isSync); TypeDispatcher::dispatch(dest.type(), op); } }; @@ -113,41 +123,44 @@ class ArrayCopier private: struct Execute { - cl::sycl::queue &queue; - UniversalBuffer &dstUnivers; + cl::sycl::queue & queue; + UniversalBuffer & dstUnivers; size_t dstOffset; - void *srcArray; + void * srcArray; size_t srcOffset; size_t count; + bool isSync; - explicit Execute(cl::sycl::queue &queue, - UniversalBuffer &dst, size_t desOffset, - void *src, size_t srcOffset, - size_t count) : queue(queue), dstUnivers(dst), - dstOffset(desOffset), srcArray(src), - srcOffset(srcOffset), count(count) { } + explicit Execute(cl::sycl::queue & queue, UniversalBuffer & dst, size_t desOffset, void * src, size_t srcOffset, size_t count, + bool isSync = true) + : queue(queue), dstUnivers(dst), dstOffset(desOffset), srcArray(src), srcOffset(srcOffset), count(count), isSync(isSync) + {} template void operator()(Typelist) { - auto src = (T*)srcArray; - auto dst = dstUnivers.get().toSycl(); - cl::sycl::event event = queue.submit([&](cl::sycl::handler &cgh) { - auto dst_acc = dst.template get_access( - cgh, cl::sycl::range<1>(count), cl::sycl::id<1>(dstOffset)); + auto src = (T *)srcArray; + auto dst = dstUnivers.get().toSycl(); + cl::sycl::event event = queue.submit([&](cl::sycl::handler & cgh) { + auto dst_acc = dst.template get_access(cgh, cl::sycl::range<1>(count), cl::sycl::id<1>(dstOffset)); cgh.copy(src, dst_acc); }); - event.wait(); + if (isSync) + { + event.wait(); + } + else + { + dst.set_write_back(false); + } } }; public: - static void copy(cl::sycl::queue &queue, - UniversalBuffer &dest, size_t dstOffset, - void *src, size_t srcOffset, - size_t count) + static void copy(cl::sycl::queue & queue, UniversalBuffer & dest, size_t dstOffset, void * src, size_t srcOffset, size_t count, + bool isSync = true) { - Execute op(queue, dest, dstOffset, src, srcOffset, count); + Execute op(queue, dest, dstOffset, src, srcOffset, count, isSync); TypeDispatcher::dispatch(dest.type(), op); } }; @@ -164,8 +177,11 @@ class BufferFiller cl::sycl::queue & queue; UniversalBuffer & dstUnivers; double value; + bool isSync; - explicit Execute(cl::sycl::queue & queue, UniversalBuffer & dest, double value) : queue(queue), dstUnivers(dest), value(value) {} + explicit Execute(cl::sycl::queue & queue, UniversalBuffer & dest, double value, bool isSync = true) + : queue(queue), dstUnivers(dest), value(value), isSync(isSync) + {} template void operator()(Typelist) @@ -175,14 +191,21 @@ class BufferFiller auto acc = dst.template get_access(cgh); cgh.fill(acc, static_cast(value)); }); - event.wait(); + if (isSync) + { + event.wait(); + } + else + { + dst.set_write_back(false); + } } }; public: - static void fill(cl::sycl::queue & queue, UniversalBuffer & dest, double value) + static void fill(cl::sycl::queue & queue, UniversalBuffer & dest, double value, bool isSync = true) { - Execute op(queue, dest, value); + Execute op(queue, dest, value, isSync); TypeDispatcher::dispatch(dest.type(), op); } };