/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include #define private public #include "common/ge_inner_error_codes.h" #include "inc/pass_manager.h" #include "utils/graph_utils.h" #include "graph/passes/parallel_group_pass.h" #undef private namespace ge { namespace { class UtestGraphPassesParallelGgroupPass : public testing::Test { protected: UtestGraphPassesParallelGgroupPass() { graph_ = std::make_shared("test"); sub_graph_ = std::make_shared("test_subgraph"); vector shape_vec{1, 1, 1, 1}; GeShape shape = GeShape(shape_vec); default_tensor_desc_ = std::make_shared(); default_tensor_desc_->SetShape(shape); default_tensor_desc_->SetFormat(FORMAT_NCHW); default_tensor_desc_->SetDataType(DT_FLOAT); } NodePtr NewNode(const std::string &name, const std::string &type, int input_cnt, int output_cnt, bool isSubgraph = false) { OpDescPtr op_desc = std::make_shared(name, type); for (int i = 0; i < input_cnt; ++i) { op_desc->AddInputDesc(default_tensor_desc_->Clone()); } for (int i = 0; i < output_cnt; ++i) { op_desc->AddOutputDesc(default_tensor_desc_->Clone()); } NodePtr node = nullptr; if (isSubgraph) { node = sub_graph_->AddNode(op_desc); (void)node->SetOwnerComputeGraph(sub_graph_); } else { node = graph_->AddNode(op_desc); (void)node->SetOwnerComputeGraph(graph_); } return node; } void BuildDefaultGraph() { /// input /// \ /// sqrt pred /// \ / /// cast /// / \ /// switch_t switch_f /// | | /// F T /// | | /// Merge /// | /// relu /// | /// sqrt1 input_node_ = NewNode("input", RELU, 0, 1); sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); pred_node_ = NewNode("pred", GREATER, 2, 1); cast_node_ = NewNode("cast", CAST, 2, 2); AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); output_false_node_ = NewNode("false_output", RELU, 1, 1); AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); output_true_node_ = NewNode("true_output", RELU, 1, 1); AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); relu_node_ = NewNode("relu", RELU, 1, 1); sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1); AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); output_false_node_->GetOpDesc()->SetIsInputConst({false}); output_true_node_->GetOpDesc()->SetIsInputConst({false}); } void BuildDefaultGraph1() { /// input /// \ /// sqrt pred /// \ / /// Switch /// | | /// ----F T---- /// \ | / \ /// \ Merge1 Merge2 /// \_________| input_node_ = NewNode("input", RELU, 0, 1); AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); pred_node_ = NewNode("pred", GREATER, 2, 1); sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); cast_node_ = NewNode("cast", CAST, 2, 2); switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); output_false_node_ = NewNode("false_output", RELU, 1, 2); AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); output_true_node_ = NewNode("true_output", RELU, 1, 2); AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1); GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0)); GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1)); output_false_node_->GetOpDesc()->SetIsInputConst({false}); output_true_node_->GetOpDesc()->SetIsInputConst({false}); } void BuildDefaultGraph2() { /// input input1 /// \ \ /// sqrt pred sqrt1 pred1 /// \ / \ / /// Switch Switch1 /// | | _______| /// | | / /// ____F T____ /// \ | / \ /// \ Merge1 Merge2 /// \__________| input_node_ = NewNode("input", RELU, 0, 2); input_node1_ = NewNode("input_1", RELU, 0, 2); sqrt_node_ = NewNode("sqrt", SQRT, 1, 1); pred_node_ = NewNode("pred", GREATER, 2, 1); sqrt_node1_ = NewNode("sqrt_1", SQRT, 1, 1); pred_node1_ = NewNode("pred_1", LESS, 2, 1); cast_node_ = NewNode("cast", CAST, 2, 2); cast_node1_ = NewNode("cast_1", CAST, 2, 2); AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); AttrUtils::SetStr(input_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2"); switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1); AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1); AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); switch_node1_t = NewNode("switch1_t", STREAMSWITCH, 1, 1); AttrUtils::SetBool(switch_node1_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true); switch_node1_f = NewNode("switch1_f", STREAMSWITCH, 1, 1); AttrUtils::SetBool(switch_node1_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false); output_false_node_ = NewNode("false_output", RELU, 2, 2); AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); output_true_node_ = NewNode("true_output", RELU, 2, 2); AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2"); merge_node_ = NewNode("merge", STREAMMERGE, 2, 1); merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1); GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1)); GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0)); GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0)); GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(input_node1_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0)); GraphUtils::AddEdge(pred_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(0)); GraphUtils::AddEdge(sqrt_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(1)); GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(0), switch_node1_t->GetInDataAnchor(0)); GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(1), switch_node1_f->GetInDataAnchor(0)); GraphUtils::AddEdge(switch_node1_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(1)); GraphUtils::AddEdge(switch_node1_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(1)); GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0)); GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1)); GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0)); GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1)); output_false_node_->GetOpDesc()->SetIsInputConst({false}); output_true_node_->GetOpDesc()->SetIsInputConst({false}); } ComputeGraphPtr graph_; ComputeGraphPtr sub_graph_; GeTensorDescPtr default_tensor_desc_; ParallelGroupPass pass_; NodePtr pred_node_; NodePtr pred_node1_; NodePtr cast_node_; NodePtr cast_node1_; NodePtr sqrt_node_; NodePtr sqrt_node1_; NodePtr input_node_; NodePtr input_node1_; NodePtr switch_node_t; NodePtr switch_node_f; NodePtr switch_node1_t; NodePtr switch_node1_f; NodePtr output_false_node_; NodePtr output_true_node_; NodePtr merge_node_; NodePtr merge_node1_; NodePtr relu_node_; }; TEST_F(UtestGraphPassesParallelGgroupPass, null_graph) { ComputeGraphPtr graph = nullptr; auto ret = pass_.Run(graph); EXPECT_EQ(ret, PARAM_INVALID); } TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph) { BuildDefaultGraph(); auto ret = pass_.Run(graph_); EXPECT_EQ(ret, GRAPH_SUCCESS); EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor())); EXPECT_EQ(true, merge_node_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor())); EXPECT_EQ(false, output_false_node_->GetOutControlAnchor()->IsLinkedWith(output_true_node_->GetInControlAnchor())); } TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph1) { BuildDefaultGraph1(); auto ret = pass_.Run(graph_); EXPECT_EQ(ret, GRAPH_SUCCESS); EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor())); } TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) { BuildDefaultGraph2(); auto ret = pass_.Run(graph_); EXPECT_EQ(ret, GRAPH_SUCCESS); EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor())); EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor())); } TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) { BuildDefaultGraph1(); NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true); NodePtr input_node2 = NewNode("input2", RELU, 0, 1, true); NodePtr add = NewNode("add", ADD, 2, 1, true); AttrUtils::SetStr(input_node1->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); AttrUtils::SetStr(input_node2->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1"); sub_graph_->SetParentNode(input_node_); sub_graph_->SetParentGraph(graph_); auto ret = graph_->AddSubgraph(sub_graph_->GetName(), sub_graph_); EXPECT_EQ(ret, GRAPH_SUCCESS); ret = input_node_->GetOpDesc()->AddSubgraphName(sub_graph_->GetName()); EXPECT_EQ(ret, GRAPH_SUCCESS); ret = input_node_->GetOpDesc()->SetSubgraphInstanceName(0, sub_graph_->GetName()); EXPECT_EQ(ret, GRAPH_SUCCESS); ret = pass_.Run(sub_graph_); EXPECT_EQ(ret, GRAPH_SUCCESS); ret = pass_.Run(graph_); EXPECT_EQ(ret, GRAPH_SUCCESS); } } // namespace } // namespace ge