commit
a1ca7b48c7
@ -1 +1 @@
|
||||
Subproject commit 412ebe82c96620b5f7c942a7ab87a45bf14c5621
|
||||
Subproject commit 383f7f751d6612e9dbde9e22a2960098fdbf3792
|
@ -1,106 +0,0 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "runtime/device/ascend/tasksink/runtime_utils.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "hccl/hcom.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "hccl/hccl_types.h"
|
||||
#include "utils/utils.h"
|
||||
|
||||
constexpr auto kHcomBroadcast = "hcom_broadcast_";
|
||||
constexpr auto kHcomAllGather = "hcom_all_gather_";
|
||||
constexpr auto kHcomAllReduce = "hcom_all_reduce_";
|
||||
constexpr auto kHcomReduceScatter = "hcom_reduce_scatter_";
|
||||
constexpr auto kUnderline = "_";
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace tasksink {
|
||||
bool RuntimeUtils::HcomBindModel(rtModel_t model, rtStream_t stream) {
|
||||
HcclResult ret = hcom_bind_model(model, stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Call hcom_bind_model failed, ret: 0x" << static_cast<int>(ret);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RuntimeUtils::HcomUnbindModel(rtModel_t model) {
|
||||
HcclResult ret = hcom_unbind_model(model);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Call hcom_unbind_model failed, ret: 0x" << static_cast<int>(ret);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RuntimeUtils::HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info, rtStream_t stream) {
|
||||
MS_LOG(INFO) << "hccl distribute start";
|
||||
MS_EXCEPTION_IF_NULL(task_info);
|
||||
HcclResult ret;
|
||||
static uint32_t task_counter = 0;
|
||||
auto hccl_group = task_info->group();
|
||||
if (task_info->hccl_type() == kBroadcastOpName) {
|
||||
// call hcom broadcast interface to run op
|
||||
const string tag_broadcast = kHcomBroadcast + std::to_string(task_counter++) + kUnderline + std::to_string(0);
|
||||
ret = hcom_broadcast(tag_broadcast.c_str(), task_info->input_data_addr(), static_cast<u64>(task_info->count()),
|
||||
static_cast<HcclDataType>(task_info->data_type()), static_cast<u32>(task_info->root_id()),
|
||||
hccl_group.c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_broadcast fail, return ret: " << static_cast<int>(ret);
|
||||
return false;
|
||||
}
|
||||
} else if (task_info->hccl_type() == kAllGatherOpName) {
|
||||
// call hcom allgather interface to run op
|
||||
const string tag_all_gather = kHcomAllGather + std::to_string(task_counter++) + kUnderline + std::to_string(0);
|
||||
ret = hcom_all_gather(tag_all_gather.c_str(), task_info->input_data_addr(), task_info->output_data_addr(),
|
||||
static_cast<u64>(task_info->count()), static_cast<HcclDataType>(task_info->data_type()),
|
||||
hccl_group.c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_all_gather fail, return ret: " << ret;
|
||||
return false;
|
||||
}
|
||||
} else if (task_info->hccl_type() == kAllReduceOpName) {
|
||||
// call hcom allreduce interface to run op
|
||||
const string tag_all_reduce = kHcomAllReduce + std::to_string(task_counter++) + kUnderline + std::to_string(0);
|
||||
ret = hcom_all_reduce(tag_all_reduce.c_str(), task_info->input_data_addr(), task_info->output_data_addr(),
|
||||
static_cast<u64>(task_info->count()), static_cast<HcclDataType>(task_info->data_type()),
|
||||
static_cast<HcclReduceOp>(task_info->op_type()), hccl_group.c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_all_reduce fail, return ret: " << ret;
|
||||
return false;
|
||||
}
|
||||
} else if (task_info->hccl_type() == kReduceScatterOpName) {
|
||||
// call hcom reducescatter interface to run op
|
||||
const string tag_reduce_scatter =
|
||||
kHcomReduceScatter + std::to_string(task_counter++) + kUnderline + std::to_string(0);
|
||||
ret = hcom_reduce_scatter(tag_reduce_scatter.c_str(), task_info->input_data_addr(), task_info->output_data_addr(),
|
||||
static_cast<u64>(task_info->count()), static_cast<HcclDataType>(task_info->data_type()),
|
||||
static_cast<HcclReduceOp>(task_info->op_type()), hccl_group.c_str(), stream);
|
||||
if (ret != HCCL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "hcom_reduce_scatter fail, return ret: " << ret;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace tasksink
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
@ -1,39 +0,0 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TASKSINK_RUNTIME_UTILS_H_
|
||||
#define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TASKSINK_RUNTIME_UTILS_H_
|
||||
|
||||
#include <memory>
|
||||
#include "runtime/rt.h"
|
||||
#include "framework/ge_runtime/task_info.h"
|
||||
|
||||
using ge::model_runner::HcclTaskInfo;
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace ascend {
|
||||
namespace tasksink {
|
||||
class RuntimeUtils {
|
||||
public:
|
||||
static bool HcomBindModel(rtModel_t model, rtStream_t stream);
|
||||
static bool HcomUnbindModel(rtModel_t model);
|
||||
static bool HcomDistribute(const std::shared_ptr<HcclTaskInfo> &task_info, rtStream_t stream);
|
||||
};
|
||||
} // namespace tasksink
|
||||
} // namespace ascend
|
||||
} // namespace device
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TASKSINK_RUNTIME_UTILS_H_
|
@ -0,0 +1,8 @@
|
||||
file(GLOB_RECURSE HCCL_ADAPTER_SRC_LIST ./*.cc)
|
||||
set_property(SOURCE ${HCCL_ADAPTER_SRC_LIST} PROPERTY COMPILE_DEFINITIONS SUBMODULE_ID=mindspore::SubModuleId::SM_HCCL_ADPT)
|
||||
add_library(hccl_adapter SHARED ${HCCL_ADAPTER_SRC_LIST})
|
||||
target_include_directories(hccl_adapter PRIVATE ${CMAKE_BINARY_DIR}/proto/ge)
|
||||
add_dependencies(hccl_adapter graph)
|
||||
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
|
||||
target_link_options(hccl_adapter PRIVATE -Wl,-init,mindspore_log_init)
|
||||
endif ()
|
@ -0,0 +1,129 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime/hccl_adapter/converter.h"
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#include <tuple>
|
||||
#define google ascend_private
|
||||
#include "register/ops_kernel_builder_registry.h"
|
||||
#include "graph/compute_graph.h"
|
||||
#include "graph/debug/ge_attr_define.h"
|
||||
#undef google
|
||||
#include "backend/session/anf_runtime_algorithm.h"
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "mindspore/core/base/core_ops.h"
|
||||
#include "transform/graph_ir/util.h"
|
||||
|
||||
static constexpr char kGeOpNameHcclAllRudece[] = "HcomAllReduce";
|
||||
static constexpr char kGeOpNameHcclAllGather[] = "HcomAllGather";
|
||||
static constexpr char kGeOpNameHcclBroadcast[] = "HcomBroadcast";
|
||||
static constexpr char kGeOpNameHcclReduceScatter[] = "HcomReduceScatter";
|
||||
static constexpr char kGeNodeAttrUsedStreamNum[] = "used_stream_num";
|
||||
static constexpr char kStubDataStructureName[] = "any_name_can_work";
|
||||
|
||||
static ge::DataType ConvertHcclDTypeToGeDType(HcclDataType datatype) {
|
||||
static map<HcclDataType, ge::DataType> kHcomDataTypeMap = {
|
||||
{HCCL_DATA_TYPE_FP32, ge::DT_FLOAT},
|
||||
{HCCL_DATA_TYPE_FP16, ge::DT_FLOAT16},
|
||||
{HCCL_DATA_TYPE_INT8, ge::DT_INT8},
|
||||
{HCCL_DATA_TYPE_INT32, ge::DT_INT32},
|
||||
};
|
||||
|
||||
auto iter = kHcomDataTypeMap.find(datatype);
|
||||
if (iter == kHcomDataTypeMap.end()) {
|
||||
MS_LOG(EXCEPTION) << "Unknown hccl data type " << datatype;
|
||||
}
|
||||
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
namespace mindspore::hccl {
|
||||
std::string GetGeNodeName(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (IsPrimitiveCNode(cnode, prim::kPrimAllReduce)) {
|
||||
return kGeOpNameHcclAllRudece;
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimAllGather)) {
|
||||
return kGeOpNameHcclAllGather;
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimBroadcast)) {
|
||||
return kGeOpNameHcclBroadcast;
|
||||
} else if (IsPrimitiveCNode(cnode, prim::kPrimReduceScatter)) {
|
||||
return kGeOpNameHcclReduceScatter;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Unknown hccl node type " << cnode->DebugString();
|
||||
}
|
||||
|
||||
std::tuple<ge::NodePtr, ge::ComputeGraphPtr> GenerateStubGeNode(const AnfNodePtr &anf_node, HcclDataType datatype) {
|
||||
MS_EXCEPTION_IF_NULL(anf_node);
|
||||
auto cnode = anf_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
std::string ge_node_name = GetGeNodeName(cnode);
|
||||
|
||||
ge::OpDescPtr op_desc = std::make_shared<ge::OpDesc>(kStubDataStructureName, ge_node_name);
|
||||
MS_EXCEPTION_IF_NULL(op_desc);
|
||||
for (size_t i = 1; i < cnode->size(); ++i) {
|
||||
auto &input = cnode->input(i);
|
||||
std::vector<int64_t> ge_shape;
|
||||
auto ms_shape = AnfAlgo::GetOutputInferShape(input, 0);
|
||||
std::transform(ms_shape.begin(), ms_shape.end(), std::back_inserter(ge_shape),
|
||||
[](size_t in) { return static_cast<int64_t>(in); });
|
||||
op_desc->AddInputDesc(
|
||||
ge::GeTensorDesc(ge::GeShape(ge_shape), ge::Format::FORMAT_NCHW,
|
||||
transform::TransformUtil::ConvertDataType(AnfAlgo::GetOutputInferDataType(input, 0))));
|
||||
}
|
||||
|
||||
// set node data type
|
||||
bool ret = ge::AttrUtils::SetDataType(*op_desc, ge::HCOM_ATTR_DATA_TYPE, ConvertHcclDTypeToGeDType(datatype));
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Set attr " << ge::HCOM_ATTR_DATA_TYPE << " for ge node of " << cnode->DebugString()
|
||||
<< " failed.";
|
||||
}
|
||||
|
||||
// set rank size
|
||||
if (AnfAlgo::HasNodeAttr(kAttrRankSize, cnode)) {
|
||||
auto rank_size = AnfAlgo::GetNodeAttr<int64_t>(cnode, kAttrRankSize);
|
||||
ret = ge::AttrUtils::SetInt(*op_desc, ge::HCOM_ATTR_RANK_SIZE, rank_size);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Set attr " << ge::HCOM_ATTR_RANK_SIZE << " for ge node of " << cnode->DebugString()
|
||||
<< " failed.";
|
||||
}
|
||||
}
|
||||
|
||||
ge::ComputeGraphPtr ge_graph = std::make_shared<ge::ComputeGraph>(kStubDataStructureName);
|
||||
MS_EXCEPTION_IF_NULL(ge_graph);
|
||||
auto ge_node = ge_graph->AddNode(op_desc);
|
||||
return {ge_node, ge_graph};
|
||||
}
|
||||
|
||||
HcclTaskInfo ParseDomiTask(const ge::OpDescPtr &op, const domi::TaskDef &task_def) {
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
// workspace size
|
||||
auto workspace_sizes = op->GetWorkspaceBytes();
|
||||
if (workspace_sizes.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Unexpected workspace size " << workspace_sizes.size();
|
||||
}
|
||||
int64_t workspace_size = workspace_sizes[0];
|
||||
// stream num
|
||||
int64_t stream_num;
|
||||
bool ret = ge::AttrUtils::GetInt(*op, kGeNodeAttrUsedStreamNum, stream_num);
|
||||
if (!ret) {
|
||||
MS_LOG(EXCEPTION) << "Get attr " << kGeNodeAttrUsedStreamNum << " for ge node " << op->GetType() << " failed.";
|
||||
}
|
||||
|
||||
return {task_def.private_def(), workspace_size, stream_num};
|
||||
}
|
||||
} // namespace mindspore::hccl
|
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_RUNTIME_HCCL_ADAPTER_CONVERTER_H
|
||||
#define MINDSPORE_RUNTIME_HCCL_ADAPTER_CONVERTER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#define google ascend_private
|
||||
#include "graph/node.h"
|
||||
#include "common/opskernel/ops_kernel_info_types.h"
|
||||
#include "proto/task.pb.h"
|
||||
#undef google
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
#include "mindspore/core/ir/anf.h"
|
||||
|
||||
namespace mindspore::hccl {
|
||||
// return graph ptr to keep reference count
|
||||
std::tuple<ge::NodePtr, ge::ComputeGraphPtr> GenerateStubGeNode(const AnfNodePtr &anf_node, HcclDataType datatype);
|
||||
HcclTaskInfo ParseDomiTask(const ge::OpDescPtr &op, const domi::TaskDef &task_def);
|
||||
std::string GetGeNodeName(const CNodePtr &cnode);
|
||||
} // namespace mindspore::hccl
|
||||
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_CONVERTER_H
|
@ -0,0 +1,165 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "runtime/hccl_adapter/hccl_adapter.h"
|
||||
#include <map>
|
||||
#include <algorithm>
|
||||
#define google ascend_private
|
||||
#include "register/ops_kernel_builder_registry.h"
|
||||
#include "common/opskernel/ops_kernel_info_store.h"
|
||||
#include "external/ge/ge_api_types.h"
|
||||
#undef google
|
||||
#include "utils/log_adapter.h"
|
||||
#include "utils/ms_utils.h"
|
||||
#include "runtime/hccl_adapter/converter.h"
|
||||
#include "runtime/hccl_adapter/hcom_graph_adaptor.h"
|
||||
|
||||
static constexpr const char *kHcclOpsKernelInfoStore = "ops_kernel_info_hccl";
|
||||
static constexpr const char *kHcclDeployModeEnv = "DEPLOY_MODE";
|
||||
// following global var, thread safety is not guaranteed
|
||||
static std::shared_ptr<ge::OpsKernelInfoStore> ops_kernel_info_store = nullptr;
|
||||
static ge::OpsKernelBuilderPtr ops_kernel_builder = nullptr;
|
||||
|
||||
namespace mindspore::hccl {
|
||||
static std::map<std::string, std::string> GenHcclOptions(uint32_t device_id, std::string_view rank_id,
|
||||
std::string_view rank_file) {
|
||||
auto env_deploy_mode = common::GetEnv(kHcclDeployModeEnv);
|
||||
if (env_deploy_mode.empty()) {
|
||||
MS_LOG(WARNING) << kHcclDeployModeEnv << " is not set in ENV. Now set to default value 0";
|
||||
env_deploy_mode = "0";
|
||||
}
|
||||
|
||||
return std::map<std::string, std::string>({{ge::OPTION_EXEC_IS_USEHCOM, "1"},
|
||||
{ge::OPTION_EXEC_IS_USEHVD, "0"},
|
||||
{ge::OPTION_EXEC_HCCL_FLAG, "1"},
|
||||
{ge::OPTION_EXEC_DEVICE_ID, std::to_string(device_id)},
|
||||
{ge::OPTION_EXEC_RANK_ID, rank_id.data()},
|
||||
{ge::OPTION_EXEC_POD_NAME, rank_id.data()},
|
||||
{ge::OPTION_EXEC_RANK_TABLE_FILE, rank_file.data()},
|
||||
{ge::OPTION_GRAPH_RUN_MODE, "1"},
|
||||
{ge::OPTION_EXEC_HCCL_FLAG, "1"},
|
||||
{ge::OPTION_EXEC_DEPLOY_MODE, env_deploy_mode}});
|
||||
}
|
||||
|
||||
bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file) {
|
||||
MS_LOG(INFO) << "Start init hccl adapter.";
|
||||
// get ops_kernel_builder
|
||||
std::map<std::string, ge::OpsKernelBuilderPtr> all_builders = ge::OpsKernelBuilderRegistry::GetInstance().GetAll();
|
||||
if (all_builders.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Builders size should be 1 (hccl builder), but is " << all_builders.size();
|
||||
}
|
||||
|
||||
MS_LOG(INFO) << "Get builder " << all_builders.begin()->first;
|
||||
ops_kernel_builder = all_builders.begin()->second;
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_builder);
|
||||
// init ops_kernel_builder
|
||||
auto options = GenHcclOptions(device_id, rank_id, rank_file);
|
||||
auto ret = ops_kernel_builder->Initialize(options);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Init builder failed, ret = " << ret;
|
||||
}
|
||||
|
||||
// get ops_kernel_info_store
|
||||
ret = ::Initialize(options);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Init plugin so failed, ret = " << ret;
|
||||
}
|
||||
|
||||
std::map<std::string, std::shared_ptr<ge::OpsKernelInfoStore>> all_ops_kernel_info_stores;
|
||||
::GetOpsKernelInfoStores(all_ops_kernel_info_stores);
|
||||
for (auto &[name, ptr] : all_ops_kernel_info_stores) {
|
||||
if (name == kHcclOpsKernelInfoStore) {
|
||||
ops_kernel_info_store = ptr;
|
||||
break;
|
||||
}
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_info_store);
|
||||
ret = ops_kernel_info_store->Initialize(options);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(EXCEPTION) << "Init info store failed, ret = " << ret;
|
||||
}
|
||||
MS_LOG(INFO) << "Init hccl adapter success.";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FinalizeHccl() {
|
||||
MS_LOG(INFO) << "Start destroy hccl adapter.";
|
||||
if (ops_kernel_info_store != nullptr) {
|
||||
auto ret = ops_kernel_info_store->Finalize();
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Destory info store failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (ops_kernel_builder != nullptr) {
|
||||
auto ret = ops_kernel_builder->Finalize();
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "Destory builder failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
::Finalize();
|
||||
ge::OpsKernelBuilderRegistry::GetInstance().UnregisterAll();
|
||||
ops_kernel_info_store.reset();
|
||||
ops_kernel_builder.reset();
|
||||
MS_LOG(INFO) << "Destroy hccl adapter success.";
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists) {
|
||||
MS_EXCEPTION_IF_NULL(ops_kernel_builder);
|
||||
MS_EXCEPTION_IF_NULL(task_info_lists);
|
||||
MS_LOG(INFO) << "Start generate task for hccl node " << node->DebugString();
|
||||
auto [ge_node, ge_graph] = GenerateStubGeNode(node, datatype);
|
||||
MS_EXCEPTION_IF_NULL(ge_node);
|
||||
auto op = ge_node->GetOpDesc();
|
||||
MS_EXCEPTION_IF_NULL(op);
|
||||
|
||||
MS_LOG(INFO) << "Start to call CalcOpRunningParam";
|
||||
ge::Status ret = ops_kernel_builder->CalcOpRunningParam(*ge_node);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "OpsKernelBuilder CalcOpRunningParam failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
MS_LOG(INFO) << "Start to call GenerateTask";
|
||||
ge::RunContext unused_ctx;
|
||||
std::vector<domi::TaskDef> domi_tasks;
|
||||
ret = ops_kernel_builder->GenerateTask(*ge_node, unused_ctx, domi_tasks);
|
||||
if (ret != ge::SUCCESS) {
|
||||
MS_LOG(ERROR) << "OpsKernelBuilder GenerateTask failed, ret = " << ret;
|
||||
return false;
|
||||
}
|
||||
|
||||
task_info_lists->clear();
|
||||
std::transform(domi_tasks.begin(), domi_tasks.end(), std::back_inserter(*task_info_lists),
|
||||
[&op](const domi::TaskDef &task_def) -> HcclTaskInfo { return ParseDomiTask(op, task_def); });
|
||||
MS_LOG(INFO) << "Generate task for node " << node->DebugString() << " success.";
|
||||
ge_graph.reset();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CalcOpRunningParam(const AnfNodePtr &node) { return true; }
|
||||
|
||||
void *GetHcclOpsKernelInfoStore() { return ops_kernel_info_store.get(); }
|
||||
|
||||
std::string GetHcclType(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
return GetGeNodeName(cnode);
|
||||
}
|
||||
} // namespace mindspore::hccl
|
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H
|
||||
#define MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "mindspore/core/ir/anf.h"
|
||||
#include "external/hccl/hccl_types.h"
|
||||
|
||||
#define MS_API __attribute__((visibility("default")))
|
||||
|
||||
namespace mindspore::hccl {
|
||||
struct MS_API HcclTaskInfo {
|
||||
std::string private_def;
|
||||
int64_t workspace_size;
|
||||
int64_t stream_num;
|
||||
};
|
||||
|
||||
MS_API bool InitHccl(uint32_t device_id, std::string_view rank_id, std::string_view rank_file);
|
||||
MS_API bool FinalizeHccl();
|
||||
MS_API bool GenTask(const AnfNodePtr &node, HcclDataType datatype, std::vector<HcclTaskInfo> *task_info_lists);
|
||||
MS_API bool CalcOpRunningParam(const AnfNodePtr &node);
|
||||
MS_API void *GetHcclOpsKernelInfoStore();
|
||||
MS_API std::string GetHcclType(const AnfNodePtr &node);
|
||||
} // namespace mindspore::hccl
|
||||
#undef MS_API
|
||||
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCCL_ADAPTER_H
|
@ -0,0 +1,32 @@
|
||||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_RUNTIME_HCCL_ADAPTER_HCOM_GRAPH_ADAPTOR_H
|
||||
#define MINDSPORE_RUNTIME_HCCL_ADAPTER_HCOM_GRAPH_ADAPTOR_H
|
||||
|
||||
#include <string>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include "mindspore/core/ir/anf.h"
|
||||
#include "common/opskernel/ops_kernel_info_store.h"
|
||||
|
||||
extern "C" {
|
||||
ge::Status Initialize(const std::map<std::string, std::string> &);
|
||||
ge::Status Finalize();
|
||||
void GetOpsKernelInfoStores(std::map<std::string, std::shared_ptr<ge::OpsKernelInfoStore>> &);
|
||||
}
|
||||
|
||||
#endif // MINDSPORE_RUNTIME_HCCL_ADAPTER_HCOM_GRAPH_ADAPTOR_H
|
Loading…
Reference in new issue