You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/tests/ut/ge/graph/passes/parallel_group_pass_unittes...

305 lines
14 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <cstdint>
#include <string>
#define private public
#include "common/ge_inner_error_codes.h"
#include "inc/pass_manager.h"
#include "utils/graph_utils.h"
#include "graph/passes/parallel_group_pass.h"
#undef private
namespace ge {
namespace {
class UtestGraphPassesParallelGgroupPass : public testing::Test {
protected:
UtestGraphPassesParallelGgroupPass() {
graph_ = std::make_shared<ComputeGraph>("test");
sub_graph_ = std::make_shared<ComputeGraph>("test_subgraph");
vector<int64_t> shape_vec{1, 1, 1, 1};
GeShape shape = GeShape(shape_vec);
default_tensor_desc_ = std::make_shared<GeTensorDesc>();
default_tensor_desc_->SetShape(shape);
default_tensor_desc_->SetFormat(FORMAT_NCHW);
default_tensor_desc_->SetDataType(DT_FLOAT);
}
NodePtr NewNode(const std::string &name, const std::string &type,
int input_cnt, int output_cnt, bool isSubgraph = false) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < input_cnt; ++i) {
op_desc->AddInputDesc(default_tensor_desc_->Clone());
}
for (int i = 0; i < output_cnt; ++i) {
op_desc->AddOutputDesc(default_tensor_desc_->Clone());
}
NodePtr node = nullptr;
if (isSubgraph) {
node = sub_graph_->AddNode(op_desc);
(void)node->SetOwnerComputeGraph(sub_graph_);
} else {
node = graph_->AddNode(op_desc);
(void)node->SetOwnerComputeGraph(graph_);
}
return node;
}
void BuildDefaultGraph() {
/// input
/// \
/// sqrt pred
/// \ /
/// cast
/// / \
/// switch_t switch_f
/// | |
/// F T
/// | |
/// Merge
/// |
/// relu
/// |
/// sqrt1
input_node_ = NewNode("input", RELU, 0, 1);
sqrt_node_ = NewNode("sqrt", SQRT, 1, 1);
pred_node_ = NewNode("pred", GREATER, 2, 1);
cast_node_ = NewNode("cast", CAST, 2, 2);
AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1);
AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1);
AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
output_false_node_ = NewNode("false_output", RELU, 1, 1);
AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
output_true_node_ = NewNode("true_output", RELU, 1, 1);
AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
merge_node_ = NewNode("merge", STREAMMERGE, 2, 1);
relu_node_ = NewNode("relu", RELU, 1, 1);
sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1);
AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0));
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0));
output_false_node_->GetOpDesc()->SetIsInputConst({false});
output_true_node_->GetOpDesc()->SetIsInputConst({false});
}
void BuildDefaultGraph1() {
/// input
/// \
/// sqrt pred
/// \ /
/// Switch
/// | |
/// ----F T----
/// \ | / \
/// \ Merge1 Merge2
/// \_________|
input_node_ = NewNode("input", RELU, 0, 1);
AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
pred_node_ = NewNode("pred", GREATER, 2, 1);
sqrt_node_ = NewNode("sqrt", SQRT, 1, 1);
cast_node_ = NewNode("cast", CAST, 2, 2);
switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1);
AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1);
AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
output_false_node_ = NewNode("false_output", RELU, 1, 2);
AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
output_true_node_ = NewNode("true_output", RELU, 1, 2);
AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
merge_node_ = NewNode("merge", STREAMMERGE, 2, 1);
merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1);
GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0));
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1));
output_false_node_->GetOpDesc()->SetIsInputConst({false});
output_true_node_->GetOpDesc()->SetIsInputConst({false});
}
void BuildDefaultGraph2() {
/// input input1
/// \ \
/// sqrt pred sqrt1 pred1
/// \ / \ /
/// Switch Switch1
/// | | _______|
/// | | /
/// ____F T____
/// \ | / \
/// \ Merge1 Merge2
/// \__________|
input_node_ = NewNode("input", RELU, 0, 2);
input_node1_ = NewNode("input_1", RELU, 0, 2);
sqrt_node_ = NewNode("sqrt", SQRT, 1, 1);
pred_node_ = NewNode("pred", GREATER, 2, 1);
sqrt_node1_ = NewNode("sqrt_1", SQRT, 1, 1);
pred_node1_ = NewNode("pred_1", LESS, 2, 1);
cast_node_ = NewNode("cast", CAST, 2, 2);
cast_node1_ = NewNode("cast_1", CAST, 2, 2);
AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
AttrUtils::SetStr(input_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2");
switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1);
AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1);
AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
switch_node1_t = NewNode("switch1_t", STREAMSWITCH, 1, 1);
AttrUtils::SetBool(switch_node1_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
switch_node1_f = NewNode("switch1_f", STREAMSWITCH, 1, 1);
AttrUtils::SetBool(switch_node1_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
output_false_node_ = NewNode("false_output", RELU, 2, 2);
AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
output_true_node_ = NewNode("true_output", RELU, 2, 2);
AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2");
merge_node_ = NewNode("merge", STREAMMERGE, 2, 1);
merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1);
GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0));
GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(input_node1_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0));
GraphUtils::AddEdge(pred_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(0));
GraphUtils::AddEdge(sqrt_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(1));
GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(0), switch_node1_t->GetInDataAnchor(0));
GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(1), switch_node1_f->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node1_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(switch_node1_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1));
output_false_node_->GetOpDesc()->SetIsInputConst({false});
output_true_node_->GetOpDesc()->SetIsInputConst({false});
}
ComputeGraphPtr graph_;
ComputeGraphPtr sub_graph_;
GeTensorDescPtr default_tensor_desc_;
ParallelGroupPass pass_;
NodePtr pred_node_;
NodePtr pred_node1_;
NodePtr cast_node_;
NodePtr cast_node1_;
NodePtr sqrt_node_;
NodePtr sqrt_node1_;
NodePtr input_node_;
NodePtr input_node1_;
NodePtr switch_node_t;
NodePtr switch_node_f;
NodePtr switch_node1_t;
NodePtr switch_node1_f;
NodePtr output_false_node_;
NodePtr output_true_node_;
NodePtr merge_node_;
NodePtr merge_node1_;
NodePtr relu_node_;
};
TEST_F(UtestGraphPassesParallelGgroupPass, null_graph) {
ComputeGraphPtr graph = nullptr;
auto ret = pass_.Run(graph);
EXPECT_EQ(ret, PARAM_INVALID);
}
TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph) {
BuildDefaultGraph();
auto ret = pass_.Run(graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor()));
EXPECT_EQ(true, merge_node_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor()));
EXPECT_EQ(false, output_false_node_->GetOutControlAnchor()->IsLinkedWith(output_true_node_->GetInControlAnchor()));
}
TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph1) {
BuildDefaultGraph1();
auto ret = pass_.Run(graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor()));
}
TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) {
BuildDefaultGraph2();
auto ret = pass_.Run(graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor()));
EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor()));
}
TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) {
BuildDefaultGraph1();
NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true);
NodePtr input_node2 = NewNode("input2", RELU, 0, 1, true);
NodePtr add = NewNode("add", ADD, 2, 1, true);
AttrUtils::SetStr(input_node1->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
AttrUtils::SetStr(input_node2->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
sub_graph_->SetParentNode(input_node_);
sub_graph_->SetParentGraph(graph_);
auto ret = graph_->AddSubgraph(sub_graph_->GetName(), sub_graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
ret = input_node_->GetOpDesc()->AddSubgraphName(sub_graph_->GetName());
EXPECT_EQ(ret, GRAPH_SUCCESS);
ret = input_node_->GetOpDesc()->SetSubgraphInstanceName(0, sub_graph_->GetName());
EXPECT_EQ(ret, GRAPH_SUCCESS);
ret = pass_.Run(sub_graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
ret = pass_.Run(graph_);
EXPECT_EQ(ret, GRAPH_SUCCESS);
}
} // namespace
} // namespace ge