|
|
@ -18,6 +18,7 @@
|
|
|
|
#include "ps/ps_cache/ps_cache_factory.h"
|
|
|
|
#include "ps/ps_cache/ps_cache_factory.h"
|
|
|
|
#include "backend/kernel_compiler/gpu/cuda_impl/hash_impl.cuh"
|
|
|
|
#include "backend/kernel_compiler/gpu/cuda_impl/hash_impl.cuh"
|
|
|
|
#include "runtime/device/gpu/gpu_common.h"
|
|
|
|
#include "runtime/device/gpu/gpu_common.h"
|
|
|
|
|
|
|
|
#include "runtime/device/gpu/cuda_driver.h"
|
|
|
|
#include "runtime/device/gpu/gpu_memory_allocator.h"
|
|
|
|
#include "runtime/device/gpu/gpu_memory_allocator.h"
|
|
|
|
#include "utils/ms_context.h"
|
|
|
|
#include "utils/ms_context.h"
|
|
|
|
|
|
|
|
|
|
|
@ -26,7 +27,11 @@ namespace ps {
|
|
|
|
namespace gpu {
|
|
|
|
namespace gpu {
|
|
|
|
MS_REG_PS_CACHE(kGPUDevice, GPUPsCache);
|
|
|
|
MS_REG_PS_CACHE(kGPUDevice, GPUPsCache);
|
|
|
|
bool GPUPsCache::InitDevice(uint32_t device_id, const void *) {
|
|
|
|
bool GPUPsCache::InitDevice(uint32_t device_id, const void *) {
|
|
|
|
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed")
|
|
|
|
bool ret = device::gpu::CudaDriver::SetDevice(UintToInt(device_id));
|
|
|
|
|
|
|
|
if (!ret) {
|
|
|
|
|
|
|
|
MS_LOG(ERROR) << "Failed to set device id:" << device_id;
|
|
|
|
|
|
|
|
return false;
|
|
|
|
|
|
|
|
}
|
|
|
|
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)),
|
|
|
|
CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamCreate(reinterpret_cast<CUstream_st **>(&stream_)),
|
|
|
|
"Cuda create stream failed");
|
|
|
|
"Cuda create stream failed");
|
|
|
|
return true;
|
|
|
|
return true;
|
|
|
|