From 4ff50db64d4d95b676fb8af24a29a0890f1bd2f9 Mon Sep 17 00:00:00 2001 From: alouhahaha Date: Sat, 27 Mar 2021 16:38:00 +0800 Subject: [PATCH] add depend for fused allgather --- .../ascend/ascend_backend_optimization.cc | 2 + .../enhancer/insert_depend_for_all_gather.cc | 70 +++++++++++++++++++ .../enhancer/insert_depend_for_all_gather.h | 44 ++++++++++++ 3 files changed, 116 insertions(+) create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_depend_for_all_gather.cc create mode 100644 mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_depend_for_all_gather.h diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 3908ed00d3..ad2aa247d4 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -114,6 +114,7 @@ #include "backend/optimizer/ascend/ir_fission/concat_fission.h" #include "backend/optimizer/ascend/ir_fission/pack_fission.h" #include "backend/optimizer/ascend/enhancer/concat_outputs_for_all_gather.h" +#include "backend/optimizer/ascend/enhancer/insert_depend_for_all_gather.h" #include "backend/optimizer/ascend/enhancer/split_inputs_for_reduce_scatter.h" #include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_rnn.h" #include "backend/optimizer/ascend/enhancer/add_placeholder_for_dynamic_gru.h" @@ -361,6 +362,7 @@ void AscendBackendOptimization(const std::shared_ptr &kern other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); + other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); other_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_depend_for_all_gather.cc b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_depend_for_all_gather.cc new file mode 100644 index 0000000000..a313491d70 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_depend_for_all_gather.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/enhancer/insert_depend_for_all_gather.h" +#include +#include +#include "utils/utils.h" +#include "backend/optimizer/ascend/ascend_helper.h" +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +bool InsertDependForAllGather::Run(const FuncGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(graph); + bool changed = false; + std::vector node_list = TopoSort(graph->get_return()); + std::map all_gather_node; + for (auto &node : node_list) { + MS_EXCEPTION_IF_NULL(node); + if (!node->cast() || !AnfAlgo::IsRealKernel(node)) { + continue; + } + auto cnode = node->cast(); + if (AnfAlgo::GetCNodeName(cnode) == kAllGatherOpName && AnfAlgo::HasNodeAttr(kAttrFusion, cnode) && + AnfAlgo::GetNodeAttr(cnode, kAttrFusion) > 0) { + all_gather_node[AnfAlgo::GetNodeAttr(cnode, kAttrFusion)] = node; + } + } + std::vector depends = {NewValueNode(prim::kPrimMakeTuple)}; + auto iter = all_gather_node.begin(); + for (int64_t i = 0; i < SizeToInt(all_gather_node.size()) - 1; ++i) { + auto current_node = iter->second; + auto next_node = (++iter)->second; + auto next_cnode = next_node->cast(); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), + AnfAlgo::GetInputNode(next_cnode, 0), current_node}; + auto new_input = graph->NewCNode(inputs); + new_input->set_abstract(AnfAlgo::GetInputNode(next_cnode, 0)->abstract()); + AnfAlgo::SetNodeInput(next_cnode, new_input, 0); + depends.push_back(new_input); + } + if (depends.size() > 1) { + auto make_tuple = graph->NewCNode(depends); + auto return_node = graph->get_return(); + auto return_cnode = return_node->cast(); + std::vector inputs = {NewValueNode(std::make_shared(prim::kPrimDepend->name())), + AnfAlgo::GetInputNode(return_cnode, 0), make_tuple}; + auto depend_node = graph->NewCNode(inputs); + depend_node->set_abstract(AnfAlgo::GetInputNode(return_cnode, 0)->abstract()); + AnfAlgo::SetNodeInput(return_cnode, depend_node, 0); + changed = true; + } + + return changed; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_depend_for_all_gather.h b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_depend_for_all_gather.h new file mode 100644 index 0000000000..6f12e9f437 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/enhancer/insert_depend_for_all_gather.h @@ -0,0 +1,44 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_GATHER_H_ +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_GATHER_H_ +#include +#include +#include +#include + +#include "backend/optimizer/common/pass.h" +#include "ir/func_graph.h" +#include "ir/anf.h" +#include "backend/optimizer/common/helper.h" +#include "backend/optimizer/common/optimizer.h" +#include "backend/optimizer/ascend/ascend_helper.h" + +namespace mindspore { +namespace opt { +class InsertDependForAllGather : public Pass { + public: + InsertDependForAllGather() : Pass("insert_depend_for_all_gather"), kernel_select_(std::make_shared()) {} + ~InsertDependForAllGather() override = default; + bool Run(const FuncGraphPtr &graph) override; + + private: + KernelSelectPtr kernel_select_; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_INSERT_DEPEND_FOR_ALL_GATHER_H_