From b296dac9dd3684ac0c49823709d7723cb3131f74 Mon Sep 17 00:00:00 2001 From: huanghui Date: Thu, 28 Jan 2021 16:39:45 +0800 Subject: [PATCH] fix do ub fusion with only one node --- .../ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc | 5 ++++- .../optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc index a64b0cf6b1..1a16060b5e 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/matmul_eltwise_fusion_pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,6 +33,9 @@ void MatmulEltwiseFusionPass::MatchMatmulEltwise(const CNodePtr &cnode, const An MS_EXCEPTION_IF_NULL(candidate_fusion); auto manager = kernel_graph.manager(); MS_EXCEPTION_IF_NULL(manager); + if (fusion_id_allocator->HasFusionIdAttr(relu_input)) { + return; + } std::vector output_used_num{SizeToLong(manager->node_users()[relu_input].size())}; AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu_input); std::unordered_set record{cnode, relu_input}; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc index 1bb68217a4..60b237da23 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/buffer_fusion/ub_pattern_fusion.cc @@ -1,5 +1,5 @@ /** - * Copyright 2019 Huawei Technologies Co., Ltd + * Copyright 2019-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -405,6 +405,9 @@ bool UbPatternFusion::ReplaceFusionOp(std::unordered_map(buffer_fusion_info.anf_nodes[0]->debug_info())); auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list, buffer_fusion_info.anf_nodes, kernel_graph);