Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,20 @@ services::Status SGDKernelOneAPI<algorithmFPType, miniBatch, cpu>::compute(HostA

const IndicesStatus indicesStatus = (batchIndices ? user : (batchSize < nTerms ? random : all));
services::SharedPtr<HomogenNumericTableCPU<int, cpu> > ntBatchIndices;
services::SharedPtr<HomogenNumericTableCPU<int, cpu> > ntBatchIndices2;
services::SharedPtr<SyclHomogenNumericTable<int> > ntBatchIndicesSycl;
services::SharedPtr<SyclHomogenNumericTable<int> > ntBatchIndices2Sycl;
BlockDescriptor<int> batchIndicesBD;
BlockDescriptor<int> batchIndicesSyclBD;
BlockDescriptor<int> batchIndices2BD;
BlockDescriptor<int> batchIndices2SyclBD;

if (indicesStatus == user || indicesStatus == random)
{
// Replace by SyclNumericTable when will be RNG on GPU
ntBatchIndices = HomogenNumericTableCPU<int, cpu>::create(batchSize, 1, &status);
ntBatchIndices = HomogenNumericTableCPU<int, cpu>::create(batchSize, 1, &status);
ntBatchIndices2 = HomogenNumericTableCPU<int, cpu>::create(batchSize, 1, &status);
ntBatchIndicesSycl = SyclHomogenNumericTable<int>::create(batchSize, 1, NumericTableIface::doAllocate);
ntBatchIndices2Sycl = SyclHomogenNumericTable<int>::create(batchSize, 1, NumericTableIface::doAllocate);
}

NumericTablePtr previousBatchIndices = function->sumOfFunctionsParameter->batchIndices;
Expand Down Expand Up @@ -288,23 +297,84 @@ services::Status SGDKernelOneAPI<algorithmFPType, miniBatch, cpu>::compute(HostA

*nProceededIterations = static_cast<int>(nIter);

bool isSync = false;
bool isSecondPartOfIndices = false;
bool isFirstPartOfIndicesInitialized = false;
bool isSecondPartOfIndicesInitialized = false;

services::internal::HostAppHelper host(pHost, 10);
for (size_t epoch = startIteration; epoch < (startIteration + nIter); epoch++)
{
if (epoch % L == 0 || epoch == startIteration)
if ((epoch % (L << 1) == 0) || (epoch == startIteration))
{
learningRate = learningRateArray[(epoch / L) % learningRateLength];
consCoeff = consCoeffsArray[(epoch / L) % consCoeffsLength];
if (indicesStatus == user || indicesStatus == random)

if ((indicesStatus == user) || (indicesStatus == random))
{
DAAL_ITTNOTIFY_SCOPED_TASK(generateUniform);

const int * pValues = nullptr;
DAAL_CHECK_STATUS(status, rngTask.get(pValues));
ntBatchIndices->setArray(const_cast<int *>(pValues), ntBatchIndices->getNumberOfRows());

DAAL_CHECK_STATUS(status,
ntBatchIndices->getBlockOfRows(0, ntBatchIndices->getNumberOfRows(), ReadWriteMode::readOnly, batchIndicesBD));
const services::Buffer<int> batchIndicesBuffer = batchIndicesBD.getBuffer();

DAAL_CHECK_STATUS(status, ntBatchIndicesSycl->getBlockOfRows(0, ntBatchIndicesSycl->getNumberOfRows(), ReadWriteMode::writeOnly,
batchIndicesSyclBD));
const services::Buffer<int> batchIndicesSyclBuffer = batchIndicesSyclBD.getBuffer();

ctx.copy(batchIndicesSyclBuffer, 0, batchIndicesBuffer, 0, batchSize, &status, isSync);
}
if ((indicesStatus == user) || (indicesStatus == random))
{
DAAL_ITTNOTIFY_SCOPED_TASK(generateUniform);

const int * pValues2 = nullptr;
DAAL_CHECK_STATUS(status, rngTask.get(pValues2));
ntBatchIndices2->setArray(const_cast<int *>(pValues2), ntBatchIndices2->getNumberOfRows());

DAAL_CHECK_STATUS(status,
ntBatchIndices2->getBlockOfRows(0, ntBatchIndices2->getNumberOfRows(), ReadWriteMode::readOnly, batchIndices2BD));
const services::Buffer<int> batchIndices2Buffer = batchIndices2BD.getBuffer();

DAAL_CHECK_STATUS(status, ntBatchIndices2Sycl->getBlockOfRows(0, ntBatchIndices2Sycl->getNumberOfRows(), ReadWriteMode::writeOnly,
batchIndices2SyclBD));
const services::Buffer<int> batchIndices2SyclBuffer = batchIndices2SyclBD.getBuffer();

ctx.copy(batchIndices2SyclBuffer, 0, batchIndices2Buffer, 0, batchSize, &status, isSync);
}

isSecondPartOfIndices = false;
isFirstPartOfIndicesInitialized = false;
isSecondPartOfIndicesInitialized = false;
}

DAAL_CHECK_STATUS(status, function->computeNoThrow());
if ((epoch % L == 0) && !(epoch == startIteration))
{
isSecondPartOfIndices = true;
}

if (isSecondPartOfIndices)
{
if (!isSecondPartOfIndicesInitialized)
{
function->sumOfFunctionsParameter->batchIndices = ntBatchIndices2Sycl;
isSecondPartOfIndicesInitialized = true;
}
DAAL_CHECK_STATUS(status, function->computeNoThrow());
}
else
{
if (!isFirstPartOfIndicesInitialized)
{
function->sumOfFunctionsParameter->batchIndices = ntBatchIndicesSycl;
isFirstPartOfIndicesInitialized = true;
}
DAAL_CHECK_STATUS(status, function->computeNoThrow());
}

if (host.isCancelled(status, 1))
{
Expand Down Expand Up @@ -332,6 +402,14 @@ services::Status SGDKernelOneAPI<algorithmFPType, miniBatch, cpu>::compute(HostA
}
DAAL_CHECK_STATUS(status, makeStep(argumentSize, prevWorkValueBuff, gradientBuff, workValueBuff, learningRate, consCoeff));
nProceededIters++;

if ((epoch % (L << 1) == (L << 1) - 1) && !(epoch == startIteration))
{
ntBatchIndices->releaseBlockOfRows(batchIndicesBD);
ntBatchIndicesSycl->releaseBlockOfRows(batchIndicesSyclBD);
ntBatchIndices2->releaseBlockOfRows(batchIndices2BD);
ntBatchIndices2Sycl->releaseBlockOfRows(batchIndices2SyclBD);
}
}

if (lastIterationResult)
Expand Down
30 changes: 14 additions & 16 deletions include/oneapi/internal/execution_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,25 +339,27 @@ class ExecutionContextIface
virtual void syrk(math::UpLo upper_lower, math::Transpose trans, size_t n, size_t k, double alpha, const UniversalBuffer & a_buffer, size_t lda,
size_t offsetA, double beta, UniversalBuffer & c_buffer, size_t ldc, size_t offsetC, services::Status * status = NULL) = 0;

virtual void axpy(const uint32_t n, const double a, const UniversalBuffer x_buffer, const int incx,
const UniversalBuffer y_buffer, const int incy, services::Status * status = NULL) = 0;
virtual void axpy(const uint32_t n, const double a, const UniversalBuffer x_buffer, const int incx, const UniversalBuffer y_buffer,
const int incy, services::Status * status = NULL) = 0;

virtual void potrf(math::UpLo uplo, size_t n, UniversalBuffer & a_buffer, size_t lda, services::Status * status = NULL) = 0;

virtual void potrs(math::UpLo uplo, size_t n, size_t ny, UniversalBuffer & a_buffer, size_t lda, UniversalBuffer & b_buffer, size_t ldb,
services::Status * status = NULL) = 0;

virtual void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count, services::Status * status) = 0;
virtual void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count, services::Status * status,
bool isSync = true) = 0;

virtual void fill(UniversalBuffer dest, double value, services::Status * status) = 0;
virtual void fill(UniversalBuffer dest, double value, services::Status * status, bool isSync = true) = 0;

virtual UniversalBuffer allocate(TypeId type, size_t bufferSize, services::Status * status) = 0;

virtual ClKernelFactoryIface & getClKernelFactory() = 0;

virtual InfoDevice & getInfoDevice() = 0;

virtual void copy(UniversalBuffer dest, size_t desOffset, void *src, size_t srcOffset, size_t count, services::Status *status) = 0;
virtual void copy(UniversalBuffer dest, size_t desOffset, void * src, size_t srcOffset, size_t count, services::Status * status,
bool isSync = true) = 0;
};

/**
Expand Down Expand Up @@ -414,8 +416,8 @@ class CpuExecutionContextImpl : public Base, public ExecutionContextIface
services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented);
}

void axpy(const uint32_t n, const double a, const UniversalBuffer x_buffer, const int incx,
const UniversalBuffer y_buffer, const int incy, services::Status * status = NULL) DAAL_C11_OVERRIDE
void axpy(const uint32_t n, const double a, const UniversalBuffer x_buffer, const int incx, const UniversalBuffer y_buffer, const int incy,
services::Status * status = NULL) DAAL_C11_OVERRIDE
{
services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented);
}
Expand All @@ -431,13 +433,13 @@ class CpuExecutionContextImpl : public Base, public ExecutionContextIface
services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented);
}

void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count,
services::Status * status = NULL) DAAL_C11_OVERRIDE
void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count, services::Status * status = NULL,
bool isSync = true) DAAL_C11_OVERRIDE
{
services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented);
}

void fill(UniversalBuffer dest, double value, services::Status * status = NULL) DAAL_C11_OVERRIDE
void fill(UniversalBuffer dest, double value, services::Status * status = NULL, bool isSync = true) DAAL_C11_OVERRIDE
{
services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented);
}
Expand All @@ -452,12 +454,8 @@ class CpuExecutionContextImpl : public Base, public ExecutionContextIface

InfoDevice & getInfoDevice() DAAL_C11_OVERRIDE { return _infoDevice; }

void copy(UniversalBuffer dest,
size_t desOffset,
void *src,
size_t srcOffset,
size_t count,
services::Status *status = NULL) DAAL_C11_OVERRIDE
void copy(UniversalBuffer dest, size_t desOffset, void * src, size_t srcOffset, size_t count, services::Status * status = NULL,
bool isSync = true) DAAL_C11_OVERRIDE
{
services::internal::tryAssignStatus(status, services::ErrorMethodNotImplemented);
}
Expand Down
90 changes: 35 additions & 55 deletions include/oneapi/internal/execution_context_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,20 @@
*******************************************************************************/

#ifdef DAAL_SYCL_INTERFACE
#ifndef __DAAL_ONEAPI_INTERNAL_EXECUTION_CONTEXT_SYCL_H__
#define __DAAL_ONEAPI_INTERNAL_EXECUTION_CONTEXT_SYCL_H__
#ifndef __DAAL_ONEAPI_INTERNAL_EXECUTION_CONTEXT_SYCL_H__
#define __DAAL_ONEAPI_INTERNAL_EXECUTION_CONTEXT_SYCL_H__

#include <vector>
#include <cstring>
#include <CL/cl.h>
#include <CL/sycl.hpp>
#include <vector>
#include <cstring>
#include <CL/cl.h>
#include <CL/sycl.hpp>

#include "services/daal_string.h"
#include "oneapi/internal/execution_context.h"
#include "oneapi/internal/kernel_scheduler_sycl.h"
#include "oneapi/internal/math/blas_executor.h"
#include "oneapi/internal/math/lapack_executor.h"
#include "oneapi/internal/error_handling.h"
#include "services/daal_string.h"
#include "oneapi/internal/execution_context.h"
#include "oneapi/internal/kernel_scheduler_sycl.h"
#include "oneapi/internal/math/blas_executor.h"
#include "oneapi/internal/math/lapack_executor.h"
#include "oneapi/internal/error_handling.h"

namespace daal
{
Expand All @@ -50,22 +50,13 @@ class OpenClKernelFactory : public Base, public ClKernelFactoryIface
public:
ProgramCacheEntry() : _program(nullptr) {}

~ProgramCacheEntry()
{
delete _program;
}
~ProgramCacheEntry() { delete _program; }

void setProgram(OpenClProgramRef *program)
{
_program = program;
}
void setProgram(OpenClProgramRef * program) { _program = program; }

OpenClProgramRef * getProgram()
{
return _program;
}
OpenClProgramRef * getProgram() { return _program; }

const char* getName(services::Status * status = nullptr)
const char * getName(services::Status * status = nullptr)
{
if (!_program)
{
Expand All @@ -90,25 +81,20 @@ class OpenClKernelFactory : public Base, public ClKernelFactoryIface

~KernelCacheEntry() {}

void setKernel(KernelPtr kernel, const char *name)
void setKernel(KernelPtr kernel, const char * name)
{
_name = name;
_name = name;
_kernel = kernel;
}

KernelPtr getKernel()
{
return _kernel;
}
KernelPtr getKernel() { return _kernel; }

const char* getName()
{
return _name.c_str();
}
const char * getName() { return _name.c_str(); }
};

public:
explicit OpenClKernelFactory(cl::sycl::queue & deviceQueue) : _clProgramRef(nullptr), _executionTarget(ExecutionTargetIds::unspecified), _deviceQueue(deviceQueue)
explicit OpenClKernelFactory(cl::sycl::queue & deviceQueue)
: _clProgramRef(nullptr), _executionTarget(ExecutionTargetIds::unspecified), _deviceQueue(deviceQueue)
{}

void build(ExecutionTargetId target, const char * key, const char * program, const char * options = "",
Expand All @@ -131,8 +117,8 @@ class OpenClKernelFactory : public Base, public ClKernelFactoryIface
}
else
{
_clProgramCache[id].setProgram(new OpenClProgramRef(_deviceQueue.get_context().get(),
_deviceQueue.get_device().get(), key, program, options, status));
_clProgramCache[id].setProgram(
new OpenClProgramRef(_deviceQueue.get_context().get(), _deviceQueue.get_device().get(), key, program, options, status));
if (status != nullptr && !status->ok())
{
return;
Expand Down Expand Up @@ -169,14 +155,13 @@ class OpenClKernelFactory : public Base, public ClKernelFactoryIface
{
return KernelPtr();
}
kernelPtr = KernelPtr(new OpenClKernel(_executionTarget, *_clProgramRef, kernelRef));
kernelPtr = KernelPtr(new OpenClKernel(_executionTarget, *_clProgramRef, kernelRef));
_kernelCache[id].setKernel(kernelPtr, kernelName);
}
return kernelPtr;
}

~OpenClKernelFactory() DAAL_C11_OVERRIDE
{}
~OpenClKernelFactory() DAAL_C11_OVERRIDE {}

protected:
uint64_t hash(const char * key)
Expand Down Expand Up @@ -293,27 +278,27 @@ class SyclExecutionContextImpl : public Base, public ExecutionContextIface
}
}

void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count,
services::Status * status = nullptr) DAAL_C11_OVERRIDE
void copy(UniversalBuffer dest, size_t desOffset, UniversalBuffer src, size_t srcOffset, size_t count, services::Status * status = nullptr,
bool isSync = true) DAAL_C11_OVERRIDE
{
DAAL_ASSERT(dest.type() == src.type());
// TODO: Thread safe?
try
{
BufferCopier::copy(_deviceQueue, dest, desOffset, src, srcOffset, count);
BufferCopier::copy(_deviceQueue, dest, desOffset, src, srcOffset, count, isSync);
}
catch (cl::sycl::exception const & e)
{
convertSyclExceptionToStatus(e, status);
}
}

void fill(UniversalBuffer dest, double value, services::Status * status = nullptr) DAAL_C11_OVERRIDE
void fill(UniversalBuffer dest, double value, services::Status * status = nullptr, bool isSync = true) DAAL_C11_OVERRIDE
{
// TODO: Thread safe?
try
{
BufferFiller::fill(_deviceQueue, dest, value);
BufferFiller::fill(_deviceQueue, dest, value, isSync);
}
catch (cl::sycl::exception const & e)
{
Expand All @@ -325,20 +310,15 @@ class SyclExecutionContextImpl : public Base, public ExecutionContextIface

InfoDevice & getInfoDevice() DAAL_C11_OVERRIDE { return _infoDevice; }

void copy(UniversalBuffer dest,
size_t desOffset,
void *src,
size_t srcOffset,
size_t count,
services::Status *status = nullptr) DAAL_C11_OVERRIDE
void copy(UniversalBuffer dest, size_t desOffset, void * src, size_t srcOffset, size_t count, services::Status * status = nullptr,
bool isSync = true) DAAL_C11_OVERRIDE
{
// TODO: Thread safe?
try
{
ArrayCopier::copy(_deviceQueue, dest,
desOffset, src, srcOffset, count);
ArrayCopier::copy(_deviceQueue, dest, desOffset, src, srcOffset, count, isSync);
}
catch (cl::sycl::exception const &e)
catch (cl::sycl::exception const & e)
{
convertSyclExceptionToStatus(e, status);
}
Expand Down
Loading