diff --git a/ge/graph/passes/atomic_addr_clean_pass.cc b/ge/graph/passes/atomic_addr_clean_pass.cc index 7c6ed8ce..cfee1367 100755 --- a/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/ge/graph/passes/atomic_addr_clean_pass.cc @@ -222,6 +222,39 @@ Status AtomicAddrCleanPass::HandleNormalGraph(ComputeGraphPtr &graph, const vect } } } + return LinkToPotentialPrecedenceNode(graph, clean_addr_node); +} + +// Add control edges from atomic clean node to all potential precedence nodes which may execute before atomic clean +// node. We hope that atomic clean node can execute with the highest priority in the entire graph. Because of stream +// concurrency mechanism, only placing it at the head can not ensure that priority. Therefore, we need to add control +// edges from atomic clean node to the nodes that may be the first node on each stream. Generally, the first nodes on +// each stream are successors of Data/Variable, and Data/Variable won't generate task or execute, so we link to the +// successors of Data/Variable. +Status AtomicAddrCleanPass::LinkToPotentialPrecedenceNode(ComputeGraphPtr &graph, NodePtr &atomic_clean_node) { + GELOGD("Start to add control edges from %s to all second-nodes behind first-nodes which have no input.", + atomic_clean_node->GetName().c_str()); + auto out_ctrl_anchor = atomic_clean_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(out_ctrl_anchor); + + for (const auto &node : graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + bool need_handle = (node->GetType() == DATA || node->GetType() == VARIABLE) && node->GetInAllNodes().empty(); + if (!need_handle) { + continue; + } + auto second_nodes = node->GetOutAllNodes(); + for (const auto &second_node : second_nodes) { + GE_CHECK_NOTNULL(second_node); + auto in_ctrl_anchor = second_node->GetInControlAnchor(); + GE_CHECK_NOTNULL(in_ctrl_anchor); + if (!out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor)) { + GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(in_ctrl_anchor)); + GELOGD("Add control edge from %s to %s.", atomic_clean_node->GetName().c_str(), second_node->GetName().c_str()); + } + } + } + return SUCCESS; } diff --git a/ge/graph/passes/atomic_addr_clean_pass.h b/ge/graph/passes/atomic_addr_clean_pass.h index 8138d511..96147fa2 100755 --- a/ge/graph/passes/atomic_addr_clean_pass.h +++ b/ge/graph/passes/atomic_addr_clean_pass.h @@ -67,6 +67,14 @@ class AtomicAddrCleanPass : public GraphPass { */ Status LinkToAtomicNode(const NodePtr &atomic_node, NodePtr &atomic_clean_node); + /** + * Link atomic clean node to all potential precedence nodes which may execute before atomic clean node + * @param graph + * @param atomic_clean_node + * @return + */ + Status LinkToPotentialPrecedenceNode(ComputeGraphPtr &graph, NodePtr &atomic_clean_node); + /** * Check if this node is atomic op. * @param node diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index e4b8d8d2..734d6af5 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -606,6 +606,7 @@ set(PASS_TEST_FILES "graph/passes/variable_prepare_pass_unittest.cc" "graph/passes/variable_ref_delete_pass_unittest.cc" "graph/passes/dimension_adjust_pass_unittest.cc" + "graph/passes/atomic_addr_clean_pass_unittest.cc" "graph/passes/pass_utils_unittest.cc" "graph/passes/net_output_pass_unittest.cc" "graph/passes/no_use_reshape_remove_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc b/tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc new file mode 100644 index 00000000..d9d663d9 --- /dev/null +++ b/tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc @@ -0,0 +1,96 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "graph/passes/atomic_addr_clean_pass.h" +#include "common/op/ge_op_utils.h" +#include "common/types.h" +#include "graph/anchor.h" +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/op_desc.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "inc/pass_manager.h" +using namespace testing; + +namespace ge { +class UtestGraphPassesAtomicAddrCleanPass : public Test { +public: + UtestGraphPassesAtomicAddrCleanPass() { + graph_ = std::make_shared("test"); + } + + NodePtr NewNode(const string &name, const string &type, int input_cnt, int output_cnt) { + OpDescPtr op_desc = std::make_shared(name, type); + for (int i = 0; i < input_cnt; ++i) { + op_desc->AddInputDesc(GeTensorDesc()); + } + for (int i = 0; i < output_cnt; ++i) { + op_desc->AddOutputDesc(GeTensorDesc()); + } + NodePtr node = graph_->AddNode(op_desc); + return node; + } + + int CountOfAtomicCleanNode() { + int node_num = 0; + for (NodePtr &node : graph_->GetDirectNode()) { + if (node->GetType() == ATOMICADDRCLEAN) { + ++node_num; + } + } + return node_num; + } + + ComputeGraphPtr graph_; +}; + +/* + * Data Data Atomic_clean + * | | / | + * relu relu | + * | ==> | | + * relu(atomic) relu(atomic) + * | | + * netoutput netoutput + */ +TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) { + auto node1 = NewNode("node1", DATA, 0, 1); + + auto node2 = NewNode("node2", RELU, 1, 1); + auto node3 = NewNode("node3", RELU, 1, 1); + auto op_desc = node3->GetOpDesc(); + vector atomic_input_index = {123, 456}; + AttrUtils::SetListInt(op_desc, "atomic_input_index", atomic_input_index); + + auto node4 = NewNode("node4", NETOUTPUT, 1, 0); + GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); + GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); + GraphUtils::AddEdge(node3->GetOutDataAnchor(0), node4->GetInDataAnchor(0)); + AtomicAddrCleanPass atomi_addr_clean_pass; + Status ret = atomi_addr_clean_pass.Run(graph_); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(1, CountOfAtomicCleanNode()); + + auto atomic_clean = graph_->FindNode("atomic_addr_clean"); + EXPECT_NE(atomic_clean, nullptr); + auto out_ctrl_nodes = atomic_clean->GetOutControlNodes(); + EXPECT_EQ(out_ctrl_nodes.size(), 2); +} +} // namespace ge