|
|
|
@ -40,62 +40,6 @@ bool IsParameterOrValueNode(const AnfNodePtr &node) {
|
|
|
|
|
return real_node->isa<ValueNode>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetInput(const CNodePtr &control_depend, const int index, const FuncGraphPtr &graph, const CNodePtr &hccl_node,
|
|
|
|
|
const std::vector<AnfNodePtr> &memcpy_async_list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(control_depend);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(hccl_node);
|
|
|
|
|
std::vector<AnfNodePtr> make_tuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
|
|
|
|
|
make_tuple_inputs.insert(make_tuple_inputs.end(), memcpy_async_list.begin(), memcpy_async_list.end());
|
|
|
|
|
make_tuple_inputs.emplace_back(hccl_node);
|
|
|
|
|
auto make_tuple = graph->NewCNode(make_tuple_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(make_tuple);
|
|
|
|
|
control_depend->set_input(IntToSize(index), make_tuple);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void DealControlForGetitem(const CNodePtr &tuple_getitem, const FuncGraphPtr &graph, const CNodePtr &hccl_node,
|
|
|
|
|
const std::vector<AnfNodePtr> &memcpy_async_list) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tuple_getitem);
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
auto &node_users = manager->node_users();
|
|
|
|
|
auto iter = node_users.find(tuple_getitem);
|
|
|
|
|
if (iter == node_users.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node has no output in manager"
|
|
|
|
|
<< " trace: " << trace::DumpSourceLines(hccl_node);
|
|
|
|
|
}
|
|
|
|
|
for (const auto &node_index : iter->second) {
|
|
|
|
|
AnfNodePtr output = node_index.first;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output);
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
|
|
|
|
|
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void TransferControl(const CNodePtr &hccl_node, const std::vector<AnfNodePtr> &memcpy_async_list,
|
|
|
|
|
const FuncGraphPtr &graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(hccl_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
auto &node_users = manager->node_users();
|
|
|
|
|
auto iter = node_users.find(hccl_node);
|
|
|
|
|
if (iter == node_users.end()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "node has no output in manager"
|
|
|
|
|
<< " trace: " << trace::DumpSourceLines(hccl_node);
|
|
|
|
|
}
|
|
|
|
|
// find hccl_node's output which is a control depend
|
|
|
|
|
for (const auto &node_index : iter->second) {
|
|
|
|
|
AnfNodePtr output = node_index.first;
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output);
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimControlDepend)) {
|
|
|
|
|
SetInput(output->cast<CNodePtr>(), node_index.second, graph, hccl_node, memcpy_async_list);
|
|
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(output, prim::kPrimTupleGetItem)) {
|
|
|
|
|
DealControlForGetitem(output->cast<CNodePtr>(), graph, hccl_node, memcpy_async_list);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// NodeUsersMap, for node B input i use node A, it will be one item in map with key: A, and value: (B, i)
|
|
|
|
|
bool IsNodeOutPutUsedByOtherRealKernel(const AnfNodeIndexSet &node_users) {
|
|
|
|
|
if (node_users.size() == 1) {
|
|
|
|
@ -155,7 +99,7 @@ bool InsertMemcpyAsyncForHcclOp::NeedInsertMemcpy(const FuncGraphPtr &graph, con
|
|
|
|
|
void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, const CNodePtr &hccl_node) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(hccl_node);
|
|
|
|
|
std::vector<AnfNodePtr> memcpy_async_list;
|
|
|
|
|
bool need_memcpy_async = false;
|
|
|
|
|
std::vector<AnfNodePtr> new_inputs = {hccl_node->input(0)};
|
|
|
|
|
for (size_t i = 1; i < hccl_node->size(); ++i) {
|
|
|
|
|
auto input = hccl_node->input(i);
|
|
|
|
@ -164,17 +108,17 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
|
|
|
|
|
if (memcpy_async == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Create memcpy_async op failed.";
|
|
|
|
|
}
|
|
|
|
|
if (AnfAlgo::IsNodeDynamicShape(input)) {
|
|
|
|
|
if (input->isa<CNode>() && AnfAlgo::IsNodeDynamicShape(input)) {
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), memcpy_async);
|
|
|
|
|
}
|
|
|
|
|
new_inputs.push_back(memcpy_async);
|
|
|
|
|
memcpy_async_list.push_back(memcpy_async);
|
|
|
|
|
need_memcpy_async = true;
|
|
|
|
|
} else {
|
|
|
|
|
new_inputs.push_back(input);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!memcpy_async_list.empty()) {
|
|
|
|
|
if (need_memcpy_async) {
|
|
|
|
|
CNodePtr new_hccl_node = std::make_shared<CNode>(*hccl_node);
|
|
|
|
|
new_hccl_node->set_inputs(new_inputs);
|
|
|
|
|
auto manager = graph->manager();
|
|
|
|
@ -182,9 +126,6 @@ void InsertMemcpyAsyncForHcclOp::InsertMemcpyAsync(const FuncGraphPtr &graph, co
|
|
|
|
|
MS_LOG(DEBUG) << "start replace new_hccl_node to old hccl_node";
|
|
|
|
|
(void)manager->Replace(hccl_node, new_hccl_node);
|
|
|
|
|
MS_LOG(DEBUG) << "end replace";
|
|
|
|
|
|
|
|
|
|
// transer hccl op's control to the memcpy_async
|
|
|
|
|
TransferControl(new_hccl_node, memcpy_async_list, graph);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|