910 parallel inference

Signed-off-by: zhoufeng <zhoufeng54@huawei.com>
pull/12582/head
zhoufeng 4 years ago
parent 095d7fb877
commit 7d54d627f1

@ -25,10 +25,15 @@
#include "backend/session/executor_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "runtime/dev.h"
#include "pipeline/jit/pipeline.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);
static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL";
static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE";
AscendGraphImpl::AscendGraphImpl()
: session_impl_(nullptr),
graph_id_(0),
@ -209,11 +214,11 @@ Status AscendGraphImpl::Load() {
}
session_impl_->GetModelInputsInfo(graph_id_, &inputs_info_, &input_names_);
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_info_, &output_names_);
if (inputs_info_.empty() || inputs_info_.size() != input_names_.size()) {
if (inputs_info_.size() != input_names_.size()) {
MS_LOG_ERROR << "Get model inputs info failed";
return kMCInvalidInput;
}
if (outputs_info_.empty() || outputs_info_.size() != output_names_.size()) {
if (outputs_info_.size() != output_names_.size()) {
MS_LOG_ERROR << "Get model outputs info failed";
return kMCInvalidInput;
}
@ -287,12 +292,34 @@ AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) {
return;
}
auto env_hccl_mode = common::GetEnv(kHcclEnable);
if (!env_hccl_mode.empty() && env_hccl_mode != std::to_string(0)) {
MS_LOG(INFO) << "Enable hccl parallel mode.";
ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
}
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
auto ret = rtSetDevice(device_id_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
pipeline::InitHccl();
auto para_group_file = common::GetEnv(kHcclGroupFile);
if (para_group_file.empty()) {
MS_LOG(INFO) << "Cannot get Env " << kHcclGroupFile << ", skip.";
} else {
MS_LOG(INFO) << "Get env " << kHcclGroupFile << " success: " << para_group_file;
if (!parallel::CreateGroupsByCkptFile(para_group_file)) {
MS_LOG(ERROR) << "CreateGroupsByCkptFile failed.";
errno_ = kMCFailed;
return;
}
}
} else {
auto ret = rtSetDevice(device_id_);
if (ret != RT_ERROR_NONE) {
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
}
}
MS_LOG(INFO) << "Device " << device_id << " init env success.";
@ -310,10 +337,18 @@ AscendGraphImpl::MsEnvGuard::~MsEnvGuard() {
return;
}
auto ret = rtDeviceReset(device_id_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
return;
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
PythonEnvGuard guard;
if (!context::CloseTsd(ms_context)) {
MS_LOG(ERROR) << "CloseTsd failed!";
return;
}
} else {
auto ret = rtDeviceReset(device_id_);
if (ret != RT_ERROR_NONE) {
MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
return;
}
}
MS_LOG(INFO) << "End finalize device " << device_id_;

Loading…
Cancel
Save