|
|
|
@ -32,6 +32,7 @@
|
|
|
|
|
#include "operator/ops.h"
|
|
|
|
|
#include "device/kernel_info.h"
|
|
|
|
|
#include "utils/context/ms_context.h"
|
|
|
|
|
#include "pre_activate/common/fusion_id_allocator.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace opt {
|
|
|
|
@ -79,20 +80,6 @@ void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) {
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
void SetAnfNodeFusionId(const FusedNodeRecord &record_node) {
|
|
|
|
|
MS_LOG(DEBUG) << "Size of opt vector to be fused is " << record_node.size();
|
|
|
|
|
int32_t id = 1;
|
|
|
|
|
for (auto &record : record_node) {
|
|
|
|
|
MS_LOG(DEBUG) << "No" << id << ", opt vector to be fused contain " << record.size() << " opt.";
|
|
|
|
|
for (const auto &candidate : record) {
|
|
|
|
|
ValuePtr fusion_id_v = MakeValue(id);
|
|
|
|
|
AnfAlgo::SetNodeAttr(kOpAttrFusionId, fusion_id_v, candidate);
|
|
|
|
|
MS_LOG(DEBUG) << "No " << id << ": " << candidate->DebugString();
|
|
|
|
|
}
|
|
|
|
|
id++;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CheckEltWiseNode(FuncGraphManager *manager, std::unordered_set<AnfNodePtr> *record, const CNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(record);
|
|
|
|
@ -482,11 +469,18 @@ void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<A
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
|
|
|
|
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
void BufferFusion::SetRecordFusionId(const std::unordered_set<AnfNodePtr> &record) {
|
|
|
|
|
auto id = fusion_id_allocator.AllocateFusionId();
|
|
|
|
|
for (auto node : record) {
|
|
|
|
|
fusion_id_allocator.SetFusionId(node, id);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BufferFusion::MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
|
|
|
|
FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fused_set);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
|
|
|
|
auto manager = kernel_graph.manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
@ -496,14 +490,13 @@ void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), conv);
|
|
|
|
|
std::unordered_set<AnfNodePtr> record{cnode, conv};
|
|
|
|
|
candidate_fusion->push_back(record);
|
|
|
|
|
fused_set->insert(record.begin(), record.end());
|
|
|
|
|
SetRecordFusionId(record);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
|
|
|
|
|
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
void BufferFusion::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
|
|
|
|
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fused_set);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
|
|
|
|
auto manager = kernel_graph.manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
@ -520,14 +513,13 @@ void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, cons
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate);
|
|
|
|
|
std::unordered_set<AnfNodePtr> record{cnode, bnupdate};
|
|
|
|
|
candidate_fusion->push_back(record);
|
|
|
|
|
fused_set->insert(record.begin(), record.end());
|
|
|
|
|
SetRecordFusionId(record);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph,
|
|
|
|
|
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
void BufferFusion::MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input,
|
|
|
|
|
const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fused_set);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
|
|
|
|
auto manager = kernel_graph.manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
@ -548,41 +540,37 @@ void MatchBnupdateAddRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, c
|
|
|
|
|
AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate);
|
|
|
|
|
std::unordered_set<AnfNodePtr> record{cnode, relu_input, bnupdate};
|
|
|
|
|
candidate_fusion->push_back(record);
|
|
|
|
|
fused_set->insert(record.begin(), record.end());
|
|
|
|
|
SetRecordFusionId(record);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MatchOpNamePattern(const session::KernelGraph &kernel_graph, std::unordered_set<AnfNodePtr> *fused_set,
|
|
|
|
|
FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fused_set);
|
|
|
|
|
void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
|
|
|
|
std::vector<AnfNodePtr> node_list = TopoSort(kernel_graph.get_return());
|
|
|
|
|
for (auto &node : node_list) {
|
|
|
|
|
if (!AnfAlgo::IsRealCNodeKernel(node) || fused_set->find(node) != fused_set->end()) {
|
|
|
|
|
if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator.HasFusionIdAttr(node)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (AnfAlgo::GetCNodeName(cnode) == kBNTrainingReduceOpName) {
|
|
|
|
|
MatchConvBnreduce(cnode, kernel_graph, fused_set, candidate_fusion);
|
|
|
|
|
MatchConvBnreduce(cnode, kernel_graph, candidate_fusion);
|
|
|
|
|
} else if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName ||
|
|
|
|
|
AnfAlgo::GetCNodeName(cnode) == prim::kPrimRelu->name()) {
|
|
|
|
|
auto relu_input = cnode->input(1);
|
|
|
|
|
if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTensorAdd->name()) {
|
|
|
|
|
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, fused_set, candidate_fusion);
|
|
|
|
|
MatchBnupdateAddRelu(cnode, relu_input, kernel_graph, candidate_fusion);
|
|
|
|
|
} else if (relu_input->isa<CNode>() && AnfAlgo::GetCNodeName(relu_input) == prim::kPrimTupleGetItem->name()) {
|
|
|
|
|
MatchBnupdateRelu(cnode, relu_input, kernel_graph, fused_set, candidate_fusion);
|
|
|
|
|
MatchBnupdateRelu(cnode, relu_input, kernel_graph, candidate_fusion);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unordered_set<AnfNodePtr> *fused_set,
|
|
|
|
|
FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
void BufferFusion::MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
auto manager = kernel_graph.manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fused_set);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(candidate_fusion);
|
|
|
|
|
|
|
|
|
|
auto return_node = kernel_graph.get_return();
|
|
|
|
@ -599,7 +587,7 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
todo.pop_front();
|
|
|
|
|
std::unordered_set<AnfNodePtr> record;
|
|
|
|
|
if (visited_set.find(node) != visited_set.end() || fused_set->find(node) != fused_set->end()) {
|
|
|
|
|
if (visited_set.find(node) != visited_set.end() || fusion_id_allocator.HasFusionIdAttr(node)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// Only fuse real cnode
|
|
|
|
@ -616,7 +604,7 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
|
|
|
|
|
cnode = FindFusionAnfNode(manager.get(), &visited_set, &record, &todo, cnode);
|
|
|
|
|
if (record.size() >= MIN_PATTERN_SIZE && record.size() <= MAX_PATTERN_SIZE) {
|
|
|
|
|
candidate_fusion->push_back(record);
|
|
|
|
|
fused_set->insert(record.begin(), record.end());
|
|
|
|
|
SetRecordFusionId(record);
|
|
|
|
|
}
|
|
|
|
|
if (record.find(cnode) == record.end()) {
|
|
|
|
|
todo.push_back(cnode);
|
|
|
|
@ -628,7 +616,6 @@ void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, std::unord
|
|
|
|
|
(void)todo.insert(todo.end(), cnode->inputs().begin() + 1, cnode->inputs().end());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
void BufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
|
|
|
|
std::unordered_map<int32_t, BufferFusionInfo_t> *buffer_fusion_infos) const {
|
|
|
|
@ -684,7 +671,7 @@ bool BufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) c
|
|
|
|
|
return change;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) const {
|
|
|
|
|
bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_graph) {
|
|
|
|
|
auto manager = kernel_graph.manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
auto return_node = kernel_graph.get_return();
|
|
|
|
@ -694,14 +681,11 @@ bool BufferFusion::MatchBufferFusionPattern(const session::KernelGraph &kernel_g
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "MatchBufferFusionPattern start...";
|
|
|
|
|
FusedNodeRecord candidate_fusion;
|
|
|
|
|
std::unordered_set<AnfNodePtr> fused_set;
|
|
|
|
|
|
|
|
|
|
MatchOpNamePattern(kernel_graph, &fused_set, &candidate_fusion);
|
|
|
|
|
MatchFusionTypePattern(kernel_graph, &fused_set, &candidate_fusion);
|
|
|
|
|
MatchOpNamePattern(kernel_graph, &candidate_fusion);
|
|
|
|
|
MatchFusionTypePattern(kernel_graph, &candidate_fusion);
|
|
|
|
|
|
|
|
|
|
if (!candidate_fusion.empty()) {
|
|
|
|
|
SetAnfNodeFusionId(candidate_fusion);
|
|
|
|
|
} else {
|
|
|
|
|
if (candidate_fusion.empty()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "MatchBufferFusionPattern Success...";
|
|
|
|
@ -741,13 +725,14 @@ bool BufferFusion::Run(const FuncGraphPtr &graph) {
|
|
|
|
|
auto kernel_graph = graph->cast<std::shared_ptr<session::KernelGraph>>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
|
|
|
|
|
fusion_id_allocator.Init();
|
|
|
|
|
if (MatchBufferFusionPattern(*kernel_graph)) {
|
|
|
|
|
changed = FuseBufferFusionPattern(kernel_graph.get());
|
|
|
|
|
}
|
|
|
|
|
// clear fusion_id attr
|
|
|
|
|
for (auto &node : graph->nodes()) {
|
|
|
|
|
if (node != nullptr && node->isa<CNode>()) {
|
|
|
|
|
AnfAlgo::EraseNodeAttr(kOpAttrFusionId, node);
|
|
|
|
|
AnfAlgo::EraseNodeAttr(kAttrFusionId, node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return changed;
|
|
|
|
|