|
|
|
@ -387,7 +387,6 @@ void RemoveCircle(const session::KernelGraph &kernel_graph,
|
|
|
|
|
void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
|
|
|
|
std::unordered_map<int64_t, BufferFusionInfo_t> *buffer_fusion_infos) const {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(buffer_fusion_infos);
|
|
|
|
|
auto graph_id = kernel_graph->graph_id();
|
|
|
|
|
GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos);
|
|
|
|
|
GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos);
|
|
|
|
|
GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos);
|
|
|
|
@ -397,7 +396,11 @@ void UbPatternFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph,
|
|
|
|
|
for (auto &buffer_fusion_info : *buffer_fusion_infos) {
|
|
|
|
|
buffer_fusion_info.second.kernel_build_info =
|
|
|
|
|
CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list);
|
|
|
|
|
buffer_fusion_info.second.graph_id = graph_id;
|
|
|
|
|
// just for full_name_with_scope for every buffer_fusion_info.
|
|
|
|
|
auto fusion_node = CreateFusionOp(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list,
|
|
|
|
|
buffer_fusion_info.second.anf_nodes, kernel_graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(fusion_node);
|
|
|
|
|
buffer_fusion_info.second.full_name = fusion_node->fullname_with_scope();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -412,7 +415,7 @@ bool UbPatternFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph
|
|
|
|
|
buffer_fusion_infos.begin(), buffer_fusion_infos.end(), std::back_inserter(fusion_scope_infos),
|
|
|
|
|
[](const std::pair<int64_t, BufferFusionInfo_t> &buffer_fusion_info) -> mindspore::kernel::FusionScopeInfo {
|
|
|
|
|
return mindspore::kernel::FusionScopeInfo(
|
|
|
|
|
buffer_fusion_info.first, buffer_fusion_info.second.graph_id, buffer_fusion_info.second.inputs_list,
|
|
|
|
|
buffer_fusion_info.first, buffer_fusion_info.second.full_name, buffer_fusion_info.second.inputs_list,
|
|
|
|
|
buffer_fusion_info.second.anf_nodes, buffer_fusion_info.second.outputs_list);
|
|
|
|
|
});
|
|
|
|
|
auto kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos);
|
|
|
|
@ -447,6 +450,7 @@ bool UbPatternFusion::ReplaceFusionOp(std::unordered_map<int64_t, BufferFusionIn
|
|
|
|
|
TraceGuard guard(std::make_shared<TraceOpt>(buffer_fusion_info.anf_nodes[0]->debug_info()));
|
|
|
|
|
auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list,
|
|
|
|
|
buffer_fusion_info.anf_nodes, kernel_graph);
|
|
|
|
|
buffer_fusion->set_fullname_with_scope(buffer_fusion_info.full_name);
|
|
|
|
|
AnfAlgo::SetSelectKernelBuildInfo(buffer_fusion_info.kernel_build_info, buffer_fusion.get());
|
|
|
|
|
// Set abstract of fusion_op node
|
|
|
|
|
std::vector<TypeId> types;
|
|
|
|
|