From 23a586b61da7ba6099f5c41ca37c88c620c98683 Mon Sep 17 00:00:00 2001 From: "Etone.Chan" Date: Tue, 28 Apr 2020 20:25:15 +0800 Subject: [PATCH] set RefInfo of Buffer Fusion kernel --- .../ascend/buffer_fusion/buffer_fusion.cc | 33 ++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc index 8581f1165d..851831383b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include "kernel/kernel_fusion.h" #include "debug/anf_ir_dump.h" @@ -461,6 +462,36 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, } } +void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector &outputs_list, + const AnfNodePtr &fusion_kernel) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + for (size_t idx = 0; idx < outputs_list.size(); ++idx) { + auto output = outputs_list[idx]; + if (output->isa() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { + auto real_output = AnfAlgo::VisitKernel(output, 0); + auto output_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + auto input2 = output_cnode->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + session::AnfWithOutIndex out_pair(real_output.first, output_idx); + if (kernel_graph->IsInRefOutputMap(out_pair)) { + auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); + session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); + kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); + } + } else { + session::AnfWithOutIndex out_pair(output, 0); + if (kernel_graph->IsInRefOutputMap(out_pair)) { + auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); + session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); + kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); + } + } + } +} + void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, std::unordered_set *fused_set, FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(cnode); @@ -708,7 +739,7 @@ bool BufferFusion::ReplaceFusionOp(std::unordered_map