|
|
|
@ -25,10 +25,15 @@
|
|
|
|
|
#include "backend/session/executor_manager.h"
|
|
|
|
|
#include "runtime/device/kernel_runtime_manager.h"
|
|
|
|
|
#include "runtime/dev.h"
|
|
|
|
|
#include "pipeline/jit/pipeline.h"
|
|
|
|
|
#include "frontend/parallel/step_parallel.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
API_FACTORY_REG(GraphCell::GraphImpl, Ascend910, AscendGraphImpl);
|
|
|
|
|
|
|
|
|
|
static constexpr const char *kHcclEnable = "MS_ENABLE_HCCL";
|
|
|
|
|
static constexpr const char *kHcclGroupFile = "PARA_GROUP_FILE";
|
|
|
|
|
|
|
|
|
|
AscendGraphImpl::AscendGraphImpl()
|
|
|
|
|
: session_impl_(nullptr),
|
|
|
|
|
graph_id_(0),
|
|
|
|
@ -209,11 +214,11 @@ Status AscendGraphImpl::Load() {
|
|
|
|
|
}
|
|
|
|
|
session_impl_->GetModelInputsInfo(graph_id_, &inputs_info_, &input_names_);
|
|
|
|
|
session_impl_->GetModelOutputsInfo(graph_id_, &outputs_info_, &output_names_);
|
|
|
|
|
if (inputs_info_.empty() || inputs_info_.size() != input_names_.size()) {
|
|
|
|
|
if (inputs_info_.size() != input_names_.size()) {
|
|
|
|
|
MS_LOG_ERROR << "Get model inputs info failed";
|
|
|
|
|
return kMCInvalidInput;
|
|
|
|
|
}
|
|
|
|
|
if (outputs_info_.empty() || outputs_info_.size() != output_names_.size()) {
|
|
|
|
|
if (outputs_info_.size() != output_names_.size()) {
|
|
|
|
|
MS_LOG_ERROR << "Get model outputs info failed";
|
|
|
|
|
return kMCInvalidInput;
|
|
|
|
|
}
|
|
|
|
@ -287,12 +292,34 @@ AscendGraphImpl::MsEnvGuard::MsEnvGuard(uint32_t device_id) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto env_hccl_mode = common::GetEnv(kHcclEnable);
|
|
|
|
|
if (!env_hccl_mode.empty() && env_hccl_mode != std::to_string(0)) {
|
|
|
|
|
MS_LOG(INFO) << "Enable hccl parallel mode.";
|
|
|
|
|
ms_context->set_param<bool>(MS_CTX_ENABLE_HCCL, true);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ms_context->set_param<int>(MS_CTX_EXECUTION_MODE, kGraphMode);
|
|
|
|
|
ms_context->set_param<uint32_t>(MS_CTX_DEVICE_ID, device_id_);
|
|
|
|
|
ms_context->set_param<std::string>(MS_CTX_DEVICE_TARGET, kAscendDevice);
|
|
|
|
|
auto ret = rtSetDevice(device_id_);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
|
|
|
|
|
|
|
|
|
|
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
|
|
|
|
|
pipeline::InitHccl();
|
|
|
|
|
auto para_group_file = common::GetEnv(kHcclGroupFile);
|
|
|
|
|
if (para_group_file.empty()) {
|
|
|
|
|
MS_LOG(INFO) << "Cannot get Env " << kHcclGroupFile << ", skip.";
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(INFO) << "Get env " << kHcclGroupFile << " success: " << para_group_file;
|
|
|
|
|
if (!parallel::CreateGroupsByCkptFile(para_group_file)) {
|
|
|
|
|
MS_LOG(ERROR) << "CreateGroupsByCkptFile failed.";
|
|
|
|
|
errno_ = kMCFailed;
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto ret = rtSetDevice(device_id_);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Device " << device_id_ << " call rtSetDevice failed, ret[" << static_cast<int>(ret) << "]";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "Device " << device_id << " init env success.";
|
|
|
|
@ -310,10 +337,18 @@ AscendGraphImpl::MsEnvGuard::~MsEnvGuard() {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto ret = rtDeviceReset(device_id_);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
|
|
|
|
|
return;
|
|
|
|
|
if (ms_context->get_param<bool>(MS_CTX_ENABLE_HCCL)) {
|
|
|
|
|
PythonEnvGuard guard;
|
|
|
|
|
if (!context::CloseTsd(ms_context)) {
|
|
|
|
|
MS_LOG(ERROR) << "CloseTsd failed!";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto ret = rtDeviceReset(device_id_);
|
|
|
|
|
if (ret != RT_ERROR_NONE) {
|
|
|
|
|
MS_LOG(ERROR) << "Device " << device_id_ << " call rtDeviceReset failed, ret[" << static_cast<int>(ret) << "]";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MS_LOG(INFO) << "End finalize device " << device_id_;
|
|
|
|
|