From 1f99cd7d868b29631afffa4094223f573584d25d Mon Sep 17 00:00:00 2001 From: lizhenyu Date: Fri, 18 Dec 2020 16:46:17 +0800 Subject: [PATCH] ps cache data process thread support exit when exceptions occur --- .../engine/datasetops/device_queue_op.cc | 4 +- .../minddata/dataset/engine/tdt/tdt_plugin.cc | 4 +- mindspore/ccsrc/pipeline/jit/pipeline.cc | 8 +- .../ps/ps_cache/ascend/ascend_ps_cache.cc | 137 ++++--- .../ps/ps_cache/ascend/ascend_ps_cache.h | 18 +- .../ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc | 68 ++-- .../ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h | 16 +- mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h | 25 +- .../ccsrc/ps/ps_cache/ps_cache_manager.cc | 373 ++++++++++-------- .../ccsrc/ps/ps_cache/ps_cache_manager.h | 35 +- .../ps/ps_cache/ps_data/ps_data_prefetch.cc | 62 ++- .../ps/ps_cache/ps_data/ps_data_prefetch.h | 9 +- mindspore/ccsrc/ps/ps_context.cc | 9 +- .../ccsrc/runtime/device/gpu/gpu_common.h | 10 + mindspore/core/utils/log_adapter.h | 8 + 15 files changed, 470 insertions(+), 316 deletions(-) diff --git a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc index 6025e4c02c..6e38bd5ff2 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/datasetops/device_queue_op.cc @@ -298,7 +298,9 @@ Status DeviceQueueOp::PushDataToGPU() { // Data prefetch only when PS mode enables cache. if (items.size() > 0) { - ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_); + if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name_, items[0].data_ptr_, items[0].data_len_)) { + return Status(StatusCode::kTimeOut, __LINE__, __FILE__, "Failed to prefetch data."); + } } while (!GpuBufferMgr::GetInstance().IsClosed() && !TaskManager::FindMe()->Interrupted()) { BlockQueueStatus_T ret = GpuBufferMgr::GetInstance().Push(handle, items, WAIT_TIME); diff --git a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc index 8760e8c69b..9bfdcacee4 100644 --- a/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc +++ b/mindspore/ccsrc/minddata/dataset/engine/tdt/tdt_plugin.cc @@ -55,7 +55,9 @@ TdtStatus TdtPlugin::hostPush(TensorRow ts_row, bool is_wait, std::string channe #if ENABLE_D // Data prefetch only when PS mode enables cache. if (items.size() > 0) { - ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_); + if (!ps::PsDataPrefetch::GetInstance().PrefetchData(channel_name, items[0].dataPtr_.get(), items[0].dataLen_)) { + return FAILED; + } } #endif if (tdt::TdtHostPushData(channel_name, items) != 0) { diff --git a/mindspore/ccsrc/pipeline/jit/pipeline.cc b/mindspore/ccsrc/pipeline/jit/pipeline.cc index a337bfc37a..c995b92aee 100644 --- a/mindspore/ccsrc/pipeline/jit/pipeline.cc +++ b/mindspore/ccsrc/pipeline/jit/pipeline.cc @@ -53,6 +53,7 @@ #include "ps/util.h" #include "ps/worker.h" #include "ps/ps_cache/ps_data/ps_data_prefetch.h" +#include "ps/ps_cache/ps_cache_manager.h" #endif #if (ENABLE_GE || ENABLE_D) @@ -1083,9 +1084,10 @@ void ClearResAtexit() { pynative::ClearPyNativeSession(); session::ClearPythonParasMap(); #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) - if (ps::Util::IsParamServerMode()) { - if (ps::Util::IsRoleOfWorker()) { - ps::worker.Finalize(); + if (ps::Util::IsParamServerMode() && ps::Util::IsRoleOfWorker()) { + ps::worker.Finalize(); + if (ps::PsDataPrefetch::GetInstance().cache_enable()) { + ps::ps_cache_instance.Finalize(); } } #endif diff --git a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc index 917b999db1..779f13802d 100644 --- a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc +++ b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.cc @@ -37,155 +37,178 @@ namespace ps { namespace ascend { MS_REG_PS_CACHE(kAscendDevice, AscendPsCache); namespace { -void SetProtoInputs(const std::vector> &data_shape, const std::vector &data_type, +bool SetProtoInputs(const std::vector> &data_shape, const std::vector &data_type, mindspore::NodeDef *proto) { - MS_EXCEPTION_IF_NULL(proto); + MS_ERROR_IF_NULL(proto); if (data_shape.size() != data_type.size()) { - MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type."; + MS_LOG(ERROR) << "The size of data shape is not equal to the size of data type."; + return false; } for (size_t input_index = 0; input_index < data_shape.size(); input_index++) { ::mindspore::Tensor *proto_inputs = proto->add_inputs(); - MS_EXCEPTION_IF_NULL(proto_inputs); + MS_ERROR_IF_NULL(proto_inputs); auto input_shape = data_shape[input_index]; mindspore::TensorShape *tensorShape = proto_inputs->mutable_tensor_shape(); - MS_EXCEPTION_IF_NULL(tensorShape); + MS_ERROR_IF_NULL(tensorShape); for (auto item : input_shape) { mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); - MS_EXCEPTION_IF_NULL(dim); + MS_ERROR_IF_NULL(dim); dim->set_size((::google::protobuf::int64)item); } auto input_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[input_index]); proto_inputs->set_tensor_type(input_type); proto_inputs->set_mem_device("HBM"); } + return true; } -void SetProtoOutputs(const std::vector> &data_shape, const std::vector &data_type, +bool SetProtoOutputs(const std::vector> &data_shape, const std::vector &data_type, mindspore::NodeDef *proto) { - MS_EXCEPTION_IF_NULL(proto); + MS_ERROR_IF_NULL(proto); if (data_shape.size() != data_type.size()) { - MS_LOG(EXCEPTION) << "The size of data shape is not equal to the size of data type."; + MS_LOG(ERROR) << "The size of data shape is not equal to the size of data type."; + return false; } for (size_t output_index = 0; output_index < data_shape.size(); output_index++) { ::mindspore::Tensor *proto_outputs = proto->add_outputs(); - MS_EXCEPTION_IF_NULL(proto_outputs); + MS_ERROR_IF_NULL(proto_outputs); auto output_shape = data_shape[output_index]; mindspore::TensorShape *tensorShape = proto_outputs->mutable_tensor_shape(); - MS_EXCEPTION_IF_NULL(tensorShape); + MS_ERROR_IF_NULL(tensorShape); for (auto item : output_shape) { mindspore::TensorShape_Dim *dim = tensorShape->add_dim(); - MS_EXCEPTION_IF_NULL(dim); + MS_ERROR_IF_NULL(dim); dim->set_size((::google::protobuf::int64)item); } auto output_type = kernel::AicpuOpUtil::MsTypeToProtoType(data_type[output_index]); proto_outputs->set_tensor_type(output_type); proto_outputs->set_mem_device("HBM"); } + return true; } -void SetNodedefProto(const std::shared_ptr &op_info, +bool SetNodedefProto(const std::shared_ptr &op_info, const std::shared_ptr &kernel_mod_ptr) { - MS_EXCEPTION_IF_NULL(op_info); - MS_EXCEPTION_IF_NULL(kernel_mod_ptr); + MS_ERROR_IF_NULL(op_info); + MS_ERROR_IF_NULL(kernel_mod_ptr); mindspore::NodeDef proto; proto.set_op(op_info->op_name_); - SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto); - SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto); + RETURN_IF_FALSE(SetProtoInputs(op_info->input_data_shape_, op_info->input_data_type_, &proto)); + RETURN_IF_FALSE(SetProtoOutputs(op_info->output_data_shape_, op_info->output_data_type_, &proto)); std::string nodeDefStr; if (!proto.SerializeToString(&nodeDefStr)) { - MS_LOG(EXCEPTION) << "Serialize nodeDef to string failed."; + MS_LOG(ERROR) << "Serialize nodeDef to string failed."; + return false; } MS_LOG(DEBUG) << "Set node def proto, node name:" << op_info->op_name_; kernel_mod_ptr->SetNodeDef(nodeDefStr); + return true; } } // namespace -void AscendPsCache::InitDevice(uint32_t device_id, const void *context) { - MS_EXCEPTION_IF_NULL(context); +bool AscendPsCache::InitDevice(uint32_t device_id, const void *context) { + MS_ERROR_IF_NULL(context); auto ret = rtSetDevice(device_id); if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtSetDevice, ret[" << ret << "]"; + MS_LOG(ERROR) << "Call rtSetDevice, ret[" << ret << "]"; + return false; } auto rt_context = const_cast(context); ret = rtCtxSetCurrent(rt_context); if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtCtxSetCurrent, ret[" << ret << "]"; + MS_LOG(ERROR) << "Call rtCtxSetCurrent, ret[" << ret << "]"; + return false; } ret = rtStreamCreate(&stream_, 0); if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Call rtStreamCreate, ret[" << ret << "]"; + MS_LOG(ERROR) << "Call rtStreamCreate, ret[" << ret << "]"; + return false; } + return true; } void *AscendPsCache::MallocMemory(size_t size) { return device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(size); } -void AscendPsCache::MallocConstantMemory(size_t constant_value) { +bool AscendPsCache::MallocConstantMemory(size_t constant_value) { offset_addr_ = reinterpret_cast(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); - MS_EXCEPTION_IF_NULL(offset_addr_); + MS_ERROR_IF_NULL(offset_addr_); rtMemset(offset_addr_, sizeof(int), 0, sizeof(int)); cache_vocab_size_addr_ = reinterpret_cast(device::ascend::AscendMemoryPool::GetInstance().AllocTensorMem(sizeof(int))); - MS_EXCEPTION_IF_NULL(cache_vocab_size_addr_); + MS_ERROR_IF_NULL(cache_vocab_size_addr_); rtMemset(cache_vocab_size_addr_, sizeof(int), constant_value, sizeof(int)); + return true; } -void AscendPsCache::RecordEvent() { +bool AscendPsCache::RecordEvent() { event_.reset(new rtEvent_t()); auto ret = rtEventCreate(&(*event_)); if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Create event failed"; + MS_LOG(ERROR) << "Create event failed"; + return false; } ret = rtEventRecord(*event_, stream_); if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "Record event failed"; + MS_LOG(ERROR) << "Record event failed"; + return false; } + return true; } -void AscendPsCache::SynchronizeEvent() { +bool AscendPsCache::SynchronizeEvent() { auto ret = rtEventSynchronize(*event_); if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "tEventSynchronize failed"; + MS_LOG(ERROR) << "tEventSynchronize failed"; + return false; } ret = rtEventDestroy(*event_); if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "rtEventDestroy failed"; + MS_LOG(ERROR) << "rtEventDestroy failed"; + return false; } + return true; } -void AscendPsCache::SynchronizeStream() { +bool AscendPsCache::SynchronizeStream() { auto ret = rtStreamSynchronize(stream_); if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "rtStreamSynchronize failed"; + MS_LOG(ERROR) << "rtStreamSynchronize failed"; + return false; } + return true; } -void AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { - MS_EXCEPTION_IF_NULL(dst); - MS_EXCEPTION_IF_NULL(src); +bool AscendPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { + MS_ERROR_IF_NULL(dst); + MS_ERROR_IF_NULL(src); auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_HOST_TO_DEVICE, stream_); if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed"; + MS_LOG(ERROR) << "rtMemcpyAsync failed"; + return false; } + return true; } -void AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { - MS_EXCEPTION_IF_NULL(dst); - MS_EXCEPTION_IF_NULL(src); +bool AscendPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { + MS_ERROR_IF_NULL(dst); + MS_ERROR_IF_NULL(src); auto ret = rtMemcpyAsync(dst, size, src, size, RT_MEMCPY_DEVICE_TO_HOST, stream_); if (ret != RT_ERROR_NONE) { - MS_EXCEPTION(DeviceProcessError) << "rtMemcpyAsync failed"; + MS_LOG(ERROR) << "rtMemcpyAsync failed"; + return false; } + return true; } -void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, +bool AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, size_t embedding_size, size_t swap_out_size) { - MS_EXCEPTION_IF_NULL(hash_table_addr); - MS_EXCEPTION_IF_NULL(swap_out_value_addr); - MS_EXCEPTION_IF_NULL(swap_out_index_addr); + MS_ERROR_IF_NULL(hash_table_addr); + MS_ERROR_IF_NULL(swap_out_value_addr); + MS_ERROR_IF_NULL(swap_out_index_addr); auto hash_swap_out_mod = std::make_shared(); - MS_EXCEPTION_IF_NULL(hash_swap_out_mod); + MS_ERROR_IF_NULL(hash_swap_out_mod); hash_swap_out_mod->SetNodeName(kEmbeddingLookupOpName); std::vector> input_shape; std::vector> output_shape; @@ -197,7 +220,7 @@ void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr output_shape.push_back({swap_out_size, embedding_size}); auto op_info = std::make_shared(kEmbeddingLookupOpName, input_shape, input_type, output_shape, output_type); - SetNodedefProto(op_info, hash_swap_out_mod); + RETURN_IF_FALSE(SetNodedefProto(op_info, hash_swap_out_mod)); AddressPtrList kernel_inputs; AddressPtrList kernel_outputs = { @@ -208,17 +231,19 @@ void AscendPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr kernel_inputs.push_back(std::make_shared
(offset_addr_, sizeof(int))); auto ret = hash_swap_out_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); if (!ret) { - MS_LOG(EXCEPTION) << "Hash swap out launch failed."; + MS_LOG(ERROR) << "Hash swap out launch failed."; + return false; } + return true; } -void AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, +bool AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, size_t embedding_size, size_t swap_in_size) { - MS_EXCEPTION_IF_NULL(hash_table_addr); - MS_EXCEPTION_IF_NULL(swap_in_value_addr); - MS_EXCEPTION_IF_NULL(swap_in_index_addr); + MS_ERROR_IF_NULL(hash_table_addr); + MS_ERROR_IF_NULL(swap_in_value_addr); + MS_ERROR_IF_NULL(swap_in_index_addr); auto hash_swap_in_mod = std::make_shared(); - MS_EXCEPTION_IF_NULL(hash_swap_in_mod); + MS_ERROR_IF_NULL(hash_swap_in_mod); hash_swap_in_mod->SetNodeName(kernel::kUpdateCache); std::vector> input_shape; std::vector> output_shape; @@ -245,8 +270,10 @@ void AscendPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, kernel_outputs.push_back(std::make_shared
(offset_addr_, sizeof(int))); auto ret = hash_swap_in_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_); if (!ret) { - MS_LOG(EXCEPTION) << "Hash swap in launch failed."; + MS_LOG(ERROR) << "Hash swap in launch failed."; + return false; } + return true; } } // namespace ascend } // namespace ps diff --git a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h index 4f7c04d853..d52a11222a 100644 --- a/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h +++ b/mindspore/ccsrc/ps/ps_cache/ascend/ascend_ps_cache.h @@ -49,17 +49,17 @@ class AscendPsCache : public PsCacheBasic { public: AscendPsCache() = default; ~AscendPsCache() override = default; - void InitDevice(uint32_t device_id, const void *context) override; + bool InitDevice(uint32_t device_id, const void *context) override; void *MallocMemory(size_t size) override; - void MallocConstantMemory(size_t constant_value) override; - void RecordEvent() override; - void SynchronizeEvent() override; - void SynchronizeStream() override; - void CopyHostMemToDevice(void *dst, void *src, size_t size) override; - void CopyDeviceMemToHost(void *dst, void *src, size_t size) override; - void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, + bool MallocConstantMemory(size_t constant_value) override; + bool RecordEvent() override; + bool SynchronizeEvent() override; + bool SynchronizeStream() override; + bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; + bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; + bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, size_t embedding_size, size_t swap_out_size) override; - void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, + bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, size_t embedding_size, size_t swap_in_size) override; private: diff --git a/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc b/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc index 1100872878..536b142f99 100644 --- a/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc +++ b/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.cc @@ -25,67 +25,75 @@ namespace mindspore { namespace ps { namespace gpu { MS_REG_PS_CACHE(kGPUDevice, GPUPsCache); -void GPUPsCache::InitDevice(uint32_t device_id, const void *) { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed") - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamCreate(reinterpret_cast(&stream_)), - "Cuda create stream failed"); +bool GPUPsCache::InitDevice(uint32_t device_id, const void *) { + CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaSetDevice(device_id), "Cuda set device failed") + CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamCreate(reinterpret_cast(&stream_)), + "Cuda create stream failed"); + return true; } void *GPUPsCache::MallocMemory(size_t size) { return device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(size); } -void GPUPsCache::RecordEvent() { +bool GPUPsCache::RecordEvent() { event_.reset(new cudaEvent_t()); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed"); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventRecord(*event_, reinterpret_cast(stream_)), - "Cuda record event failed"); + CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventCreate(&(*event_)), "Cuda create event failed"); + CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventRecord(*event_, reinterpret_cast(stream_)), + "Cuda record event failed"); + return true; } -void GPUPsCache::SynchronizeEvent() { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed"); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed"); +bool GPUPsCache::SynchronizeEvent() { + CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventSynchronize(*event_), "Cuda sync event failed"); + CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaEventDestroy(*event_), "Cuda destroy event failed"); + return true; } -void GPUPsCache::SynchronizeStream() { - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_)), - "Cuda sync stream failed"); +bool GPUPsCache::SynchronizeStream() { + CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(cudaStreamSynchronize(reinterpret_cast(stream_)), + "Cuda sync stream failed"); + return true; } -void GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { - MS_EXCEPTION_IF_NULL(dst); - MS_EXCEPTION_IF_NULL(src); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( +bool GPUPsCache::CopyHostMemToDevice(void *dst, void *src, size_t size) { + MS_ERROR_IF_NULL(dst); + MS_ERROR_IF_NULL(src); + CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( cudaMemcpyAsync(dst, src, size, cudaMemcpyHostToDevice, reinterpret_cast(stream_)), "Cuda memcpy failed"); + return true; } -void GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { - MS_EXCEPTION_IF_NULL(dst); - MS_EXCEPTION_IF_NULL(src); - CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( +bool GPUPsCache::CopyDeviceMemToHost(void *dst, void *src, size_t size) { + MS_ERROR_IF_NULL(dst); + MS_ERROR_IF_NULL(src); + CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE( cudaMemcpyAsync(dst, src, size, cudaMemcpyDeviceToHost, reinterpret_cast(stream_)), "Cuda memcpy failed"); + return true; } -void GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t, +bool GPUPsCache::HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t, size_t embedding_size, size_t swap_out_size) { - MS_EXCEPTION_IF_NULL(hash_table_addr); - MS_EXCEPTION_IF_NULL(swap_out_value_addr); - MS_EXCEPTION_IF_NULL(swap_out_index_addr); + MS_ERROR_IF_NULL(hash_table_addr); + MS_ERROR_IF_NULL(swap_out_value_addr); + MS_ERROR_IF_NULL(swap_out_index_addr); DoHashSwapOut(reinterpret_cast(hash_table_addr), reinterpret_cast(swap_out_value_addr), reinterpret_cast(swap_out_index_addr), swap_out_size, embedding_size, reinterpret_cast(stream_)); + return true; } -void GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t, +bool GPUPsCache::HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t, size_t embedding_size, size_t swap_in_size) { - MS_EXCEPTION_IF_NULL(hash_table_addr); - MS_EXCEPTION_IF_NULL(swap_in_value_addr); - MS_EXCEPTION_IF_NULL(swap_in_index_addr); + MS_ERROR_IF_NULL(hash_table_addr); + MS_ERROR_IF_NULL(swap_in_value_addr); + MS_ERROR_IF_NULL(swap_in_index_addr); DoHashSwapIn(reinterpret_cast(hash_table_addr), reinterpret_cast(swap_in_value_addr), reinterpret_cast(swap_in_index_addr), swap_in_size, embedding_size, reinterpret_cast(stream_)); + return true; } } // namespace gpu } // namespace ps diff --git a/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h b/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h index 612382c2b3..a0bfbd951f 100644 --- a/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h +++ b/mindspore/ccsrc/ps/ps_cache/gpu/gpu_ps_cache.h @@ -28,16 +28,16 @@ class GPUPsCache : public PsCacheBasic { public: GPUPsCache() = default; ~GPUPsCache() override = default; - void InitDevice(uint32_t device_id, const void *context) override; + bool InitDevice(uint32_t device_id, const void *context) override; void *MallocMemory(size_t size) override; - void RecordEvent() override; - void SynchronizeEvent() override; - void SynchronizeStream() override; - void CopyHostMemToDevice(void *dst, void *src, size_t size) override; - void CopyDeviceMemToHost(void *dst, void *src, size_t size) override; - void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, + bool RecordEvent() override; + bool SynchronizeEvent() override; + bool SynchronizeStream() override; + bool CopyHostMemToDevice(void *dst, void *src, size_t size) override; + bool CopyDeviceMemToHost(void *dst, void *src, size_t size) override; + bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, size_t embedding_size, size_t swap_out_size) override; - void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, + bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, size_t embedding_size, size_t swap_in_size) override; private: diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h index c9c7f703a9..fe2727e7ee 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_basic.h @@ -21,21 +21,28 @@ namespace mindspore { namespace ps { +#define RETURN_IF_FALSE(condition) \ + do { \ + if (!(condition)) { \ + return false; \ + } \ + } while (false) + class PsCacheBasic { public: PsCacheBasic() = default; virtual ~PsCacheBasic() = default; - virtual void InitDevice(uint32_t device_id, const void *context) = 0; + virtual bool InitDevice(uint32_t device_id, const void *context) = 0; virtual void *MallocMemory(size_t size) = 0; - virtual void MallocConstantMemory(size_t constant_value) {} - virtual void RecordEvent() = 0; - virtual void SynchronizeEvent() = 0; - virtual void SynchronizeStream() = 0; - virtual void CopyHostMemToDevice(void *dst, void *src, size_t size) = 0; - virtual void CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0; - virtual void HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, + virtual bool MallocConstantMemory(size_t constant_value) { return true; } + virtual bool RecordEvent() = 0; + virtual bool SynchronizeEvent() = 0; + virtual bool SynchronizeStream() = 0; + virtual bool CopyHostMemToDevice(void *dst, void *src, size_t size) = 0; + virtual bool CopyDeviceMemToHost(void *dst, void *src, size_t size) = 0; + virtual bool HashSwapOut(void *hash_table_addr, void *swap_out_value_addr, void *swap_out_index_addr, size_t hash_table_size, size_t embedding_size, size_t swap_out_size) = 0; - virtual void HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, + virtual bool HashSwapIn(void *hash_table_addr, void *swap_in_value_addr, void *swap_in_index_addr, size_t hash_table_size, size_t embedding_size, size_t swap_in_size) = 0; protected: diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc index d30db0cf25..7b7b30ffe8 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.cc @@ -170,8 +170,10 @@ void PsCacheManager::AddEmbeddingTable() const { void PsCacheManager::InitParameterServer() { MS_LOG(INFO) << "Embedding table init begin:" << finish_insert_init_info_; std::unique_lock locker(data_mutex_); - insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true; }); - + insert_init_info_.wait(locker, [this] { return finish_insert_init_info_ == true || running_ == false; }); + if (!running_) { + return; + } for (const auto &item : hash_tables_) { const auto ¶m_name = item.first; size_t key = worker.SetParamKey(param_name); @@ -224,7 +226,9 @@ void PsCacheManager::AllocMemForHashTable() { embedding_device_cache_->hash_swap_value_addr_ = reinterpret_cast( embedding_device_cache_->cache_->MallocMemory(max_embedding_size * batch_elements_ * sizeof(float))); MS_EXCEPTION_IF_NULL(embedding_device_cache_->hash_swap_value_addr_); - embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_); + if (!(embedding_device_cache_->cache_->MallocConstantMemory(cache_vocab_size_))) { + MS_LOG(EXCEPTION) << "MallocConstantMemory failed."; + } } void PsCacheManager::SetLocalIdRank() { @@ -250,19 +254,25 @@ void PsCacheManager::set_channel_name(const std::string channel_name) { channel_name_ = channel_name; } -void PsCacheManager::IncreaseStep() { +bool PsCacheManager::IncreaseStep() { if (data_step_ >= UINT64_MAX) { - MS_LOG(EXCEPTION) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t."; + MS_LOG(ERROR) << "The data step (" << data_step_ << ") will exceed the maximum value of uint64_t."; + return false; } data_step_++; set_current_graph_step(); if (graph_running_step_ > data_step_) { - MS_LOG(EXCEPTION) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_ - << ")."; + MS_LOG(ERROR) << "The graph running step (" << graph_running_step_ << ") exceed the data step (" << data_step_ + << ")."; + return false; } + return true; } void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { + if (terminated_) { + MS_LOG(EXCEPTION) << "ps cache data process thread is terminated."; + } if (graph_step_ >= UINT64_MAX) { MS_LOG(EXCEPTION) << "The graph step(" << graph_step_ << ") will exceed the maximum value of uint64_t."; } @@ -274,7 +284,9 @@ void PsCacheManager::IncreaseGraphStep(const std::string &channel_name) { } graph_step_++; set_channel_name(channel_name); - PsDataPrefetch::GetInstance().TryWakeChannel(channel_name); + if (!PsDataPrefetch::GetInstance().TryWakeChannel(channel_name)) { + MS_LOG(EXCEPTION) << "TryWakeChannel failed, channel name: " << channel_name; + } data_prase_.notify_one(); } @@ -284,74 +296,99 @@ void PsCacheManager::DoProcessData(uint32_t device_id, void *context) { MS_LOG(EXCEPTION) << "Only the sink_mode of dataset supports embeddingLookup cache in parameter server training " "mode, current dataset mode is not sink_mode."; } - auto process_data_thread = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); - process_data_thread.detach(); + process_data_thread_ = std::thread(&PsCacheManager::ProcessDataTask, this, device_id, context); } void PsCacheManager::ProcessDataTask(uint32_t device_id, void *context) { - MS_EXCEPTION_IF_NULL(embedding_device_cache_); - MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); embedding_device_cache_->cache_->InitDevice(device_id, context); + running_ = true; + bool ret = true; InitParameterServer(); - while (true) { - ProcessData(); + while (ret) { + if (!running_) { + break; + } + ret = ProcessData(); + } + if (!ret) { + terminated_ = true; } } -void PsCacheManager::ProcessData() { - MS_EXCEPTION_IF_NULL(embedding_device_cache_); - MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); +void PsCacheManager::Finalize() { + if (running_) { + running_ = false; + } + PsDataPrefetch::GetInstance().NotifyFinalize(); + insert_init_info_.notify_all(); + data_prase_.notify_all(); + if (process_data_thread_.joinable()) { + process_data_thread_.join(); + } +} + +bool PsCacheManager::ProcessData() { struct timeval start_time, end_time; const uint64_t kUSecondInSecond = 1000000; (void)gettimeofday(&start_time, nullptr); auto channel = channel_name(); if (channel.empty()) { std::unique_lock locker(data_mutex_); - data_prase_.wait(locker, [this] { return !channel_name_.empty(); }); + data_prase_.wait(locker, [this] { return !channel_name_.empty() || running_ == false; }); + if (!running_) { + return false; + } } auto data = PsDataPrefetch::GetInstance().data(channel_name_); if (data == nullptr) { MS_LOG(INFO) << "No data process, channel name:" << channel_name_; std::unique_lock locker(data_mutex_); (void)data_prase_.wait_for(locker, std::chrono::milliseconds(100)); - return; + return true; } - IncreaseStep(); + RETURN_IF_FALSE(IncreaseStep()); auto data_size = PsDataPrefetch::GetInstance().data_size(channel_name_); + if (data_size == 0) { + MS_LOG(ERROR) << "The data_size can not be zero."; + return false; + } auto batch_ids = reinterpret_cast(data); auto batch_ids_len = data_size / sizeof(int); std::unique_ptr hash_index(new int[batch_ids_len]); if (memset_s(&statistics_info_, sizeof(statistics_info_), 0, sizeof(statistics_info_))) { - MS_LOG(EXCEPTION) << "Process data memset failed."; + MS_LOG(ERROR) << "Process data memset failed."; + return false; } // Get hash swap in/out index and ids. - ParseData(batch_ids, batch_ids_len, hash_index.get()); + RETURN_IF_FALSE(ParseData(batch_ids, batch_ids_len, hash_index.get())); for (const auto &item : hash_tables_) { auto key = worker.GetParamKey(item.first); auto hash_info = item.second; - HashSwapHostToServer(key, hash_info); - HashSwapDeviceToHost(hash_info); - HashSwapServerToHost(key, hash_info); - HashSwapHostToDevice(hash_info); + RETURN_IF_FALSE(HashSwapHostToServer(key, hash_info)); + RETURN_IF_FALSE(HashSwapDeviceToHost(hash_info)); + RETURN_IF_FALSE(HashSwapServerToHost(key, hash_info)); + RETURN_IF_FALSE(HashSwapHostToDevice(hash_info)); } // Replace the batch_ids by hash index for getNext-op getting hash index as input. if (memcpy_s(data, data_size, hash_index.get(), data_size) != EOK) { - MS_LOG(EXCEPTION) << "Process data memcpy failed."; + MS_LOG(ERROR) << "Process data memcpy failed."; + return false; } - embedding_device_cache_->cache_->SynchronizeStream(); + RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream()); // Finish the data process and notify data prefetch. - PsDataPrefetch::GetInstance().FinalizeData(channel_name_); + RETURN_IF_FALSE(PsDataPrefetch::GetInstance().FinalizeData(channel_name_)); (void)gettimeofday(&end_time, nullptr); uint64_t cost = kUSecondInSecond * static_cast(end_time.tv_sec - start_time.tv_sec); cost += static_cast(end_time.tv_usec - start_time.tv_usec); MS_LOG(DEBUG) << "Ps cache completes processing data(data step:" << data_step_ << ",graph step:" << graph_running_step_ << " channel name:" << channel_name_ << ", time cost:" << cost / 1000 << "ms)."; + return true; } -void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { - MS_EXCEPTION_IF_NULL(batch_ids); - MS_EXCEPTION_IF_NULL(hash_index); +bool PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index) { + MS_ERROR_IF_NULL(batch_ids); + MS_ERROR_IF_NULL(hash_index); for (size_t i = 0; i < batch_ids_len; i++) { bool need_swap_host_to_device = true; bool need_swap_device_to_host = true; @@ -360,12 +397,16 @@ void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, hash_index[i] = -1; continue; } - hash_index[i] = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device); + auto index = ParseDeviceData(id, &need_swap_device_to_host, &need_swap_host_to_device); + if (index == INVALID_INDEX_VALUE) { + return false; + } + hash_index[i] = index; if (need_swap_host_to_device) { - ParseHostDataHostToDevice(id); + RETURN_IF_FALSE(ParseHostDataHostToDevice(id)); } if (need_swap_device_to_host) { - ParseHostDataDeviceToHost(id); + RETURN_IF_FALSE(ParseHostDataDeviceToHost(id)); } } // Each 1000 step prints ps cache hit rate. @@ -374,33 +415,28 @@ void PsCacheManager::ParseData(const int *batch_ids, const size_t batch_ids_len, auto hit_rate = SizeToFloat(statistics_info_.hash_hit_count_) / statistics_info_.batch_id_unique_count_; MS_LOG(INFO) << "Ps cache hit rate: " << hit_rate * 100 << "%."; } + return true; } -void PsCacheManager::WaitGraphRun() { +bool PsCacheManager::WaitGraphRun() { MS_LOG(INFO) << "Hash table has no space to insert new data and retries within 2 minutes."; std::unique_lock locker(data_mutex_); if (!data_prase_.wait_for(locker, std::chrono::seconds(120), [this] { return graph_step_ > graph_running_step_; })) { - MS_LOG(EXCEPTION) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_ - << ", graph running step:" << graph_running_step_ << ")."; + MS_LOG(ERROR) << "Ps cache data parse timeout, suggest to enlarge the cache size(graph step:" << graph_step_ + << ", graph running step:" << graph_running_step_ << ")."; + return false; } set_current_graph_step(); + return true; } int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device) { - MS_EXCEPTION_IF_NULL(need_swap_device_to_host); - MS_EXCEPTION_IF_NULL(need_swap_host_to_device); - MS_EXCEPTION_IF_NULL(embedding_device_cache_); int *device_to_host_index = embedding_device_cache_->device_to_host_index.get(); int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); int *host_to_device_index = embedding_device_cache_->host_to_device_index.get(); int *host_to_device_ids = embedding_device_cache_->host_to_device_ids.get(); - MS_EXCEPTION_IF_NULL(device_to_host_index); - MS_EXCEPTION_IF_NULL(device_to_host_ids); - MS_EXCEPTION_IF_NULL(host_to_device_index); - MS_EXCEPTION_IF_NULL(host_to_device_ids); auto device_hash_map = embedding_device_cache_->device_hash_map_; - MS_EXCEPTION_IF_NULL(device_hash_map); int index = 0; auto iter = device_hash_map->id_iter(id); if (device_hash_map->IsIdExist(iter)) { @@ -417,7 +453,9 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b index = device_hash_map->ParseData(id, device_to_host_index, device_to_host_ids, data_step_, graph_running_step_, &(statistics_info_.device_to_host_size_)); if (index == INVALID_INDEX_VALUE) { - WaitGraphRun(); + if (!WaitGraphRun()) { + return INVALID_INDEX_VALUE; + } continue; } host_to_device_index[statistics_info_.host_to_device_size_] = index; @@ -430,21 +468,20 @@ int PsCacheManager::ParseDeviceData(size_t id, bool *need_swap_device_to_host, b return index; } -void PsCacheManager::ParseHostDataHostToDevice(size_t id) { - MS_EXCEPTION_IF_NULL(embedding_host_cache_); +bool PsCacheManager::ParseHostDataHostToDevice(size_t id) { int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); int *server_to_host_index = embedding_host_cache_->server_to_host_index.get(); int *server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); int *host_to_device_index = embedding_host_cache_->host_to_device_index.get(); - MS_EXCEPTION_IF_NULL(host_to_server_index); - MS_EXCEPTION_IF_NULL(host_to_server_ids); - MS_EXCEPTION_IF_NULL(server_to_host_index); - MS_EXCEPTION_IF_NULL(server_to_host_ids); - MS_EXCEPTION_IF_NULL(host_to_device_index); + MS_ERROR_IF_NULL(host_to_server_index); + MS_ERROR_IF_NULL(host_to_server_ids); + MS_ERROR_IF_NULL(server_to_host_index); + MS_ERROR_IF_NULL(server_to_host_ids); + MS_ERROR_IF_NULL(host_to_device_index); auto host_hash_map = embedding_host_cache_->host_hash_map_; - MS_EXCEPTION_IF_NULL(host_hash_map); + MS_ERROR_IF_NULL(host_hash_map); auto iter = host_hash_map->id_iter(id); if (host_hash_map->IsIdExist(iter)) { auto index = iter->second; @@ -457,7 +494,7 @@ void PsCacheManager::ParseHostDataHostToDevice(size_t id) { auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, graph_running_step_, &statistics_info_.host_to_server_size_); if (index == INVALID_INDEX_VALUE) { - WaitGraphRun(); + RETURN_IF_FALSE(WaitGraphRun()); continue; } host_to_device_index[statistics_info_.host_to_device_size_ - 1] = index; @@ -466,22 +503,21 @@ void PsCacheManager::ParseHostDataHostToDevice(size_t id) { break; } } + return true; } -void PsCacheManager::ParseHostDataDeviceToHost(size_t id) { - MS_EXCEPTION_IF_NULL(embedding_device_cache_); - MS_EXCEPTION_IF_NULL(embedding_host_cache_); +bool PsCacheManager::ParseHostDataDeviceToHost(size_t id) { int *device_to_host_ids = embedding_device_cache_->device_to_host_ids.get(); int *host_to_server_index = embedding_host_cache_->host_to_server_index.get(); int *host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); int *device_to_host_index = embedding_host_cache_->device_to_host_index.get(); - MS_EXCEPTION_IF_NULL(device_to_host_ids); - MS_EXCEPTION_IF_NULL(host_to_server_index); - MS_EXCEPTION_IF_NULL(host_to_server_ids); - MS_EXCEPTION_IF_NULL(device_to_host_index); + MS_ERROR_IF_NULL(device_to_host_ids); + MS_ERROR_IF_NULL(host_to_server_index); + MS_ERROR_IF_NULL(host_to_server_ids); + MS_ERROR_IF_NULL(device_to_host_index); auto host_hash_map = embedding_host_cache_->host_hash_map_; - MS_EXCEPTION_IF_NULL(host_hash_map); + MS_ERROR_IF_NULL(host_hash_map); int swap_device_to_host_id = device_to_host_ids[statistics_info_.device_to_host_size_ - 1]; auto iter = host_hash_map->id_iter(swap_device_to_host_id); if (host_hash_map->IsIdExist(iter)) { @@ -495,13 +531,14 @@ void PsCacheManager::ParseHostDataDeviceToHost(size_t id) { auto index = host_hash_map->ParseData(id, host_to_server_index, host_to_server_ids, data_step_, graph_running_step_, &statistics_info_.host_to_server_size_); if (index == INVALID_INDEX_VALUE) { - WaitGraphRun(); + RETURN_IF_FALSE(WaitGraphRun()); continue; } device_to_host_index[statistics_info_.device_to_host_size_ - 1] = index; break; } } + return true; } void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, @@ -514,19 +551,21 @@ void PsCacheManager::LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t pos = index * outer_dim_size; auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens); if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; + MS_LOG(ERROR) << "LookUpTable task memcpy failed."; + terminated_ = true; } } else { auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); if (ret != EOK) { - MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; + MS_LOG(ERROR) << "LookUpTable task memset failed."; + terminated_ = true; } } output_addr += outer_dim_size; } } -void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, +bool PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, const int *indices_addr, float *output_addr) { size_t first_dim_size = host_cache_vocab_size_; size_t outer_dim_size = embedding_size; @@ -553,9 +592,10 @@ void PsCacheManager::LookUpHostHashTable(size_t embedding_size, size_t indices_l for (size_t j = 0; j < i; j++) { threads[j].join(); } + return !terminated_; } -void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, +bool PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, float *hash_table_addr) { size_t first_dim_size = host_cache_vocab_size_; size_t thread_num = insert_indices_size / 10000 + 1; @@ -565,8 +605,8 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in size_t i; size_t task_offset = 0; - auto insert_hash_table_task = [](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, - int *insert_indices, float *insert_data, float *hash_table_addr) { + auto insert_hash_table_task = [this](size_t insert_indices_size, size_t outer_dim_size, size_t first_dim_size, + int *insert_indices, float *insert_data, float *hash_table_addr) { auto type_size = sizeof(float); size_t lens = outer_dim_size * type_size; for (size_t i = 0; i < insert_indices_size; ++i) { @@ -574,7 +614,8 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in if (index >= 0 && index < SizeToInt(first_dim_size)) { auto ret = memcpy_s(hash_table_addr + index * outer_dim_size, lens, insert_data + i * outer_dim_size, lens); if (ret != EOK) { - MS_LOG(EXCEPTION) << "Insert hash table task memcpy failed."; + MS_LOG(ERROR) << "Insert hash table task memcpy failed."; + terminated_ = true; } } } @@ -596,94 +637,101 @@ void PsCacheManager::InsertHostHashTable(size_t embedding_size, size_t insert_in for (size_t j = 0; j < i; j++) { threads[j].join(); } + return !terminated_; } -void PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { - MS_EXCEPTION_IF_NULL(embedding_device_cache_); - MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); - MS_EXCEPTION_IF_NULL(embedding_host_cache_); +bool PsCacheManager::HashSwapHostToDevice(const HashTableInfo &hash_info) { + MS_ERROR_IF_NULL(embedding_device_cache_); + MS_ERROR_IF_NULL(embedding_device_cache_->cache_); + MS_ERROR_IF_NULL(embedding_host_cache_); auto host_cache_host_to_device_index = embedding_host_cache_->host_to_device_index.get(); auto device_cache_host_to_device_index = embedding_device_cache_->host_to_device_index.get(); auto swap_indices_size = statistics_info_.host_to_device_size_; if (swap_indices_size == 0) { - return; + return true; } auto embedding_size = hash_info.embedding_size; auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); auto hash_table_size = hash_info.device_address.size; auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); auto swap_out_data = std::make_unique(swap_indices_size * embedding_size); - LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_cache_host_to_device_index, - swap_out_data.get()); - embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, - swap_out_data.get(), - swap_indices_size * embedding_size * sizeof(float)); - embedding_device_cache_->cache_->CopyHostMemToDevice( - embedding_device_cache_->hash_swap_index_addr_, device_cache_host_to_device_index, swap_indices_size * sizeof(int)); - embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, - embedding_device_cache_->hash_swap_index_addr_, hash_table_size, - embedding_size, swap_indices_size); -} - -void PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) { - MS_EXCEPTION_IF_NULL(embedding_device_cache_); - MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); - MS_EXCEPTION_IF_NULL(embedding_host_cache_); + RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, + host_cache_host_to_device_index, swap_out_data.get())); + RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( + embedding_device_cache_->hash_swap_value_addr_, swap_out_data.get(), + swap_indices_size * embedding_size * sizeof(float))); + RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, + device_cache_host_to_device_index, + swap_indices_size * sizeof(int))); + RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn( + hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, + hash_table_size, embedding_size, swap_indices_size)); + return true; +} + +bool PsCacheManager::HashSwapDeviceToHost(const HashTableInfo &hash_info) { + MS_ERROR_IF_NULL(embedding_device_cache_); + MS_ERROR_IF_NULL(embedding_device_cache_->cache_); + MS_ERROR_IF_NULL(embedding_host_cache_); auto swap_indices_size = statistics_info_.device_to_host_size_; auto device_cache_device_to_host_index = embedding_device_cache_->device_to_host_index.get(); auto host_cache_device_to_host_index = embedding_host_cache_->device_to_host_index.get(); if (swap_indices_size == 0) { - return; + return true; } auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); auto hash_table_size = hash_info.device_address.size; auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); auto embedding_size = hash_info.embedding_size; auto swap_out_data = std::make_unique(swap_indices_size * embedding_size); - embedding_device_cache_->cache_->CopyHostMemToDevice( - embedding_device_cache_->hash_swap_index_addr_, device_cache_device_to_host_index, swap_indices_size * sizeof(int)); - embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, - embedding_device_cache_->hash_swap_index_addr_, hash_table_size, - embedding_size, swap_indices_size); - embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data.get(), - embedding_device_cache_->hash_swap_value_addr_, - swap_indices_size * embedding_size * sizeof(float)); - embedding_device_cache_->cache_->SynchronizeStream(); - InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index, - swap_out_data.get(), host_hash_table_addr); -} - -void PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) { - MS_EXCEPTION_IF_NULL(embedding_host_cache_); + RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, + device_cache_device_to_host_index, + swap_indices_size * sizeof(int))); + RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut( + hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, + hash_table_size, embedding_size, swap_indices_size)); + RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost( + swap_out_data.get(), embedding_device_cache_->hash_swap_value_addr_, + swap_indices_size * embedding_size * sizeof(float))); + RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeStream()); + RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), host_cache_device_to_host_index, + swap_out_data.get(), host_hash_table_addr)); + return true; +} + +bool PsCacheManager::HashSwapHostToServer(size_t key, const HashTableInfo &hash_info) { + MS_ERROR_IF_NULL(embedding_host_cache_); auto host_to_server_ids = embedding_host_cache_->host_to_server_ids.get(); auto host_to_server_index = embedding_host_cache_->host_to_server_index.get(); auto swap_indices_size = statistics_info_.host_to_server_size_; if (swap_indices_size == 0) { - return; + return true; } ::ps::SArray lookup_ids(swap_indices_size, 0); ::ps::SArray swap_out_data; auto embedding_size = hash_info.embedding_size; swap_out_data.resize(swap_indices_size * embedding_size); auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); - LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, - swap_out_data.data()); + RETURN_IF_FALSE(LookUpHostHashTable(embedding_size, swap_indices_size, host_hash_table_addr, host_to_server_index, + swap_out_data.data())); auto copy_len = swap_indices_size * sizeof(int); auto ret = memcpy_s(lookup_ids.data(), copy_len, host_to_server_ids, copy_len); if (ret != EOK) { - MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + MS_LOG(ERROR) << "Lookup id memcpy failed."; + return false; } worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); + return true; } -void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) { - MS_EXCEPTION_IF_NULL(embedding_host_cache_); +bool PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_info) { + MS_ERROR_IF_NULL(embedding_host_cache_); auto swap_indices_size = statistics_info_.server_to_host_size_; auto server_to_host_ids = embedding_host_cache_->server_to_host_ids.get(); auto server_to_host_index = embedding_host_cache_->server_to_host_index.get(); if (swap_indices_size == 0) { - return; + return true; } auto host_hash_table_addr = reinterpret_cast(hash_info.host_address.get()); auto embedding_size = hash_info.embedding_size; @@ -693,47 +741,50 @@ void PsCacheManager::HashSwapServerToHost(size_t key, const HashTableInfo &hash_ auto copy_len = swap_indices_size * sizeof(int); auto ret = memcpy_s(lookup_ids.data(), copy_len, server_to_host_ids, copy_len); if (ret != EOK) { - MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + MS_LOG(ERROR) << "Lookup id memcpy failed."; + return false; } worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); - InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, lookup_result.data(), - host_hash_table_addr); + RETURN_IF_FALSE(InsertHostHashTable(embedding_size, IntToSize(swap_indices_size), server_to_host_index, + lookup_result.data(), host_hash_table_addr)); + return true; } -void PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray *swap_out_data, +bool PsCacheManager::HashSwapDeviceOut(int *swap_out_index, ::ps::SArray *swap_out_data, const HashTableInfo &hash_info) { - MS_EXCEPTION_IF_NULL(swap_out_index); - MS_EXCEPTION_IF_NULL(swap_out_data); - MS_EXCEPTION_IF_NULL(embedding_device_cache_); - MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); + MS_ERROR_IF_NULL(swap_out_index); + MS_ERROR_IF_NULL(swap_out_data); + MS_ERROR_IF_NULL(embedding_device_cache_); + MS_ERROR_IF_NULL(embedding_device_cache_->cache_); auto swap_out_index_size = statistics_info_.device_to_host_size_; if (swap_out_index_size == 0) { - return; + return true; } auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); auto hash_table_size = hash_info.device_address.size; auto embedding_size = hash_info.embedding_size; swap_out_data->resize(swap_out_index_size * embedding_size); - embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_out_index, - swap_out_index_size * sizeof(int)); - embedding_device_cache_->cache_->HashSwapOut(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, - embedding_device_cache_->hash_swap_index_addr_, hash_table_size, - embedding_size, swap_out_index_size); - embedding_device_cache_->cache_->CopyDeviceMemToHost(swap_out_data->data(), - embedding_device_cache_->hash_swap_value_addr_, - swap_out_index_size * embedding_size * sizeof(float)); - embedding_device_cache_->cache_->RecordEvent(); -} - -void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, + RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( + embedding_device_cache_->hash_swap_index_addr_, swap_out_index, swap_out_index_size * sizeof(int))); + RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapOut( + hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, + hash_table_size, embedding_size, swap_out_index_size)); + RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyDeviceMemToHost( + swap_out_data->data(), embedding_device_cache_->hash_swap_value_addr_, + swap_out_index_size * embedding_size * sizeof(float))); + RETURN_IF_FALSE(embedding_device_cache_->cache_->RecordEvent()); + return true; +} + +bool PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key) { - MS_EXCEPTION_IF_NULL(swap_in_ids); - MS_EXCEPTION_IF_NULL(swap_in_index); - MS_EXCEPTION_IF_NULL(embedding_device_cache_); - MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); + MS_ERROR_IF_NULL(swap_in_ids); + MS_ERROR_IF_NULL(swap_in_index); + MS_ERROR_IF_NULL(embedding_device_cache_); + MS_ERROR_IF_NULL(embedding_device_cache_->cache_); auto swap_in_ids_size = statistics_info_.host_to_device_size_; if (swap_in_ids_size == 0) { - return; + return true; } auto hash_table_addr = reinterpret_cast(hash_info.device_address.addr); auto hash_table_size = hash_info.device_address.size; @@ -745,42 +796,44 @@ void PsCacheManager::HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, cons auto copy_len = swap_in_ids_size * sizeof(int); auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_in_ids, copy_len); if (ret != EOK) { - MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + MS_LOG(ERROR) << "Lookup id memcpy failed."; + return false; } worker.DoPSEmbeddingLookup({key}, lookup_ids, lengths, &lookup_result, mindspore::ps::kEmbeddingLookupCmd); // Hash swap-in in device. - embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_value_addr_, - lookup_result.data(), - swap_in_ids_size * embedding_size * sizeof(float)); - embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, swap_in_index, - swap_in_ids_size * sizeof(int)); - embedding_device_cache_->cache_->HashSwapIn(hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, - embedding_device_cache_->hash_swap_index_addr_, hash_table_size, - embedding_size, swap_in_ids_size); + RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice( + embedding_device_cache_->hash_swap_value_addr_, lookup_result.data(), + swap_in_ids_size * embedding_size * sizeof(float))); + RETURN_IF_FALSE(embedding_device_cache_->cache_->CopyHostMemToDevice(embedding_device_cache_->hash_swap_index_addr_, + swap_in_index, swap_in_ids_size * sizeof(int))); + RETURN_IF_FALSE(embedding_device_cache_->cache_->HashSwapIn( + hash_table_addr, embedding_device_cache_->hash_swap_value_addr_, embedding_device_cache_->hash_swap_index_addr_, + hash_table_size, embedding_size, swap_in_ids_size)); + return true; } -void PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray &swap_out_data, int *swap_out_ids, size_t key) { - MS_EXCEPTION_IF_NULL(embedding_device_cache_); - MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); - MS_EXCEPTION_IF_NULL(swap_out_ids); +bool PsCacheManager::UpdataEmbeddingTable(const ::ps::SArray &swap_out_data, int *swap_out_ids, size_t key) { + MS_ERROR_IF_NULL(embedding_device_cache_); + MS_ERROR_IF_NULL(embedding_device_cache_->cache_); + MS_ERROR_IF_NULL(swap_out_ids); auto swap_out_ids_size = statistics_info_.device_to_host_size_; if (swap_out_ids_size == 0) { - return; + return true; } ::ps::SArray lookup_ids(swap_out_ids_size, 0); auto copy_len = swap_out_ids_size * sizeof(int); auto ret = memcpy_s(lookup_ids.data(), copy_len, swap_out_ids, copy_len); if (ret != EOK) { - MS_LOG(EXCEPTION) << "Lookup id memcpy failed."; + MS_LOG(ERROR) << "Lookup id memcpy failed."; + return false; } // Need synchronize event to ensure that the swap-out in device is completed. - embedding_device_cache_->cache_->SynchronizeEvent(); + RETURN_IF_FALSE(embedding_device_cache_->cache_->SynchronizeEvent()); worker.UpdateEmbeddingTable({key}, lookup_ids, swap_out_data); + return true; } void PsCacheManager::DumpHashTables(bool dump_device_tables) const { - MS_EXCEPTION_IF_NULL(embedding_device_cache_); - MS_EXCEPTION_IF_NULL(embedding_device_cache_->cache_); for (const auto &item : hash_tables_) { const auto ¶m_name = item.first; size_t cache_vocab_size = item.second.cache_vocab_size; diff --git a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h index db8b18bbb0..99eb6cdb00 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_cache_manager.h @@ -126,6 +126,8 @@ class PsCacheManager { bool initialized_ps_cache() const { return initialized_ps_cache_; } void DoProcessData(uint32_t device_id, void *context); void IncreaseGraphStep(const std::string &channel_name); + bool terminated() const { return terminated_; } + void Finalize(); void DumpHashTables(bool dump_device_tables = false) const; private: @@ -133,7 +135,7 @@ class PsCacheManager { ~PsCacheManager() = default; PsCacheManager(const PsCacheManager &) = delete; PsCacheManager &operator=(const PsCacheManager &) = delete; - void IncreaseStep(); + bool IncreaseStep(); void set_current_graph_step() { graph_running_step_ = graph_step_; } std::string channel_name(); void set_channel_name(const std::string channel_name); @@ -141,23 +143,23 @@ class PsCacheManager { void AllocMemForHashTable(); void SetLocalIdRank(); void ProcessDataTask(uint32_t device_id, void *context); - void ProcessData(); - void ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); - void WaitGraphRun(); + bool ProcessData(); + bool ParseData(const int *batch_ids, const size_t batch_ids_len, int *hash_index); + bool WaitGraphRun(); int ParseDeviceData(size_t id, bool *need_swap_device_to_host, bool *need_swap_host_to_device); - void ParseHostDataHostToDevice(size_t id); - void ParseHostDataDeviceToHost(size_t id); - void HashSwapDeviceOut(int *swap_out_index, ::ps::SArray *swap_out_data, const HashTableInfo &hash_info); - void HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); - void HashSwapHostToDevice(const HashTableInfo &hash_info); - void HashSwapDeviceToHost(const HashTableInfo &hash_info); - void HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); - void HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); - void InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, + bool ParseHostDataHostToDevice(size_t id); + bool ParseHostDataDeviceToHost(size_t id); + bool HashSwapDeviceOut(int *swap_out_index, ::ps::SArray *swap_out_data, const HashTableInfo &hash_info); + bool HashSwapDeviceIn(int *swap_in_ids, int *swap_in_index, const HashTableInfo &hash_info, size_t key); + bool HashSwapHostToDevice(const HashTableInfo &hash_info); + bool HashSwapDeviceToHost(const HashTableInfo &hash_info); + bool HashSwapHostToServer(size_t key, const HashTableInfo &hash_info); + bool HashSwapServerToHost(size_t key, const HashTableInfo &hash_info); + bool InsertHostHashTable(size_t embedding_size, size_t insert_indices_size, int *insert_indices, float *insert_data, float *hash_table_addr); - void LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, + bool LookUpHostHashTable(size_t embedding_size, size_t indices_lens, const float *hash_table_addr, const int *indices_addr, float *output_addr); - void UpdataEmbeddingTable(const ::ps::SArray &swap_out_data, int *swap_out_ids, size_t key); + bool UpdataEmbeddingTable(const ::ps::SArray &swap_out_data, int *swap_out_ids, size_t key); void LookUpTableTask(size_t indices_lens, size_t outer_dim_size, size_t first_dim_size, const float *input_addr, const int *indices_addr, float *output_addr); bool CheckFinishInsertInitInfo() const; @@ -172,6 +174,7 @@ class PsCacheManager { std::mutex data_mutex_; std::condition_variable data_prase_; std::condition_variable insert_init_info_; + std::thread process_data_thread_; std::map hash_tables_; std::shared_ptr embedding_device_cache_; @@ -185,6 +188,8 @@ class PsCacheManager { std::pair range_bound_; std::atomic_bool finish_insert_init_info_{false}; std::atomic_bool finish_init_parameter_server_{false}; + std::atomic_bool running_{false}; + std::atomic_bool terminated_{false}; }; static PsCacheManager &ps_cache_instance = PsCacheManager::GetInstance(); diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc index 33cb35c44d..f7b08ef860 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.cc @@ -28,11 +28,9 @@ void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t s if (iter != ps_data_channel_map_.end()) { MS_LOG(WARNING) << "The ps data channel already exists, channel name:" << channel_name; auto channel = iter->second; - MS_EXCEPTION_IF_NULL(channel); channel->set_step_num(step_num); } else { auto channel = std::make_shared(channel_name, step_num); - MS_EXCEPTION_IF_NULL(channel); (void)ps_data_channel_map_.emplace(channel_name, channel); } } @@ -40,71 +38,95 @@ void PsDataPrefetch::CreateDataChannel(const std::string &channel_name, size_t s std::shared_ptr PsDataPrefetch::ps_data_channel(const std::string &channel_name) const { auto iter = ps_data_channel_map_.find(channel_name); if (iter == ps_data_channel_map_.end()) { - MS_LOG(EXCEPTION) << "The ps data channel does not exist, channel name:" << channel_name; + MS_LOG(ERROR) << "The ps data channel does not exist, channel name:" << channel_name; + return nullptr; } return iter->second; } -void PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { +bool PsDataPrefetch::PrefetchData(const std::string &channel_name, void *data, const size_t data_size) { if (cache_enable_ == false) { - return; + return true; } if (data == nullptr) { MS_LOG(WARNING) << "No data prefetch."; - return; + return true; } auto channel = ps_data_channel(channel_name); - MS_EXCEPTION_IF_NULL(channel); + MS_ERROR_IF_NULL(channel); channel->set_data(data, data_size); std::unique_lock locker(data_mutex_); data_ready_ = true; data_process_.notify_one(); + if (!need_wait_) { + return true; + } for (int i = 0; i < 10; i++) { - if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == false; })) { - return; + if (data_prefetch_.wait_for(locker, std::chrono::seconds(30), + [this] { return data_ready_ == false || need_wait_ == false; })) { + return true; } else { MS_LOG(INFO) << "Waiting for ps data process, channel name:" << channel_name << "...(" << i << " / 10)"; } } - MS_LOG(EXCEPTION) << "Ps cache data process timeout, suggest to enlarge the cache size."; + MS_LOG(ERROR) << "Ps cache data process timeout, suggest to enlarge the cache size."; + return false; } -void PsDataPrefetch::FinalizeData(const std::string &channel_name) { +bool PsDataPrefetch::FinalizeData(const std::string &channel_name) { if (cache_enable_ == false) { - return; + return true; } auto channel = ps_data_channel(channel_name); - MS_EXCEPTION_IF_NULL(channel); + MS_ERROR_IF_NULL(channel); channel->ResetData(); std::unique_lock locker(data_mutex_); data_ready_ = false; data_prefetch_.notify_one(); + if (!need_wait_) { + return true; + } for (int i = 0; i < 10; i++) { - if (data_process_.wait_for(locker, std::chrono::seconds(30), [this] { return data_ready_ == true; })) { - return; + if (data_process_.wait_for(locker, std::chrono::seconds(30), + [this] { return data_ready_ == true || need_wait_ == false; })) { + return true; } else { MS_LOG(INFO) << "Waiting for ps data prefetch, channel name:" << channel_name << "...(" << i << " / 10)"; } } - MS_LOG(EXCEPTION) << "Ps cache data prefetch timeout."; + MS_LOG(ERROR) << "Ps cache data prefetch timeout."; + return false; } void *PsDataPrefetch::data(const std::string &channel_name) const { auto channel = ps_data_channel(channel_name); - MS_EXCEPTION_IF_NULL(channel); + if (channel == nullptr) { + return nullptr; + } return channel->data(); } size_t PsDataPrefetch::data_size(const std::string &channel_name) const { auto channel = ps_data_channel(channel_name); - MS_EXCEPTION_IF_NULL(channel); + if (channel == nullptr) { + return 0; + } return channel->data_size(); } -void PsDataPrefetch::TryWakeChannel(const std::string &channel_name) { +void PsDataPrefetch::NotifyFinalize() { + need_wait_ = false; + data_prefetch_.notify_one(); + data_process_.notify_one(); +} + +bool PsDataPrefetch::TryWakeChannel(const std::string &channel_name) { auto channel = ps_data_channel(channel_name); - MS_EXCEPTION_IF_NULL(channel); + if (channel == nullptr) { + return false; + } channel->TryWakeChannel(); + return true; } } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h index 044e6f834b..0d40961076 100644 --- a/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h +++ b/mindspore/ccsrc/ps/ps_cache/ps_data/ps_data_prefetch.h @@ -19,6 +19,7 @@ #include #include #include +#include #include #include "ps/ps_cache/ps_data/ps_data_channel.h" @@ -36,11 +37,12 @@ class EXPORT PsDataPrefetch { EXPORT bool cache_enable() const { return cache_enable_; } EXPORT void set_cache_enable(bool cache_enable) { cache_enable_ = cache_enable; } EXPORT void CreateDataChannel(const std::string &channel_name, size_t step_num); - EXPORT void PrefetchData(const std::string &channel_name, void *data, const size_t data_size); - EXPORT void FinalizeData(const std::string &channel_name); + EXPORT bool PrefetchData(const std::string &channel_name, void *data, const size_t data_size); + EXPORT bool FinalizeData(const std::string &channel_name); + EXPORT void NotifyFinalize(); EXPORT void *data(const std::string &channel_name) const; EXPORT size_t data_size(const std::string &channel_name) const; - EXPORT void TryWakeChannel(const std::string &channel_name); + EXPORT bool TryWakeChannel(const std::string &channel_name); private: PsDataPrefetch() : cache_enable_(false), data_ready_(false) {} @@ -54,6 +56,7 @@ class EXPORT PsDataPrefetch { std::mutex data_mutex_; std::condition_variable data_prefetch_; std::condition_variable data_process_; + std::atomic_bool need_wait_{true}; }; } // namespace ps } // namespace mindspore diff --git a/mindspore/ccsrc/ps/ps_context.cc b/mindspore/ccsrc/ps/ps_context.cc index 7830e713a6..5d3942f96c 100644 --- a/mindspore/ccsrc/ps/ps_context.cc +++ b/mindspore/ccsrc/ps/ps_context.cc @@ -17,10 +17,10 @@ #include "ps/ps_context.h" #include "utils/log_adapter.h" #include "utils/ms_utils.h" -#include "ps/ps_cache/ps_data/ps_data_prefetch.h" #include "backend/kernel_compiler/kernel.h" #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) #include "ps/ps_cache/ps_cache_manager.h" +#include "ps/ps_cache/ps_data/ps_data_prefetch.h" #endif namespace mindspore { @@ -62,7 +62,12 @@ void PSContext::Reset() { is_worker_ = false; is_pserver_ = false; is_sched_ = false; - set_cache_enable(false); +#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) + if (ps::PsDataPrefetch::GetInstance().cache_enable()) { + ps_cache_instance.Finalize(); + set_cache_enable(false); + } +#endif } std::string PSContext::ms_role() const { diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_common.h b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h index d9cefc4275..8b2a97dabf 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_common.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_common.h @@ -62,6 +62,16 @@ namespace gpu { } \ } +#define CHECK_CUDA_RET_WITH_RETURN_ERROR_NOTRACE(expression, message) \ + { \ + cudaError_t status = (expression); \ + if (status != cudaSuccess) { \ + MS_LOG(ERROR) << "CUDA Error: " << message << " | Error Number: " << status << " " \ + << cudaGetErrorString(status); \ + return false; \ + } \ + } + #define CHECK_CUDA_RET_WITH_EXCEPT(node, expression, message) \ { \ cudaError_t status = (expression); \ diff --git a/mindspore/core/utils/log_adapter.h b/mindspore/core/utils/log_adapter.h index 84849843bf..7020a30e7a 100644 --- a/mindspore/core/utils/log_adapter.h +++ b/mindspore/core/utils/log_adapter.h @@ -199,6 +199,14 @@ class LogWriter { } \ } while (0) +#define MS_ERROR_IF_NULL(ptr) \ + do { \ + if ((ptr) == nullptr) { \ + MS_LOG(ERROR) << ": The pointer[" << #ptr << "] is null."; \ + return false; \ + } \ + } while (0) + #ifdef DEBUG #include #define MS_ASSERT(f) assert(f)