diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 4011e0fe14..806c592f97 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -65,6 +65,8 @@ #include "pre_activate/ascend/buffer_fusion/buffer_fusion.h" #include "pre_activate/ascend/buffer_fusion/ub_pattern_fusion.h" #include "pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" #include "pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h" #include "pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h" #include "pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h" @@ -336,16 +338,18 @@ void AscendBackendUBFusionOptimization(const std::shared_ptrInit(); auto optimizer = std::make_shared(); auto ub_fusion_pm = std::make_shared("ub_fusion_pm"); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); - ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator)); ub_fusion_pm->AddPass(std::make_shared()); optimizer->AddPassManager(ub_fusion_pm); (void)optimizer->Optimize(kernel_graph); diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h index 59b7b25d8d..6cdc5885f6 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_eltwise_fusion_pass.h @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector>; class BnupdateEltwiseEltwiseFusionPass : public FusionBasePass { public: - explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocator *idAllocator) + explicit BnupdateEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("BnupdateEltwiseEltwiseFusionPass", idAllocator) {} ~BnupdateEltwiseEltwiseFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h index b9db5c68b7..b5688f3a36 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/bnupdate_eltwise_fusion_pass.h @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector>; class BnupdateEltwiseFusionPass : public FusionBasePass { public: - explicit BnupdateEltwiseFusionPass(FusionIdAllocator *idAllocator) + explicit BnupdateEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {} ~BnupdateEltwiseFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..c90d2a17cd --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.cc @@ -0,0 +1,76 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "kernel/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "session/anf_runtime_algorithm.h" +#include "operator/ops.h" +#include "utils/context/ms_context.h" +#include "pre_activate/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void Conv2DBackpropEltwiseEltwiseFusionPass::MatchConv2DBackpropInputEltwiseEltwise( + const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + if (CheckDoubleInEltWiseNode(manager.get(), eltwise_input)) { + (void)record.insert(eltwise_input); + } else { + return; + } + auto input_cnode = eltwise_input->cast(); + MS_EXCEPTION_IF_NULL(input_cnode); + auto double_in_eltwise_input = input_cnode->input(1); + if (!double_in_eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(double_in_eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(double_in_eltwise_input)) { + return; + } + if (AnfAlgo::CheckPrimitiveType(double_in_eltwise_input, prim::kPrimConv2DBackpropInput)) { + (void)record.insert(double_in_eltwise_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && + (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { + MatchConv2DBackpropInputEltwiseEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h new file mode 100644 index 0000000000..7d779d35f8 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_eltwise_fusion_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "pre_activate/common/pass.h" +#include "pre_activate/common/fusion_id_allocator.h" +#include "device/kernel_info.h" +#include "kernel/kernel.h" +#include "session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class Conv2DBackpropEltwiseEltwiseFusionPass : public FusionBasePass { + public: + explicit Conv2DBackpropEltwiseEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("Conv2DBackpropEltwiseEltwiseFusionPass", idAllocator) {} + ~Conv2DBackpropEltwiseEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConv2DBackpropInputEltwiseEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..a18d578f7f --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.cc @@ -0,0 +1,69 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "kernel/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "session/anf_runtime_algorithm.h" +#include "operator/ops.h" +#include "utils/context/ms_context.h" +#include "pre_activate/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void Conv2DBackpropEltwiseFusionPass::MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + std::unordered_set record{cnode}; + auto eltwise_input = cnode->input(1); + if (!eltwise_input->isa() || !AnfAlgo::IsRealCNodeKernel(eltwise_input) || + fusion_id_allocator->HasFusionIdAttr(eltwise_input)) { + return; + } + if (AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimConv2DBackpropInput)) { + (void)record.insert(eltwise_input); + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void Conv2DBackpropEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && + (cnode->inputs().size() == ELTWISE_INPUT_SIZE || cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE)) { + MatchConv2DBackpropInputEltwise(cnode, kernel_graph, candidate_fusion); + } + } +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h new file mode 100644 index 0000000000..171352de9b --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv2dbackprop_eltwise_fusion_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "pre_activate/ascend/buffer_fusion/fusion_base_pass.h" +#include "ir/anf.h" +#include "pre_activate/common/pass.h" +#include "pre_activate/common/fusion_id_allocator.h" +#include "device/kernel_info.h" +#include "kernel/kernel.h" +#include "session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class Conv2DBackpropEltwiseFusionPass : public FusionBasePass { + public: + explicit Conv2DBackpropEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) + : FusionBasePass("Conv2DBackpropEltwiseFusionPass", idAllocator) {} + ~Conv2DBackpropEltwiseFusionPass() override = default; + void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; + + private: + void MatchConv2DBackpropInputEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_CONV2DBACKPROP_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h index aa835b5ba7..7a06faa624 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_bnreduce_fusion_pass.h @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector>; class ConvBnReduceFusionPass : public FusionBasePass { public: - explicit ConvBnReduceFusionPass(FusionIdAllocator *idAllocator) + explicit ConvBnReduceFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("ConvBnReduceFusionPass", idAllocator) {} ~ConvBnReduceFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc index 4c6902816c..c4bfb96109 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.cc @@ -27,19 +27,6 @@ namespace mindspore { namespace opt { -bool ConvDoubleInFusionPass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { - MS_EXCEPTION_IF_NULL(manager); - if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { - return false; - } - auto cnode = node->cast(); - MS_EXCEPTION_IF_NULL(cnode); - auto user_nodes = manager->node_users()[node]; - return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && - AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && - cnode->inputs().size() == ELTWISE_INPUT_SIZE; -} - void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) { MS_EXCEPTION_IF_NULL(cnode); diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h index 6bcc40789a..062b8182fb 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_double_in_fusion_pass.h @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector>; class ConvDoubleInFusionPass : public FusionBasePass { public: - explicit ConvDoubleInFusionPass(FusionIdAllocator *idAllocator) + explicit ConvDoubleInFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("ConvDoubleInFusionPass", idAllocator) {} ~ConvDoubleInFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; @@ -41,7 +41,6 @@ class ConvDoubleInFusionPass : public FusionBasePass { private: void MatchConvDoubleInEltwise(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); - bool CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h index 2824b6c883..bf7e581dff 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/conv_single_in_fusion_pass.h @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector>; class ConvSingleInFusionPass : public FusionBasePass { public: - explicit ConvSingleInFusionPass(FusionIdAllocator *idAllocator) + explicit ConvSingleInFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("ConvSingleInFusionPass", idAllocator) {} ~ConvSingleInFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h index 05d473bd1a..c2e72f26ff 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/depthwiseconv_eltwise_fusion_pass.h @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector>; class DepthwiseConvEltwiseFusionPass : public FusionBasePass { public: - explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocator *idAllocator) + explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("DepthwiseConvEltwiseFusionPass", idAllocator) {} ~DepthwiseConvEltwiseFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h index 8cf9796e98..54ff0f5982 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/eltwise_fusion_pass.h @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector>; class EltwiseFusionPass : public FusionBasePass { public: - explicit EltwiseFusionPass(FusionIdAllocator *idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {} + explicit EltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("EltwiseFusionPass", idAllocator) {} ~EltwiseFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc index 51e39ac9fd..3f5dd98112 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.cc @@ -36,6 +36,19 @@ bool FusionBasePass::CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePt cnode->inputs().size() == ELTWISE_INPUT_SIZE; } +bool FusionBasePass::CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + if (!node->isa() || !AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node)) { + return false; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + auto user_nodes = manager->node_users()[node]; + return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && user_nodes.size() == ELTWISE_USE && + cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE; +} + void FusionBasePass::SetRecordFusionId(const std::unordered_set &record) { auto id = fusion_id_allocator->AllocateFusionId(); for (auto node : record) { diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h index c44508318e..8d2ed81607 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/fusion_base_pass.h @@ -32,13 +32,14 @@ namespace opt { const int8_t MAX_ELTWISE_NUM = 3; const int8_t MIN_ELTWISE_SIZE = 2; const int8_t ELTWISE_INPUT_SIZE = 2; +const int8_t ELTWISE_DOUBLE_IN_INPUT_SIZE = 3; const int8_t ELTWISE_USE = 1; const int8_t MAX_ELTWISE_SIZE = 6; using FusedNodeRecord = std::vector>; class FusionBasePass : public Pass { public: - FusionBasePass(const std::string &name, FusionIdAllocator *idAllocator) + FusionBasePass(const std::string &name, FusionIdAllocatorPtr idAllocator) : Pass(name), fusion_id_allocator(idAllocator) {} ~FusionBasePass() override = default; bool Run(const FuncGraphPtr &graph) override; @@ -49,7 +50,8 @@ class FusionBasePass : public Pass { FusedNodeRecord *candidate_fusion) = 0; void SetRecordFusionId(const std::unordered_set &record); bool CheckEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); - FusionIdAllocator *fusion_id_allocator; + bool CheckDoubleInEltWiseNode(FuncGraphManager *manager, const AnfNodePtr &node); + FusionIdAllocatorPtr fusion_id_allocator; }; } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h index 00178ee678..5baaa6db86 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/matmul_eltwise_fusion_pass.h @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector>; class MatmulEltwiseFusionPass : public FusionBasePass { public: - explicit MatmulEltwiseFusionPass(FusionIdAllocator *idAllocator) + explicit MatmulEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("MatmulEltwiseFusionPass", idAllocator) {} ~MatmulEltwiseFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h index 082cbf99a0..42d896e96b 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/reduce_eltwise_fusion_pass.h @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector>; class ReduceEltwiseFusionPass : public FusionBasePass { public: - explicit ReduceEltwiseFusionPass(FusionIdAllocator *idAllocator) + explicit ReduceEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("ReduceEltwiseFusionPass", idAllocator) {} ~ReduceEltwiseFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h index c774d2a8bf..41f06ba1f9 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/segment_eltwise_fusion_pass.h @@ -33,7 +33,7 @@ using FusedNodeRecord = std::vector>; class SegmentEltwiseFusionPass : public FusionBasePass { public: - explicit SegmentEltwiseFusionPass(FusionIdAllocator *idAllocator) + explicit SegmentEltwiseFusionPass(FusionIdAllocatorPtr idAllocator) : FusionBasePass("SegmentEltwiseFusionPass", idAllocator) {} ~SegmentEltwiseFusionPass() override = default; void MatchSingleFusionPattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion) override; diff --git a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h b/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h index 8d0cbb3e76..91e83600f2 100644 --- a/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h +++ b/mindspore/ccsrc/pre_activate/common/fusion_id_allocator.h @@ -16,6 +16,7 @@ #ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ #define MINDSPORE_CCSRC_PRE_ACTIVATE_COMMON_FUSION_ID_ALLOCATOR_H_ +#include #include "ir/base.h" namespace mindspore { @@ -36,6 +37,7 @@ class FusionIdAllocator { private: int32_t fusion_id; }; +using FusionIdAllocatorPtr = std::shared_ptr; } // namespace opt } // namespace mindspore