From 931c4f90a2f78313bdd579582163ee54d5b4bffe Mon Sep 17 00:00:00 2001 From: zhoufeng Date: Sat, 20 Feb 2021 15:54:44 +0800 Subject: [PATCH] fix 910 cpp inference multi device id Signed-off-by: zhoufeng --- .../cxx_api/graph/ascend/ascend_graph_impl.cc | 39 +++++++++++-------- .../cxx_api/graph/ascend/ascend_graph_impl.h | 2 +- mindspore/ccsrc/cxx_api/serialization.cc | 10 ++--- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc index b0c8572bc3..3c47124341 100644 --- a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc +++ b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.cc @@ -276,7 +276,7 @@ Status AscendGraphImpl::Run(const std::vector &inputs, std::vector(ret) << "]"; } - MS_LOG(INFO) << "InitEnv success."; + MS_LOG(INFO) << "Device " << device_id << " init env success."; errno_ = kSuccess; } AscendGraphImpl::MsEnvGuard::~MsEnvGuard() { - MS_LOG(INFO) << "Start finalize env"; + MS_LOG(INFO) << "Start finalize device " << device_id_; session::ExecutorManager::Instance().Clear(); device::KernelRuntimeManager::Instance().ClearRuntimeResource(); auto ms_context = MsContext::GetInstance(); if (ms_context == nullptr) { MS_LOG(ERROR) << "Get Context failed!"; - errno_ = kMCFailed; return; } auto ret = rtDeviceReset(device_id_); if (ret != RT_ERROR_NONE) { - MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast(ret) << "]"; + MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast(ret) << "]"; + return; } - errno_ = kSuccess; - MS_LOG(INFO) << "End finalize env"; + MS_LOG(INFO) << "End finalize device " << device_id_; } std::shared_ptr AscendGraphImpl::MsEnvGuard::GetEnv(uint32_t device_id) { std::shared_ptr acl_env; std::lock_guard lock(global_ms_env_mutex_); - acl_env = global_ms_env_.lock(); + auto iter = global_ms_env_.find(device_id); + if (iter != global_ms_env_.end()) { + acl_env = iter->second.lock(); + } + if (acl_env != nullptr) { MS_LOG(INFO) << "Env has been initialized, skip."; - } else { - acl_env = std::make_shared(device_id); - if (acl_env->GetErrno() != kSuccess) { - MS_LOG(ERROR) << "Execute aclInit Failed"; - return nullptr; - } - global_ms_env_ = acl_env; - MS_LOG(INFO) << "Env init success"; + return acl_env; } + + acl_env = std::make_shared(device_id); + if (acl_env->GetErrno() != kSuccess) { + MS_LOG(ERROR) << "Init ascend env Failed"; + return nullptr; + } + + global_ms_env_.emplace(device_id, acl_env); + MS_LOG(INFO) << "Env init success"; return acl_env; } -std::weak_ptr AscendGraphImpl::MsEnvGuard::global_ms_env_; +std::map> AscendGraphImpl::MsEnvGuard::global_ms_env_; std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_; } // namespace mindspore diff --git a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.h b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.h index c4595dab93..b918e4ec4f 100644 --- a/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.h +++ b/mindspore/ccsrc/cxx_api/graph/ascend/ascend_graph_impl.h @@ -73,7 +73,7 @@ class AscendGraphImpl::MsEnvGuard { static std::shared_ptr GetEnv(uint32_t device_id); private: - static std::weak_ptr global_ms_env_; + static std::map> global_ms_env_; static std::mutex global_ms_env_mutex_; Status errno_; diff --git a/mindspore/ccsrc/cxx_api/serialization.cc b/mindspore/ccsrc/cxx_api/serialization.cc index 5ff271d8f1..43dc06952c 100644 --- a/mindspore/ccsrc/cxx_api/serialization.cc +++ b/mindspore/ccsrc/cxx_api/serialization.cc @@ -85,20 +85,20 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy } Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { - Buffer data = ReadFile(file); - if (data.Data() == nullptr) { - MS_LOG(EXCEPTION) << "Read file " << file << " failed."; - } if (model_type == kMindIR) { FuncGraphPtr anf_graph = nullptr; try { - anf_graph = ConvertStreamToFuncGraph(reinterpret_cast(data.Data()), data.DataSize()); + anf_graph = LoadMindIR(file); } catch (const std::exception &) { MS_LOG(EXCEPTION) << "Load MindIR failed."; } return Graph(std::make_shared(anf_graph, kMindIR)); } else if (model_type == kOM) { + Buffer data = ReadFile(file); + if (data.Data() == nullptr) { + MS_LOG(EXCEPTION) << "Read file " << file << " failed."; + } return Graph(std::make_shared(data, kOM)); } MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;