From dfa7b79d75232c06adae13a64c658a04d9bfd423 Mon Sep 17 00:00:00 2001 From: wangcong Date: Fri, 15 May 2020 19:57:14 +0800 Subject: [PATCH] buffer fusion code refactor --- .../ascend/ascend_backend_optimization.cc | 38 ++ .../ascend/ascend_backend_optimization.h | 1 + .../ascend/buffer_fusion/buffer_fusion.cc | 3 + .../pass/bnupdate_eltwise_fusion_pass.cc | 96 ++++ .../pass/bnupdate_eltwise_fusion_pass.h | 50 ++ .../pass/depthwiseconv_eltwise_fusion_pass.cc | 107 +++++ .../pass/depthwiseconv_eltwise_fusion_pass.h | 50 ++ .../buffer_fusion/pass/fusion_base_pass.cc | 38 ++ .../buffer_fusion/pass/fusion_base_pass.h | 50 ++ .../pass/fusion_type_fusion_pass.cc | 245 ++++++++++ .../pass/fusion_type_fusion_pass.h | 47 ++ .../ascend/buffer_fusion/tbe_buffer_fusion.cc | 435 ++++++++++++++++++ .../ascend/buffer_fusion/tbe_buffer_fusion.h | 50 ++ 13 files changed, 1210 insertions(+) create mode 100644 mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/bnupdate_eltwise_fusion_pass.cc create mode 100644 mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/bnupdate_eltwise_fusion_pass.h create mode 100644 mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/depthwiseconv_eltwise_fusion_pass.cc create mode 100644 mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/depthwiseconv_eltwise_fusion_pass.h create mode 100644 mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.cc create mode 100644 mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h create mode 100644 mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_type_fusion_pass.cc create mode 100644 mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_type_fusion_pass.h create mode 100644 mindspore/ccsrc/pre_activate/ascend/buffer_fusion/tbe_buffer_fusion.cc create mode 100644 mindspore/ccsrc/pre_activate/ascend/buffer_fusion/tbe_buffer_fusion.h diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc index 16a82500b4..bfa2022a80 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.cc @@ -60,6 +60,10 @@ #include "pre_activate/ascend/format_type/merge_cast_to_op.h" #include "pre_activate/ascend/format_type/check_consistency.h" #include "pre_activate/ascend/buffer_fusion/buffer_fusion.h" +#include "pre_activate/ascend/buffer_fusion/tbe_buffer_fusion.h" +#include "pre_activate/ascend/buffer_fusion/pass/depthwiseconv_eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/pass/bnupdate_eltwise_fusion_pass.h" +#include "pre_activate/ascend/buffer_fusion/pass/fusion_type_fusion_pass.h" #include "pre_activate/ascend/format_type/deal_ref_trans_and_cast.h" #include "pre_activate/ascend/enhancer/add_memcpy_async.h" #include "pre_activate/ascend/format_type/insert_cast_for_runop.h" @@ -296,5 +300,39 @@ void AscendBackendOptimization(const std::shared_ptr &kern DumpIRProto(kernel_graph, "after_hwopt"); } } + +void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (!context_ptr->ir_fusion_flag()) { + MS_LOG(INFO) << "UBFusion is not enable, skip"; + return; + } + bool save_graphs = context_ptr->save_graphs_flag(); + auto save_graphs_path = context_ptr->save_graphs_path(); + if (save_graphs_path.empty()) { + save_graphs_path = "."; + } + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_ub_fusion_before.ir"; + DumpIR(file_path, kernel_graph); + } + auto fusion_id_allocator = std::make_shared(); + MS_EXCEPTION_IF_NULL(fusion_id_allocator); + fusion_id_allocator->Init(); + auto optimizer = std::make_shared(); + auto ub_fusion_pm = std::make_shared("ub_fusion_pm"); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); + ub_fusion_pm->AddPass(std::make_shared(fusion_id_allocator.get())); + ub_fusion_pm->AddPass(std::make_shared()); + optimizer->AddPassManager(ub_fusion_pm); + (void)optimizer->Optimize(kernel_graph); + kernel_graph->SetExecOrderByDefault(); + if (save_graphs) { + std::string file_path = save_graphs_path + "/" + "hwopt_d_ub_fusion_after.ir"; + DumpIR(file_path, kernel_graph); + } +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h index 65e70def85..914b4c053a 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_backend_optimization.h @@ -26,6 +26,7 @@ void AscendDataLayout(const std::shared_ptr &kernel_graph) void AscendMixPrecision(const std::shared_ptr &kernel_graph); void AscendBackendOptimization(const std::shared_ptr &kernel_graph); void AscendBackendIRFusionOptimization(const std::shared_ptr &kernel_graph); +void AscendBackendUBFusionOptimization(const std::shared_ptr &kernel_graph); } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc index 694ee4d236..e44e3dff5d 100644 --- a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/buffer_fusion.cc @@ -616,6 +616,9 @@ void BufferFusion::MatchOpNamePattern(const session::KernelGraph &kernel_graph, MatchBnupdateAddRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); } else if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); + } else if (eltwise_input->isa() && + AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) { + MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); } } } else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/bnupdate_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/bnupdate_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..ade115acde --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/bnupdate_eltwise_fusion_pass.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/buffer_fusion/pass/bnupdate_eltwise_fusion_pass.h" +#include +#include +#include +#include +#include "kernel/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "session/anf_runtime_algorithm.h" +#include "operator/ops.h" +#include "utils/context/ms_context.h" +#include "pre_activate/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void BnupdateEltwiseFusionPass::MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + auto getitem = relu_input->cast(); + auto bnupdate = getitem->input(1); + if (bnupdate->isa() && AnfAlgo::GetCNodeName(bnupdate) == kBNTrainingUpdateOpName) { + std::vector output_used_num(AnfAlgo::GetOutputTensorNum(bnupdate), 0); + for (auto out_getitem : manager->node_users()[bnupdate]) { + auto out_getitem_ptr = out_getitem.first->cast(); + auto input2 = out_getitem_ptr->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + output_used_num[output_idx] = SizeToInt(manager->node_users()[out_getitem.first].size()); + } + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), bnupdate); + std::unordered_set record{cnode, bnupdate}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } +} + +void BnupdateEltwiseFusionPass::MatchBnupdateOpNamePattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + auto eltwise_input = cnode->input(1); + if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) { + if (eltwise_input->isa() && AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimTupleGetItem)) { + MatchBnupdateRelu(cnode, eltwise_input, kernel_graph, candidate_fusion); + } + } + } + } +} + +bool BnupdateEltwiseFusionPass::MatchUBFusionPattern(const session::KernelGraph &kernel_graph) { + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + auto return_node = kernel_graph.get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->inputs().size() <= 1) { + return false; + } + MS_LOG(DEBUG) << "MatchBufferFusionPattern start..."; + FusedNodeRecord candidate_fusion; + + MatchBnupdateOpNamePattern(kernel_graph, &candidate_fusion); + if (candidate_fusion.empty()) { + return false; + } + MS_LOG(DEBUG) << "MatchBufferFusionPattern Success..."; + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/bnupdate_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/bnupdate_eltwise_fusion_pass.h new file mode 100644 index 0000000000..ab112d004a --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/bnupdate_eltwise_fusion_pass.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h" +#include "ir/anf.h" +#include "pre_activate/common/pass.h" +#include "pre_activate/common/fusion_id_allocator.h" +#include "device/kernel_info.h" +#include "kernel/kernel.h" +#include "session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class BnupdateEltwiseFusionPass : public FusionBasePass { + public: + BnupdateEltwiseFusionPass() : FusionBasePass("BnupdateEltwiseFusionPass") {} + explicit BnupdateEltwiseFusionPass(FusionIdAllocator *idAllocator) + : FusionBasePass("BnupdateEltwiseFusionPass", idAllocator) {} + ~BnupdateEltwiseFusionPass() override = default; + bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph) override; + + private: + void MatchBnupdateRelu(const CNodePtr &cnode, const AnfNodePtr &relu_input, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion); + void MatchBnupdateOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_BNUPDATE_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/depthwiseconv_eltwise_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/depthwiseconv_eltwise_fusion_pass.cc new file mode 100644 index 0000000000..124b61bf1a --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/depthwiseconv_eltwise_fusion_pass.cc @@ -0,0 +1,107 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/buffer_fusion/pass/depthwiseconv_eltwise_fusion_pass.h" + +#include +#include +#include +#include +#include "kernel/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "session/anf_runtime_algorithm.h" +#include "operator/ops.h" +#include "utils/context/ms_context.h" +#include "pre_activate/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void DepthwiseConvEltwiseFusionPass::MatchDepthwiseConvRelu(const CNodePtr &cnode, + const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion, bool is_order) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(candidate_fusion); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + if (is_order) { + // DepthwiseConvolution--->Elemwise + auto depthwise_conv = cnode->input(1); + MS_EXCEPTION_IF_NULL(depthwise_conv); + if (cnode->isa() && IsPrimitiveCNode(depthwise_conv, prim::kPrimDepthwiseConv2dNative)) { + std::vector output_used_num{SizeToInt(manager->node_users()[depthwise_conv].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), depthwise_conv); + std::unordered_set record{cnode, depthwise_conv}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + } else { + // Elemwise-->DepthwiseConvolution + auto relu = cnode->input(1); + MS_EXCEPTION_IF_NULL(relu); + if (cnode->isa() && (IsPrimitiveCNode(relu, prim::kPrimRelu) || IsPrimitiveCNode(relu, prim::kPrimReluV2))) { + std::vector output_used_num{SizeToInt(manager->node_users()[relu].size())}; + AnfAlgo::SetNodeAttr(kAttrOutputUsedNum, MakeValue(output_used_num), relu); + std::unordered_set record{cnode, relu}; + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + } +} + +void DepthwiseConvEltwiseFusionPass::MatchDepthwiseOpNamePattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + MS_EXCEPTION_IF_NULL(candidate_fusion); + std::vector node_list = TopoSort(kernel_graph.get_return()); + for (auto &node : node_list) { + if (!AnfAlgo::IsRealCNodeKernel(node) || fusion_id_allocator->HasFusionIdAttr(node) || + AnfAlgo::CheckPrimitiveType(node, prim::kPrimReturn)) { + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) { + auto eltwise_input = cnode->input(1); + if (AnfAlgo::GetCNodeName(cnode) == kReluV2OpName || AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimRelu)) { + if (eltwise_input->isa() && + AnfAlgo::CheckPrimitiveType(eltwise_input, prim::kPrimDepthwiseConv2dNative)) { + MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, true); + } + } + } else if (AnfAlgo::GetCNodeName(cnode) == prim::kPrimDepthwiseConv2dNative->name()) { + MatchDepthwiseConvRelu(cnode, kernel_graph, candidate_fusion, false); + } + } +} + +bool DepthwiseConvEltwiseFusionPass::MatchUBFusionPattern(const session::KernelGraph &kernel_graph) { + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + auto return_node = kernel_graph.get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->inputs().size() <= 1) { + return false; + } + MS_LOG(DEBUG) << "MatchBufferFusionPattern start..."; + FusedNodeRecord candidate_fusion; + MatchDepthwiseOpNamePattern(kernel_graph, &candidate_fusion); + if (candidate_fusion.empty()) { + return false; + } + MS_LOG(DEBUG) << "MatchBufferFusionPattern Success..."; + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/depthwiseconv_eltwise_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/depthwiseconv_eltwise_fusion_pass.h new file mode 100644 index 0000000000..66ede52d04 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/depthwiseconv_eltwise_fusion_pass.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ + +#include +#include + +#include "pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h" +#include "ir/anf.h" +#include "pre_activate/common/pass.h" +#include "pre_activate/common/fusion_id_allocator.h" +#include "device/kernel_info.h" +#include "kernel/kernel.h" +#include "session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class DepthwiseConvEltwiseFusionPass : public FusionBasePass { + public: + DepthwiseConvEltwiseFusionPass() : FusionBasePass("DepthwiseConvEltwiseFusionPass") {} + explicit DepthwiseConvEltwiseFusionPass(FusionIdAllocator *idAllocator) + : FusionBasePass("DepthwiseConvEltwiseFusionPass", idAllocator) {} + ~DepthwiseConvEltwiseFusionPass() override = default; + bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph) override; + + private: + void MatchDepthwiseConvRelu(const CNodePtr &cnode, const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion, bool is_order); + void MatchDepthwiseOpNamePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_DEPTHWISECONV_ELTWISE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.cc new file mode 100644 index 0000000000..e993bdeb78 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h" +#include +#include +#include "debug/anf_ir_dump.h" +#include "utils/context/ms_context.h" +#include "pre_activate/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +void FusionBasePass::SetRecordFusionId(const std::unordered_set &record) { + auto id = fusion_id_allocator->AllocateFusionId(); + for (auto node : record) { + fusion_id_allocator->SetFusionId(node, id); + } +} +bool FusionBasePass::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + auto kernel_graph = graph->cast>(); + MS_EXCEPTION_IF_NULL(kernel_graph); + return MatchUBFusionPattern(*kernel_graph); +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h new file mode 100644 index 0000000000..4a6161ca08 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ +#include +#include +#include +#include + +#include "ir/anf.h" +#include "pre_activate/common/pass.h" +#include "pre_activate/common/fusion_id_allocator.h" +#include "device/kernel_info.h" +#include "kernel/kernel.h" +#include "session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class FusionBasePass : public Pass { + public: + explicit FusionBasePass(const std::string &name) : Pass(name) {} + FusionBasePass(const std::string &name, FusionIdAllocator *idAllocator) + : Pass(name), fusion_id_allocator(idAllocator) {} + ~FusionBasePass() override = default; + bool Run(const FuncGraphPtr &graph) override; + + protected: + virtual bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph) = 0; + void SetRecordFusionId(const std::unordered_set &record); + FusionIdAllocator *fusion_id_allocator; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_BASE_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_type_fusion_pass.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_type_fusion_pass.cc new file mode 100644 index 0000000000..4ea98e580b --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_type_fusion_pass.cc @@ -0,0 +1,245 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/buffer_fusion/pass/fusion_type_fusion_pass.h" + +#include +#include +#include +#include +#include +#include + +#include "kernel/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "session/anf_runtime_algorithm.h" +#include "utils/context/ms_context.h" +#include "pre_activate/common/fusion_id_allocator.h" + +namespace mindspore { +namespace opt { +namespace { +const int8_t MAX_PATTERN_SIZE = 7; +const int8_t MIN_PATTERN_SIZE = 2; +const int8_t ELTWISE_INPUT_SIZE = 2; +const int8_t ELTWISE_USE = 1; +const int8_t MULTI_ELTWISE_USE = 2; +const int8_t MAX_MULTI_ELTWISE_SIZE = 4; +const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3; +constexpr auto kOpAttrFusionId = "fusion_id"; + +bool CheckEltWiseNode(FuncGraphManager *manager, std::unordered_set *record, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(record); + auto user_nodes = manager->node_users()[node]; + return (AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL && + AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && + (user_nodes.size() <= ELTWISE_USE || record->size() == 0)); +} + +// Common method to check for predecessors and successors in a fusion pattern +std::tuple FindPredAndSuccEltWiseNodes(const int8_t &max_size, FuncGraphManager *manager, + std::unordered_set *visited_set, + std::deque *todo, + std::unordered_set *record, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(visited_set); + MS_EXCEPTION_IF_NULL(todo); + MS_EXCEPTION_IF_NULL(record); + MS_EXCEPTION_IF_NULL(node); + + CNodePtr new_node = node; + if (new_node->inputs().size() < ELTWISE_INPUT_SIZE) { + return std::make_tuple(false, new_node); + } + int8_t index = 1; + auto &users = manager->node_users(); + while (CheckEltWiseNode(manager, record, new_node)) { + (void)record->insert(new_node); + (void)visited_set->insert(new_node); + (void)todo->insert(todo->end(), new_node->inputs().begin() + 1, new_node->inputs().end()); + + auto cnode = new_node->input(1); + MS_EXCEPTION_IF_NULL(cnode); + if (!cnode->isa()) { + return std::make_tuple(false, new_node); + } + new_node = cnode->cast(); + MS_EXCEPTION_IF_NULL(new_node); + + if (!AnfAlgo::IsRealKernel(new_node) || new_node->inputs().size() < ELTWISE_INPUT_SIZE || + users[(new_node)].size() >= MULTI_ELTWISE_USE || visited_set->find(new_node) != visited_set->end()) { + return std::make_tuple(false, new_node); + } + + if (index >= max_size) { + break; + } + index++; + } + return std::make_tuple(true, new_node); +} + +std::tuple MatchGeneralPattern(FuncGraphManager *manager, std::unordered_set *record, + std::unordered_set *visited_set, + std::deque *todo, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(record); + MS_EXCEPTION_IF_NULL(visited_set); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(todo); + CNodePtr new_node = node; + auto &users = manager->node_users(); + if (users[(new_node)].size() >= MULTI_ELTWISE_USE) { + return std::make_tuple(false, new_node); + } + + (void)record->insert(node); + (void)visited_set->insert(node); + (void)todo->insert(todo->end(), new_node->inputs().begin() + 1, new_node->inputs().end()); + + if (node->inputs().size() < 2) { + return std::make_tuple(false, new_node); + } + // only check the first real input, will check all + auto cnode = node->input(1); + MS_EXCEPTION_IF_NULL(cnode); + if (!cnode->isa()) { + return std::make_tuple(false, new_node); + } + new_node = cnode->cast(); + MS_EXCEPTION_IF_NULL(new_node); + + if (!AnfAlgo::IsRealKernel(new_node) || users[(new_node)].size() >= MULTI_ELTWISE_USE || + visited_set->find(new_node) != visited_set->end()) { + return std::make_tuple(false, new_node); + } + return std::make_tuple(true, new_node); +} + +CNodePtr FindFusionAnfNode(FuncGraphManager *manager, std::unordered_set *visited_set, + std::unordered_set *record, std::deque *todo, const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(visited_set); + MS_EXCEPTION_IF_NULL(record); + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(todo); + // find fusion pattern predecessor nodes + auto ret = FindPredAndSuccEltWiseNodes(MAX_MULTI_ELTWISE_SIZE, manager, visited_set, todo, record, node); + auto new_node = std::get<1>(ret); + auto node_use_size = manager->node_users()[new_node].size(); + if (!std::get<0>(ret) || (record->size() > 1 && node_use_size > 1) || record->size() >= MAX_MULTI_ELTWISE_SIZE || + AnfAlgo::GetKernelType(new_node) != KernelType::TBE_KERNEL) { + return new_node; + } + + // key of fusion precessor + auto node_fusion_type = AnfAlgo::GetFusionType(new_node); + switch (node_fusion_type) { + case kernel::FusionType::COMMREDUCE: + case kernel::FusionType::SEGMENT: + ret = MatchGeneralPattern(manager, record, visited_set, todo, new_node); + new_node = std::get<1>(ret); + if (!std::get<0>(ret)) { + return new_node; + } + break; + case kernel::FusionType::ELEMWISE: + return new_node; + // -fallthrough to default and return + case kernel::FusionType::CONVLUTION: + (void)record->insert(new_node); + default: + (void)visited_set->insert(new_node); + if (new_node != nullptr) { + (void)todo->insert(todo->end(), new_node->inputs().begin() + 1, new_node->inputs().end()); + } + return new_node; + } + // find fusion pattern successor nodes + ret = FindPredAndSuccEltWiseNodes(MAX_PURE_BUFFER_SUCC_SIZE, manager, visited_set, todo, record, new_node); + return std::get<1>(ret); +} +} // namespace + +void FusionTypeFusionPass::MatchFusionTypePattern(const session::KernelGraph &kernel_graph, + FusedNodeRecord *candidate_fusion) { + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + MS_EXCEPTION_IF_NULL(candidate_fusion); + + auto return_node = kernel_graph.get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->inputs().size() <= 1) { + return; + } + std::deque todo; + todo.push_back(return_node->input(1)); + std::unordered_set visited_set; + + while (!todo.empty()) { + auto node = todo.front(); + MS_EXCEPTION_IF_NULL(node); + todo.pop_front(); + std::unordered_set record; + if (visited_set.find(node) != visited_set.end() || fusion_id_allocator->HasFusionIdAttr(node)) { + continue; + } + // Only fuse real cnode + if (!AnfAlgo::IsRealCNodeKernel(node)) { + auto cnode = node->cast(); + if (cnode != nullptr) { + (void)todo.insert(todo.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + } + continue; + } + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // cnode maybe updated + cnode = FindFusionAnfNode(manager.get(), &visited_set, &record, &todo, cnode); + if (record.size() >= MIN_PATTERN_SIZE && record.size() <= MAX_PATTERN_SIZE) { + candidate_fusion->push_back(record); + SetRecordFusionId(record); + } + if (record.find(cnode) == record.end()) { + todo.push_back(cnode); + } + // no node matched + if (record.size() == 0) { + (void)visited_set.insert(node); + } + (void)todo.insert(todo.end(), cnode->inputs().begin() + 1, cnode->inputs().end()); + } +} + +bool FusionTypeFusionPass::MatchUBFusionPattern(const session::KernelGraph &kernel_graph) { + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + auto return_node = kernel_graph.get_return(); + MS_EXCEPTION_IF_NULL(return_node); + if (return_node->inputs().size() <= 1) { + return false; + } + MS_LOG(DEBUG) << "MatchBufferFusionPattern start..."; + FusedNodeRecord candidate_fusion; + MatchFusionTypePattern(kernel_graph, &candidate_fusion); + if (candidate_fusion.empty()) { + return false; + } + MS_LOG(DEBUG) << "MatchBufferFusionPattern Success..."; + return true; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_type_fusion_pass.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_type_fusion_pass.h new file mode 100644 index 0000000000..2e008633b4 --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/pass/fusion_type_fusion_pass.h @@ -0,0 +1,47 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_TYPE_FUSION_PASS_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_TYPE_FUSION_PASS_H_ +#include +#include +#include + +#include "pre_activate/ascend/buffer_fusion/pass/fusion_base_pass.h" +#include "ir/anf.h" +#include "pre_activate/common/pass.h" +#include "pre_activate/common/fusion_id_allocator.h" +#include "device/kernel_info.h" +#include "kernel/kernel.h" +#include "session/kernel_graph.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class FusionTypeFusionPass : public FusionBasePass { + public: + FusionTypeFusionPass() : FusionBasePass("FusionTypeFusionPass") {} + explicit FusionTypeFusionPass(FusionIdAllocator *idAllocator) : FusionBasePass("FusionTypeFusionPass", idAllocator) {} + ~FusionTypeFusionPass() override = default; + bool MatchUBFusionPattern(const session::KernelGraph &kernel_graph) override; + + private: + void MatchFusionTypePattern(const session::KernelGraph &kernel_graph, FusedNodeRecord *candidate_fusion); +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_PASS_FUSION_TYPE_FUSION_PASS_H_ diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/tbe_buffer_fusion.cc b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/tbe_buffer_fusion.cc new file mode 100644 index 0000000000..7eb9ac7bdf --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/tbe_buffer_fusion.cc @@ -0,0 +1,435 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "pre_activate/ascend/buffer_fusion/tbe_buffer_fusion.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "kernel/kernel_fusion.h" +#include "debug/anf_ir_dump.h" +#include "session/anf_runtime_algorithm.h" +#include "operator/ops.h" +#include "device/kernel_info.h" +#include "utils/context/ms_context.h" + +namespace mindspore { +namespace opt { +namespace { +const int8_t MAX_PATTERN_SIZE = 7; +const int8_t MIN_PATTERN_SIZE = 2; +const int8_t ELTWISE_INPUT_SIZE = 2; +const int8_t ELTWISE_USE = 1; +const int8_t MULTI_ELTWISE_USE = 2; +const int8_t MAX_MULTI_ELTWISE_SIZE = 4; +const int8_t MAX_PURE_BUFFER_SUCC_SIZE = 3; +constexpr auto kOpAttrFusionId = "fusion_id"; + +#ifdef DEBUG +std::string GetFusionTypeName(const kernel::FusionType &type) { + switch (type) { + case kernel::FusionType::COMMREDUCE: + return "COMMREDUCE"; + case kernel::FusionType::SEGMENT: + return "SEGMENT"; + case kernel::FusionType::ELEMWISE: + return "ELEMWISE"; + case kernel::FusionType::CONVLUTION: + return "CONVLUTION"; + case kernel::FusionType::OPAQUE: + return "OPAQUE"; + default: + return "OPAQUE"; + } +} + +void DumpFusionScopeInfo(const kernel::FusionScopeInfo &info) { + MS_LOG(INFO) << "=== Dump FusionScopeInfo start id: " << info.scope_id; + for (auto &node : info.input_nodes) { + MS_LOG(INFO) << "=== Input: " << node->DebugString(); + } + for (auto &node : info.output_nodes) { + MS_LOG(INFO) << "=== Output: " << node->DebugString(); + } + for (auto &node : info.compute_nodes) { + MS_LOG(INFO) << "=== Compute: (" << node->DebugString() << ")-(" << GetFusionTypeName(AnfAlgo::GetFusionType(node)) + << ")"; + } + MS_LOG(INFO) << "=== Dump FusionScopeInfo end"; +} +#endif +CNodePtr CreateFusionOp(const std::vector &inputs_list, const std::vector &outputs_list, + const std::vector &anf_nodes, session::KernelGraph *kernel_graph) { + MS_LOG(DEBUG) << "Start Create FusionOp Kernel"; + MS_EXCEPTION_IF_NULL(kernel_graph); + std::string fusion_op_name = "FusionOp"; + for (auto node : anf_nodes) { + fusion_op_name += '_' + AnfAlgo::GetCNodeName(node); + } + auto fusion_op = std::make_shared(fusion_op_name); + MS_EXCEPTION_IF_NULL(fusion_op); + + std::vector input_names; + for (uint8_t i = 0; i < inputs_list.size(); i++) { + input_names.emplace_back("input" + std::to_string(i)); + } + std::vector output_names; + for (uint8_t i = 0; i < outputs_list.size(); i++) { + output_names.emplace_back("output" + std::to_string(i)); + } + + ValuePtr input_names_v = MakeValue(input_names); + ValuePtr output_names_v = MakeValue(output_names); + fusion_op->set_attr("input_names", input_names_v); + fusion_op->set_attr("output_names", output_names_v); + std::vector fusion_inputs_list = inputs_list; + auto value_node = std::make_shared(fusion_op); + (void)fusion_inputs_list.insert(fusion_inputs_list.begin(), value_node); + auto buffer_fusion_kernel = kernel_graph->NewCNode(fusion_inputs_list); + if (buffer_fusion_kernel == nullptr) { + MS_LOG(EXCEPTION) << "New FusionOp kernel failed!"; + } + buffer_fusion_kernel->set_scope((anf_nodes.back())->scope()); + + return buffer_fusion_kernel; +} + +kernel::KernelBuildInfoPtr CreateFusionOpKernelInfo(const std::vector &inputs_list, + const std::vector &outputs_list) { + MS_LOG(DEBUG) << "Start Create Kernel Info"; + kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; + // inputs format and data type + std::vector inputs_format; + std::vector inputs_data_type; + for (const auto &input : inputs_list) { + auto real_input = AnfAlgo::VisitKernel(input, 0); + inputs_format.push_back(AnfAlgo::GetOutputFormat(real_input.first, real_input.second)); + inputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(real_input.first, real_input.second)); + } + // outputs format and data type + std::vector outputs_format; + std::vector outputs_data_type; + for (const auto &output : outputs_list) { + if (AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { + auto tuple_getitem = output->cast(); + MS_EXCEPTION_IF_NULL(tuple_getitem); + outputs_format.push_back(AnfAlgo::GetOutputFormat( + tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); + outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType( + tuple_getitem->input(1), IntToSize(GetValue(GetValueNode(tuple_getitem->input(2)))))); + } else { + outputs_format.push_back(AnfAlgo::GetOutputFormat(output, 0)); + outputs_data_type.push_back(AnfAlgo::GetOutputDeviceDataType(output, 0)); + } + } + builder.SetInputsFormat(inputs_format); + builder.SetInputsDeviceType(inputs_data_type); + builder.SetOutputsFormat(outputs_format); + builder.SetOutputsDeviceType(outputs_data_type); + builder.SetKernelType(KernelType::TBE_KERNEL); + return builder.Build(); +} + +AnfNodePtr CreateTupleGetItem(const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph, + size_t output_index) { + MS_EXCEPTION_IF_NULL(kernel_graph); + std::vector tuple_getitem_inputs_list; + auto value = std::make_shared(prim::kPrimTupleGetItem); + MS_EXCEPTION_IF_NULL(value); + auto idx = NewValueNode(SizeToInt(output_index)); + MS_EXCEPTION_IF_NULL(idx); + int temp = SizeToInt(output_index); + auto imm = std::make_shared(temp); + auto abstract_scalar = std::make_shared(imm); + idx->set_abstract(abstract_scalar); + tuple_getitem_inputs_list.push_back(value); + tuple_getitem_inputs_list.push_back(buffer_fusion_kernel); + tuple_getitem_inputs_list.push_back(idx); + auto tuple_item = kernel_graph->NewCNode(tuple_getitem_inputs_list); + MS_EXCEPTION_IF_NULL(tuple_item); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(buffer_fusion_kernel, output_index)}, + {AnfAlgo::GetOutputInferShape(buffer_fusion_kernel, output_index)}, + tuple_item.get()); + return tuple_item; +} + +void ReplaceInputNodeInOtherFusionScope(std::unordered_map *buffer_fusion_infos, + int32_t fusion_id, const AnfNodePtr &output_item, + const AnfNodePtr &replace_item) { + for (int32_t id = fusion_id + 1; id <= SizeToInt(buffer_fusion_infos->size()); ++id) { + auto itr = std::find((*buffer_fusion_infos)[id].inputs_list.begin(), (*buffer_fusion_infos)[id].inputs_list.end(), + output_item); + if (itr != (*buffer_fusion_infos)[id].inputs_list.end()) { + MS_LOG(DEBUG) << "replace input of other pattern, id = " << id; + *itr = replace_item; + } + } +} + +void ReplaceOldNode(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, + const AnfNodePtr &buffer_fusion_kernel, session::KernelGraph *kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; + if (buffer_fusion_info.outputs_list.size() == 1) { // single output + (void)manager->Replace(buffer_fusion_info.outputs_list[0], buffer_fusion_kernel); + ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[0], + buffer_fusion_kernel); + } else { // multiple output + for (size_t index = 0; index < buffer_fusion_info.outputs_list.size(); ++index) { + auto tuple_item = CreateTupleGetItem(buffer_fusion_kernel, kernel_graph, index); + (void)manager->Replace(buffer_fusion_info.outputs_list[index], tuple_item); + ReplaceInputNodeInOtherFusionScope(buffer_fusion_infos, fusion_id, buffer_fusion_info.outputs_list[index], + tuple_item); + } + } +} + +void GetFusionScopeComputeNodeList(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + auto nodes = TopoSort(kernel_graph->get_return()); + for (auto &node : nodes) { + MS_EXCEPTION_IF_NULL(node); + if (!node->isa()) { + continue; + } + auto cnode = node->cast(); + if (AnfAlgo::IsRealCNodeKernel(cnode) && AnfAlgo::HasNodeAttr(kOpAttrFusionId, cnode)) { + auto fusion_id = AnfAlgo::GetNodeAttr(cnode, kOpAttrFusionId); + (*buffer_fusion_infos)[fusion_id].anf_nodes.push_back(cnode); + } + } +} + +void GetFusionScopeInputNodeList(const session::KernelGraph &kernel_graph, + std::unordered_map *buffer_fusion_infos) { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + auto manager = kernel_graph.manager(); + MS_EXCEPTION_IF_NULL(manager); + + for (auto &buffer_fusion_info : *buffer_fusion_infos) { + auto fusion_id = buffer_fusion_info.first; + auto fusion_info = buffer_fusion_info.second; + for (const auto &node : fusion_info.anf_nodes) { + auto cnode = node->cast(); + for (size_t idx = 1; idx < cnode->inputs().size(); ++idx) { + auto real_input = AnfAlgo::VisitKernel(cnode->input(idx), 0); + if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), real_input.first) == + fusion_info.anf_nodes.end()) { + if (std::find((*buffer_fusion_infos)[fusion_id].inputs_list.begin(), + (*buffer_fusion_infos)[fusion_id].inputs_list.end(), + cnode->input(idx)) == (*buffer_fusion_infos)[fusion_id].inputs_list.end()) { + (*buffer_fusion_infos)[fusion_id].inputs_list.push_back(cnode->input(idx)); + } + } + } + } + } +} + +bool TupleGetitemNodeCompare(const AnfNodePtr &node1, const AnfNodePtr &node2) { + MS_EXCEPTION_IF_NULL(node1); + MS_EXCEPTION_IF_NULL(node2); + auto getitem1 = node1->cast(); + auto getitem2 = node2->cast(); + MS_EXCEPTION_IF_NULL(getitem1); + MS_EXCEPTION_IF_NULL(getitem2); + auto output_idx1 = GetValue(GetValueNode(getitem1->input(2))); + auto output_idx2 = GetValue(GetValueNode(getitem2->input(2))); + return output_idx1 < output_idx2; +} + +void GetFusionScopeOutputNodeList(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + + for (auto &buffer_fusion_info : *buffer_fusion_infos) { + auto fusion_id = buffer_fusion_info.first; + auto fusion_info = buffer_fusion_info.second; + for (const auto &node : fusion_info.anf_nodes) { + if (AnfAlgo::GetOutputTensorNum(node) == 1) { + for (auto use_node : manager->node_users()[node]) { + if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), use_node.first) == + fusion_info.anf_nodes.end()) { + (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(node); + break; + } + } + } else { + int prev_idx = 0; + std::vector tuple_getitem_nodes; + std::transform(manager->node_users()[node].begin(), manager->node_users()[node].end(), + std::back_inserter(tuple_getitem_nodes), + [](const std::pair &use_node) { return use_node.first; }); + std::sort(tuple_getitem_nodes.begin(), tuple_getitem_nodes.end(), TupleGetitemNodeCompare); + for (auto getitem : tuple_getitem_nodes) { + auto getitem_ptr = getitem->cast(); + auto input2 = getitem_ptr->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + for (int stub_idx = prev_idx; stub_idx < output_idx; ++stub_idx) { + auto stub_node = CreateTupleGetItem(node, kernel_graph, IntToSize(stub_idx)); + (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(stub_node); + } + prev_idx = output_idx + 1; + for (auto item_use_node : manager->node_users()[getitem]) { + if (std::find(fusion_info.anf_nodes.begin(), fusion_info.anf_nodes.end(), item_use_node.first) == + fusion_info.anf_nodes.end()) { + (*buffer_fusion_infos)[fusion_id].outputs_list.push_back(getitem); + break; + } + } + } + } + } + } +} + +void SetFusionOpRefInfos(session::KernelGraph *kernel_graph, const std::vector &outputs_list, + const AnfNodePtr &fusion_kernel) { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto manager = kernel_graph->manager(); + MS_EXCEPTION_IF_NULL(manager); + for (size_t idx = 0; idx < outputs_list.size(); ++idx) { + auto output = outputs_list[idx]; + if (output->isa() && AnfAlgo::GetCNodeName(output) == prim::kPrimTupleGetItem->name()) { + auto real_output = AnfAlgo::VisitKernel(output, 0); + auto output_cnode = output->cast(); + MS_EXCEPTION_IF_NULL(output_cnode); + auto input2 = output_cnode->input(2); + auto output_idx = GetValue(GetValueNode(input2)); + session::AnfWithOutIndex out_pair(real_output.first, output_idx); + if (kernel_graph->IsInRefOutputMap(out_pair)) { + auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); + session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); + kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); + } + } else { + session::AnfWithOutIndex out_pair(output, 0); + if (kernel_graph->IsInRefOutputMap(out_pair)) { + auto origin_pair = kernel_graph->GetRefCorrespondOutput(out_pair); + session::AnfWithOutIndex fusion_final_pair(fusion_kernel, idx); + kernel_graph->AddRefCorrespondPairs(fusion_final_pair, origin_pair); + } + } + } +} +} // namespace + +void TbeBufferFusion::GetBufferFusionInfo(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) const { + MS_EXCEPTION_IF_NULL(buffer_fusion_infos); + GetFusionScopeComputeNodeList(kernel_graph, buffer_fusion_infos); + GetFusionScopeInputNodeList(*kernel_graph, buffer_fusion_infos); + GetFusionScopeOutputNodeList(kernel_graph, buffer_fusion_infos); + for (auto &buffer_fusion_info : *buffer_fusion_infos) { + buffer_fusion_info.second.kernel_build_info = + CreateFusionOpKernelInfo(buffer_fusion_info.second.inputs_list, buffer_fusion_info.second.outputs_list); + } +} + +bool TbeBufferFusion::FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + bool change = false; + std::unordered_map buffer_fusion_infos; + buffer_fusion_infos.clear(); + GetBufferFusionInfo(kernel_graph, &buffer_fusion_infos); + + std::vector fusion_scope_infos; + for (auto &buffer_fusion_info : buffer_fusion_infos) { + mindspore::kernel::FusionScopeInfo fusion_scope_info; + fusion_scope_info.scope_id = buffer_fusion_info.first; + fusion_scope_info.input_nodes = buffer_fusion_info.second.inputs_list; + fusion_scope_info.compute_nodes = buffer_fusion_info.second.anf_nodes; + fusion_scope_info.output_nodes = buffer_fusion_info.second.outputs_list; + fusion_scope_infos.push_back(fusion_scope_info); +#ifdef DEBUG + DumpFusionScopeInfo(fusion_scope_info); +#endif + } + auto kernel_mods = mindspore::kernel::KernelFusion(fusion_scope_infos); + std::vector fusion_ids; + for (auto &buffer_fusion_info : buffer_fusion_infos) { + MS_LOG(DEBUG) << "anf node size: " << buffer_fusion_info.second.anf_nodes.size() + << ", inputs_list size: " << buffer_fusion_info.second.inputs_list.size() + << ", outputs list size: " << buffer_fusion_info.second.outputs_list.size(); + fusion_ids.push_back(buffer_fusion_info.first); + } + // Replace fusion op from return to head + std::sort(fusion_ids.begin(), fusion_ids.end()); + for (auto &fusion_id : fusion_ids) { + // Get kernel mod when supporting tbe + if (kernel_mods.find(fusion_id) == kernel_mods.end() || kernel_mods[fusion_id] == nullptr) { + MS_LOG(DEBUG) << "fusion id: " << fusion_id << ", fusion op compiling failed"; + continue; + } + change = ReplaceFusionOp(&buffer_fusion_infos, fusion_id, kernel_mods[fusion_id], kernel_graph); + } + MS_LOG(DEBUG) << "End Buffer Fusion"; + return change; +} + +bool TbeBufferFusion::ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, + int32_t fusion_id, const kernel::KernelModPtr &kernel_ptr, + session::KernelGraph *kernel_graph) const { + auto buffer_fusion_info = (*buffer_fusion_infos)[fusion_id]; + auto buffer_fusion = CreateFusionOp(buffer_fusion_info.inputs_list, buffer_fusion_info.outputs_list, + buffer_fusion_info.anf_nodes, kernel_graph); + AnfAlgo::SetSelectKernelBuildInfo(buffer_fusion_info.kernel_build_info, buffer_fusion.get()); + // Set abstract of fusion_op node + std::vector types; + std::vector> shapes; + for (const auto &out_node : buffer_fusion_info.outputs_list) { + for (size_t idx = 0; idx < AnfAlgo::GetOutputTensorNum(out_node); ++idx) { + types.push_back(AnfAlgo::GetOutputInferDataType(out_node, idx)); + shapes.push_back(AnfAlgo::GetOutputInferShape(out_node, idx)); + } + } + if (types.empty() || shapes.empty()) { + MS_LOG(WARNING) << "buffer_fusion_info.outputs_list is empty"; + return false; + } + AnfAlgo::SetOutputInferTypeAndShape(types, shapes, buffer_fusion.get()); + AnfAlgo::SetKernelMod(kernel_ptr, buffer_fusion.get()); + SetFusionOpRefInfos(kernel_graph, buffer_fusion_info.outputs_list, buffer_fusion); + ReplaceOldNode(buffer_fusion_infos, fusion_id, buffer_fusion, kernel_graph); + return true; +} + +bool TbeBufferFusion::Run(const FuncGraphPtr &graph) { + bool changed = false; + MS_EXCEPTION_IF_NULL(graph); + auto kernel_graph = graph->cast>(); + MS_EXCEPTION_IF_NULL(kernel_graph); + changed = FuseBufferFusionPattern(kernel_graph.get()); + // clear fusion_id attr + for (auto &node : graph->nodes()) { + if (node != nullptr && node->isa()) { + AnfAlgo::EraseNodeAttr(kAttrFusionId, node); + } + } + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/tbe_buffer_fusion.h b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/tbe_buffer_fusion.h new file mode 100644 index 0000000000..09f62c3ddd --- /dev/null +++ b/mindspore/ccsrc/pre_activate/ascend/buffer_fusion/tbe_buffer_fusion.h @@ -0,0 +1,50 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_TBE_BUFFER_FUSION_H_ +#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_TBE_BUFFER_FUSION_H_ +#include +#include +#include + +#include "ir/anf.h" +#include "pre_activate/common/pass.h" +#include "pre_activate/common/fusion_id_allocator.h" +#include "device/kernel_info.h" +#include "kernel/kernel.h" +#include "session/kernel_graph.h" +#include "pre_activate/ascend/buffer_fusion/buffer_fusion.h" + +namespace mindspore { +namespace opt { +using FusedNodeRecord = std::vector>; + +class TbeBufferFusion : public Pass { + public: + TbeBufferFusion() : Pass("TbeBufferFusion") {} + ~TbeBufferFusion() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + void GetBufferFusionInfo(session::KernelGraph *kernel_graph, + std::unordered_map *buffer_fusion_infos) const; + bool ReplaceFusionOp(std::unordered_map *buffer_fusion_infos, int32_t fusion_id, + const kernel::KernelModPtr &kernel_ptr, session::KernelGraph *kernel_graph) const; + bool FuseBufferFusionPattern(session::KernelGraph *kernel_graph) const; +}; +} // namespace opt +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_BUFFER_FUSION_TBE_BUFFER_FUSION_H_