|
|
|
@ -24,6 +24,7 @@
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <iterator>
|
|
|
|
|
|
|
|
|
|
#include "kernel/kernel_fusion.h"
|
|
|
|
|
#include "debug/anf_ir_dump.h"
|
|
|
|
@ -461,6 +462,36 @@ void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector<AnfNodePtr> &outputs_list,
|
|
|
|
|
const AnfNodePtr &fusion_kernel) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
|
|
|
|
auto manager = kernel_graph->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
for (size_t idx = 0; idx < outputs_list.size(); ++idx) {
|
|
|
|
|
auto output = outputs_list[idx];
|
|
|
|
|
if (output->isa<CNode>() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) {
|
|
|
|
|
auto real_output = AnfAlgo::VisitKernel(output, 0);
|
|
|
|
|
auto output_cnode = output->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_cnode);
|
|
|
|
|
auto input2 = output_cnode->input(2);
|
|
|
|
|
auto output_idx = GetValue<int>(GetValueNode(input2));
|
|
|
|
|
session::AnfWithOutIndex out_pair(real_output.first, output_idx);
|
|
|
|
|
if (kernel_graph->IsInRefOutputMap(out_pair)) {
|
|
|
|
|
auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair);
|
|
|
|
|
session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx);
|
|
|
|
|
kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
session::AnfWithOutIndex out_pair(output, 0);
|
|
|
|
|
if (kernel_graph->IsInRefOutputMap(out_pair)) {
|
|
|
|
|
auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair);
|
|
|
|
|
session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx);
|
|
|
|
|
kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MatchConvBnreduce(const CNodePtr &cnode, const session::KernelGraph &kernel_graph,
|
|
|
|
|
std::unordered_set<AnfNodePtr> *fused_set, FusedNodeRecord *candidate_fusion) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
@ -708,7 +739,7 @@ bool BufferFusion::ReplaceFusionOp(std::unordered_map<int32_t, BufferFusionInfo_
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get());
|
|
|
|
|
AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get());
|
|
|
|
|
// replace node
|
|
|
|
|
SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion);
|
|
|
|
|
ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph);
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|