|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
#include "backend/optimizer/pass/communication_op_fusion.h"
|
|
|
|
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
|
|
|
|
@ -89,6 +90,13 @@ std::string GetFusionGroupKey(const AnfNodePtr &node) {
|
|
|
|
|
}
|
|
|
|
|
return group + op + std::to_string(fusion);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CheckInputs(const std::vector<AnfNodePtr> &fusion_inputs) {
|
|
|
|
|
std::set<AnfNodePtr> inputs_set(fusion_inputs.begin(), fusion_inputs.end());
|
|
|
|
|
if (inputs_set.size() < fusion_inputs.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Different communication op in one segment cannot share the same input";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
bool CommunicationOpFusion::GetSplitSegments(const CommunicationOpInfo &communication_op_info, size_t *segment_num,
|
|
|
|
@ -163,6 +171,7 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
fusion_inputs.insert(fusion_inputs.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
|
|
|
|
}
|
|
|
|
|
CheckInputs(fusion_inputs);
|
|
|
|
|
AnfNodePtr fused_node = func_graph->NewCNode(fusion_inputs);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fused_node);
|
|
|
|
|
auto kernel_info = std::make_shared<device::KernelInfo>();
|
|
|
|
@ -172,9 +181,6 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
|
|
|
|
|
for (size_t idx = start_index; idx <= end_index; ++idx) {
|
|
|
|
|
auto cnode = communication_op_info.communication_op_nodes[idx];
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
AnfAlgo::CopyNodeAttr("fusion", cnode, fused_node);
|
|
|
|
|
AnfAlgo::CopyNodeAttr("op", cnode, fused_node);
|
|
|
|
|
AnfAlgo::CopyNodeAttr("group", cnode, fused_node);
|
|
|
|
|
abstract_list.push_back(cnode->abstract());
|
|
|
|
|
}
|
|
|
|
|
auto kernel_build_info = GenerateKernelBuildInfo(communication_op_info, start_index, end_index);
|
|
|
|
@ -182,6 +188,9 @@ AnfNodePtr CommunicationOpFusion::CreateFusedCommunicationOp(const FuncGraphPtr
|
|
|
|
|
auto abstract_tuple = std::make_shared<abstract::AbstractTuple>(abstract_list);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
|
|
|
|
fused_node->set_abstract(abstract_tuple);
|
|
|
|
|
AnfAlgo::CopyNodeAttr("fusion", communication_op_info.communication_op_nodes[end_index], fused_node);
|
|
|
|
|
AnfAlgo::CopyNodeAttr("op", communication_op_info.communication_op_nodes[end_index], fused_node);
|
|
|
|
|
AnfAlgo::CopyNodeAttr("group", communication_op_info.communication_op_nodes[end_index], fused_node);
|
|
|
|
|
return fused_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|