fix 910 cpp inference multi device id

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
pull/12464/head
zhoufeng 4 years ago
parent 02737b5e32
commit 931c4f90a2

@ -276,7 +276,7 @@ Status AscendGraphImpl::Run(const std::vector<MSTensor> &inputs, std::vector<MST
} }
AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) { AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) {
MS_LOG(INFO) << "Start to init env."; MS_LOG(INFO) << "Start to init device " << device_id;
device_id_ = device_id; device_id_ = device_id;
RegAllOp(); RegAllOp();
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
@ -294,49 +294,54 @@ AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) {
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]"; MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
} }
MS_LOG(INFO) << "InitEnv success."; MS_LOG(INFO) << "Device " << device_id << " init env success.";
errno_ = kSuccess; errno_ = kSuccess;
} }
AscendGraphImpl::MsEnvGuard::~MsEnvGuard() { AscendGraphImpl::MsEnvGuard::~MsEnvGuard() {
MS_LOG(INFO) << "Start finalize env"; MS_LOG(INFO) << "Start finalize device " << device_id_;
session::ExecutorManager::Instance().Clear(); session::ExecutorManager::Instance().Clear();
device::KernelRuntimeManager::Instance().ClearRuntimeResource(); device::KernelRuntimeManager::Instance().ClearRuntimeResource();
auto ms_context = MsContext::GetInstance(); auto ms_context = MsContext::GetInstance();
if (ms_context == nullptr) { if (ms_context == nullptr) {
MS_LOG(ERROR) << "Get Context failed!"; MS_LOG(ERROR) << "Get Context failed!";
errno_ = kMCFailed;
return; return;
} }
auto ret = rtDeviceReset(device_id_); auto ret = rtDeviceReset(device_id_);
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]"; MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
return;
} }
errno_ = kSuccess; MS_LOG(INFO) << "End finalize device " << device_id_;
MS_LOG(INFO) << "End finalize env";
} }
std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv(uint32_t device_id) { std::shared_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::GetEnv(uint32_t device_id) {
std::shared_ptr<MsEnvGuard> acl_env; std::shared_ptr<MsEnvGuard> acl_env;
std::lock_guard<std::mutex> lock(global_ms_env_mutex_); std::lock_guard<std::mutex> lock(global_ms_env_mutex_);
acl_env = global_ms_env_.lock(); auto iter = global_ms_env_.find(device_id);
if (iter != global_ms_env_.end()) {
acl_env = iter->second.lock();
}
if (acl_env != nullptr) { if (acl_env != nullptr) {
MS_LOG(INFO) << "Env has been initialized, skip."; MS_LOG(INFO) << "Env has been initialized, skip.";
} else { return acl_env;
acl_env = std::make_shared<MsEnvGuard>(device_id);
if (acl_env->GetErrno() != kSuccess) {
MS_LOG(ERROR) << "Execute aclInit Failed";
return nullptr;
}
global_ms_env_ = acl_env;
MS_LOG(INFO) << "Env init success";
} }
acl_env = std::make_shared<MsEnvGuard>(device_id);
if (acl_env->GetErrno() != kSuccess) {
MS_LOG(ERROR) << "Init ascend env Failed";
return nullptr;
}
global_ms_env_.emplace(device_id, acl_env);
MS_LOG(INFO) << "Env init success";
return acl_env; return acl_env;
} }
std::weak_ptr<AscendGraphImpl::MsEnvGuard> AscendGraphImpl::MsEnvGuard::global_ms_env_; std::map<uint32_t, std::weak_ptr<AscendGraphImpl::MsEnvGuard>> AscendGraphImpl::MsEnvGuard::global_ms_env_;
std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_; std::mutex AscendGraphImpl::MsEnvGuard::global_ms_env_mutex_;
} // namespace mindspore } // namespace mindspore

@ -73,7 +73,7 @@ class AscendGraphImpl::MsEnvGuard {
static std::shared_ptr<MsEnvGuard> GetEnv(uint32_t device_id); static std::shared_ptr<MsEnvGuard> GetEnv(uint32_t device_id);
private: private:
static std::weak_ptr<MsEnvGuard> global_ms_env_; static std::map<uint32_t, std::weak_ptr<MsEnvGuard>> global_ms_env_;
static std::mutex global_ms_env_mutex_; static std::mutex global_ms_env_mutex_;
Status errno_; Status errno_;

@ -85,20 +85,20 @@ Graph Serialization::LoadModel(const void *model_data, size_t data_size, ModelTy
} }
Graph Serialization::LoadModel(const std::string &file, ModelType model_type) { Graph Serialization::LoadModel(const std::string &file, ModelType model_type) {
Buffer data = ReadFile(file);
if (data.Data() == nullptr) {
MS_LOG(EXCEPTION) << "Read file " << file << " failed.";
}
if (model_type == kMindIR) { if (model_type == kMindIR) {
FuncGraphPtr anf_graph = nullptr; FuncGraphPtr anf_graph = nullptr;
try { try {
anf_graph = ConvertStreamToFuncGraph(reinterpret_cast<const char *>(data.Data()), data.DataSize()); anf_graph = LoadMindIR(file);
} catch (const std::exception &) { } catch (const std::exception &) {
MS_LOG(EXCEPTION) << "Load MindIR failed."; MS_LOG(EXCEPTION) << "Load MindIR failed.";
} }
return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR)); return Graph(std::make_shared<Graph::GraphData>(anf_graph, kMindIR));
} else if (model_type == kOM) { } else if (model_type == kOM) {
Buffer data = ReadFile(file);
if (data.Data() == nullptr) {
MS_LOG(EXCEPTION) << "Read file " << file << " failed.";
}
return Graph(std::make_shared<Graph::GraphData>(data, kOM)); return Graph(std::make_shared<Graph::GraphData>(data, kOM));
} }
MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type; MS_LOG(EXCEPTION) << "Unsupported ModelType " << model_type;

Loading…
Cancel
Save