You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/tests/ut/ge/graph/passes/merge_pass_unittest.cc

404 lines
15 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <cstdint>
#include <memory>
#include <string>
#define private public
#include "graph/passes/merge_pass.h"
#include "common/ge_inner_error_codes.h"
#include "inc/pass_manager.h"
#undef private
namespace ge {
namespace {
class UtestGraphPassesMergePass : public testing::Test {
protected:
UtestGraphPassesMergePass() {
graph_ = std::make_shared<ComputeGraph>("test");
vector<int64_t> shape_vec{1, 1, 224, 224};
GeShape shape = GeShape(shape_vec);
default_tensor_desc_ = std::make_shared<GeTensorDesc>();
default_tensor_desc_->SetShape(shape);
default_tensor_desc_->SetFormat(FORMAT_NCHW);
default_tensor_desc_->SetDataType(DT_FLOAT);
}
NodePtr NewNode(const std::string &name, const std::string &type, int input_cnt, int output_cnt) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < input_cnt; ++i) {
op_desc->AddInputDesc(default_tensor_desc_->Clone());
}
for (int i = 0; i < output_cnt; ++i) {
op_desc->AddOutputDesc(default_tensor_desc_->Clone());
}
NodePtr node = graph_->AddNode(op_desc);
if (type == CONSTANT) {
int32_t weight[] = {1};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
vector<GeTensorPtr> tensor_vec = {tensor};
OpDescUtils::SetWeights(node, tensor_vec);
}
return node;
}
ComputeGraphPtr graph_;
GeTensorDescPtr default_tensor_desc_;
MergePass pass_;
};
} // namespace
TEST_F(UtestGraphPassesMergePass, null_input) {
NodePtr node = nullptr;
auto ret = pass_.Run(node);
EXPECT_EQ(ret, PARAM_INVALID);
}
TEST_F(UtestGraphPassesMergePass, filter_non_merge_node) {
auto node = NewNode("Op1", CONSTANT, 0, 1);
auto ret = pass_.Run(node);
EXPECT_EQ(ret, SUCCESS);
}
TEST_F(UtestGraphPassesMergePass, invalid_merge_node) {
auto merge_node = NewNode("Merge", MERGE, 2, 0);
auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, PARAM_INVALID);
}
/// Op1 Op2
/// \ /
/// \ /
/// Merge
/// |
/// |
/// NetOutput
TEST_F(UtestGraphPassesMergePass, multiple_inputs) {
auto node1 = NewNode("Op1", CONSTANT, 0, 1);
auto node2 = NewNode("Op2", CONSTANT, 0, 1);
auto merge_node = NewNode("Merge", MERGE, 2, 2);
auto net_output_node = NewNode("NetOutput", NETOUTPUT, 1, 1);
GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0));
GraphUtils::AddEdge(node2->GetOutDataAnchor(0), merge_node->GetInDataAnchor(1));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), net_output_node->GetInDataAnchor(0));
auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, SUCCESS);
}
/// Merge
/// | \
/// | \
/// Op1 Op2 Merge2
/// \ | |
/// \ | Op3
/// \ | /
/// NetOutput
TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_data_anchor) {
auto merge_node = NewNode("Merge", MERGE, 1, 1);
auto merge_node2 = NewNode("Merge2", MERGE, 1, 1);
auto node1 = NewNode("Op1", CONSTANT, 0, 1);
auto node2 = NewNode("Op2", RELU, 1, 1);
auto node3 = NewNode("Op3", RELU, 1, 1);
auto net_output = NewNode("NetOutput", NETOUTPUT, 3, 3);
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), merge_node2->GetInDataAnchor(0));
GraphUtils::AddEdge(node1->GetOutDataAnchor(0), net_output->GetInDataAnchor(0));
GraphUtils::AddEdge(node2->GetOutDataAnchor(0), net_output->GetInDataAnchor(1));
GraphUtils::AddEdge(merge_node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
GraphUtils::AddEdge(node3->GetOutDataAnchor(0), net_output->GetInDataAnchor(2));
auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, INTERNAL_ERROR);
}
/// Merge
/// | \
/// | \
/// Op1 Op2 Merge2
/// \ | | \
/// \ | Op3
/// \ | :
/// NetOutput
TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch_meet_net_output_with_control_anchor) {
auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto merge_node2 = NewNode("Merge2", MERGE, 1, 2);
auto node1 = NewNode("Op1", CONSTANT, 0, 1);
auto node2 = NewNode("Op2", RELU, 1, 1);
auto node3 = NewNode("Op3", RELU, 1, 1);
auto net_output = NewNode("NetOutput", NETOUTPUT, 3, 3);
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), merge_node2->GetInDataAnchor(0));
GraphUtils::AddEdge(node1->GetOutDataAnchor(0), net_output->GetInDataAnchor(0));
GraphUtils::AddEdge(node2->GetOutDataAnchor(0), net_output->GetInControlAnchor());
GraphUtils::AddEdge(merge_node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
GraphUtils::AddEdge(node3->GetOutDataAnchor(0), net_output->GetInControlAnchor());
auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, SUCCESS);
}
TEST_F(UtestGraphPassesMergePass, empty_input_cut_branch) {
/// Merge
/// | \
/// | \
/// Op1 Op2 Merge2
/// \ | |
/// \ | Op3
/// \ | /
/// Merge3
auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto merge_node2 = NewNode("Merge2", MERGE, 1, 2);
auto node1 = NewNode("Op1", CONSTANT, 0, 1);
auto node2 = NewNode("Op2", RELU, 1, 1);
auto node3 = NewNode("Op3", RELU, 1, 1);
auto merge_node3 = NewNode("Merge3", MERGE, 3, 3);
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), merge_node2->GetInDataAnchor(0));
GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node3->GetInDataAnchor(0));
GraphUtils::AddEdge(node2->GetOutDataAnchor(0), merge_node3->GetInDataAnchor(1));
GraphUtils::AddEdge(merge_node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
GraphUtils::AddEdge(node3->GetOutDataAnchor(0), merge_node3->GetInDataAnchor(2));
/// Merge
/// |
/// |
/// Op1 Op2 Merge2
/// \ |
/// \ Op3
/// \ /
/// Merge3
auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(merge_node3->GetInDataNodes().size(), 2);
EXPECT_EQ(merge_node2->GetInDataNodes().size(), 0);
EXPECT_EQ(node2->GetOutDataNodes().size(), 0);
EXPECT_EQ(merge_node3->GetInDataNodes().at(0)->GetName(), "Op1");
EXPECT_EQ(merge_node3->GetInDataNodes().at(1)->GetName(), "Op3");
/// Merge
/// |
/// |
/// Op1 Op2 Merge2
/// \ |
/// \ Op3
/// \
/// Merge3
ret = pass_.Run(merge_node2);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(merge_node3->GetInDataNodes().size(), 1);
EXPECT_EQ(merge_node3->GetInDataNodes().at(0)->GetName(), "Op1");
EXPECT_EQ(node3->GetOutDataNodes().size(), 0);
}
TEST_F(UtestGraphPassesMergePass, single_non_const_input) {
/// Op1
/// |
/// Merge
/// / \
/// Op2 Op3
auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto node1 = NewNode("Op1", RELU, 1, 1);
auto node2 = NewNode("Op2", CONVOLUTION, 1, 1);
auto node3 = NewNode("Op3", CONVOLUTION, 1, 1);
GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
EXPECT_EQ(node1->GetOutDataNodes().size(), 1);
auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph_->GetDirectNodesSize(), 3);
EXPECT_EQ(graph_->FindNode("Merge"), nullptr);
EXPECT_EQ(node1->GetOutDataNodes().size(), 2);
EXPECT_EQ(node2->GetInDataNodes().size(), 1);
EXPECT_EQ(node3->GetInDataNodes().size(), 1);
EXPECT_EQ(node2->GetInDataAnchor(0)->GetPeerOutAnchor(), node1->GetOutDataAnchor(0));
EXPECT_EQ(node3->GetInDataAnchor(0)->GetPeerOutAnchor(), node1->GetOutDataAnchor(0));
}
TEST_F(UtestGraphPassesMergePass, single_const_input) {
/// Const
/// |
/// Merge Pass Const
/// / \ ===> / \
/// Op1 Op2 Op1 Op2
auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto const_node = NewNode("Const", CONSTANT, 1, 1);
auto node1 = NewNode("Op1", ADDN, 1, 1);
auto node2 = NewNode("Op2", ADDN, 1, 1);
node1->GetOpDesc()->SetIsInputConst({false});
node2->GetOpDesc()->SetIsInputConst({false});
GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph_->GetDirectNodesSize(), 3);
EXPECT_EQ(graph_->FindNode("Merge").get(), nullptr);
EXPECT_EQ(node1->GetInDataNodes().size(), 1);
EXPECT_EQ(node2->GetInDataNodes().size(), 1);
EXPECT_EQ(node1->GetOpDesc()->GetIsInputConst().at(0), false);
EXPECT_EQ(node2->GetOpDesc()->GetIsInputConst().at(0), false);
}
TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes) {
/// Const
/// |
/// Merge Pass Const
/// / | ===> / \(control anchor)
/// Op1 | \ Op1 Constant
/// Op2 Op3 |
/// / \
/// Op2 Op3
auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto const_node = NewNode("Const", CONSTANT, 1, 1);
auto node1 = NewNode("Op1", ADDN, 1, 1);
auto node2 = NewNode("Op2", ADDN, 1, 1);
auto node3 = NewNode("Op3", ADDN, 1, 1);
node1->GetOpDesc()->SetIsInputConst({false});
node2->GetOpDesc()->SetIsInputConst({false});
GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(1), node2->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(1), node3->GetInDataAnchor(0));
auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph_->GetDirectNodesSize(), 5);
EXPECT_EQ(graph_->FindNode("Merge").get(), nullptr);
EXPECT_EQ(node1->GetInDataNodes().size(), 1);
EXPECT_EQ(node2->GetInDataNodes().size(), 1);
EXPECT_EQ(node1->GetOpDesc()->GetIsInputConst().at(0), false);
EXPECT_EQ(node2->GetOpDesc()->GetIsInputConst().at(0), false);
NodePtr node_test = graph_->FindNode("Merge_value_index");
EXPECT_NE(node_test.get(), nullptr);
EXPECT_EQ(node_test->GetOutDataNodes().size(), 2);
EXPECT_EQ(node_test->GetAllOutDataAnchors().size(), 1);
EXPECT_EQ(node_test->GetInDataNodes().size(), 0);
EXPECT_EQ(node2->GetInDataNodes().at(0)->GetInControlAnchor()->GetPeerOutControlAnchors().at(0),
const_node->GetOutControlAnchor());
EXPECT_EQ(node3->GetInDataNodes().at(0)->GetInControlAnchor()->GetPeerOutControlAnchors().at(0),
const_node->GetOutControlAnchor());
EXPECT_EQ(node2->GetInDataNodes().at(0), node_test);
EXPECT_EQ(node3->GetInDataNodes().at(0), node_test);
}
TEST_F(UtestGraphPassesMergePass, single_const_input_value_index_two_out_nodes1) {
/// Const
/// |
/// Merge Pass Const
/// / | ===> / \(control anchor)
/// Op1 | \ Op1 Constant
/// Op2 Op3 |
/// / \
/// Op2 Op3
auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto const_node = NewNode("Const", CONSTANT, 1, 1);
auto node1 = NewNode("Op1", ADDN, 1, 1);
auto node2 = NewNode("Op2", ADDN, 1, 1);
auto node3 = NewNode("Op3", ADDN, 1, 1);
node1->GetOpDesc()->SetIsInputConst({false});
node2->GetOpDesc()->SetIsInputConst({false});
GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(1), node2->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(1), node3->GetInDataAnchor(0));
auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, SUCCESS);
}
TEST_F(UtestGraphPassesMergePass, const_with_control_input) {
/// Switch
/// |
/// Identity
/// .
/// .
/// C
/// |
/// Merge
/// / \
/// Op1 Op2
auto switch_node = NewNode("Switch", SWITCH, 1, 2);
auto identity_node = NewNode("Identity", SWITCH, 1, 1);
auto const_node = NewNode("Const", CONSTANT, 1, 1);
auto merge_node = NewNode("Merge", MERGE, 1, 2);
auto node1 = NewNode("Op1", ADDN, 1, 1);
auto node2 = NewNode("Op2", ADDN, 1, 1);
node1->GetOpDesc()->SetIsInputConst({false});
node2->GetOpDesc()->SetIsInputConst({false});
GraphUtils::AddEdge(switch_node->GetOutDataAnchor(0), identity_node->GetInDataAnchor(0));
GraphUtils::AddEdge(identity_node->GetOutControlAnchor(), const_node->GetInControlAnchor());
GraphUtils::AddEdge(identity_node->GetOutDataAnchor(0), const_node->GetInControlAnchor());
GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
/// Switch
/// |
/// Identity
/// .
/// .
/// C
/// / \
/// Op1 Op2
auto ret = pass_.Run(merge_node);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph_->GetDirectNodesSize(), 5);
EXPECT_EQ(graph_->FindNode("Merge").get(), nullptr);
EXPECT_EQ(node1->GetInDataNodes().size(), 1);
EXPECT_EQ(node2->GetInDataNodes().size(), 1);
EXPECT_EQ(node1->GetOpDesc()->GetIsInputConst().at(0), false);
EXPECT_EQ(node2->GetOpDesc()->GetIsInputConst().at(0), false);
EXPECT_EQ(node1->GetInDataNodes().at(0)->GetInControlAnchor()->GetPeerOutDataAnchors().at(0),
identity_node->GetOutDataAnchor(0));
EXPECT_EQ(node1->GetInDataNodes().at(0)->GetInControlAnchor()->GetPeerOutControlAnchors().at(0),
identity_node->GetOutControlAnchor());
EXPECT_EQ(node2->GetInDataNodes().at(0)->GetInControlAnchor()->GetPeerOutDataAnchors().at(0),
identity_node->GetOutDataAnchor(0));
EXPECT_EQ(node2->GetInDataNodes().at(0)->GetInControlAnchor()->GetPeerOutControlAnchors().at(0),
identity_node->GetOutControlAnchor());
}
} // namespace ge