You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/tests/ut/ge/graph/passes/switch_pass_unittest.cc

424 lines
16 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <cstdint>
#include <string>
#define private public
#include "graph/passes/switch_pass.h"
#include "common/ge_inner_error_codes.h"
#include "inc/pass_manager.h"
#include "utils/graph_utils.h"
#undef private
namespace ge {
namespace {
class UtestGraphPassesSwitchPass : public testing::Test {
protected:
UtestGraphPassesSwitchPass() {
graph_ = std::make_shared<ComputeGraph>("test");
vector<int64_t> shape_vec{1, 1, 1, 1};
GeShape shape = GeShape(shape_vec);
default_tensor_desc_ = std::make_shared<GeTensorDesc>();
default_tensor_desc_->SetShape(shape);
default_tensor_desc_->SetFormat(FORMAT_NCHW);
default_tensor_desc_->SetDataType(DT_FLOAT);
}
NodePtr NewNode(const std::string &name, const std::string &type, int input_cnt, int output_cnt) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < input_cnt; ++i) {
op_desc->AddInputDesc(default_tensor_desc_->Clone());
}
for (int i = 0; i < output_cnt; ++i) {
op_desc->AddOutputDesc(default_tensor_desc_->Clone());
}
NodePtr node = graph_->AddNode(op_desc);
(void)node->SetOwnerComputeGraph(graph_);
return node;
}
void BuildDefaultGraph(bool is_input_const, const bool *pred_value = nullptr) {
/// input pred
/// \ /
/// Switch
/// | |
/// F T
/// | |
/// Merge
///
bool is_pred_const = pred_value != nullptr;
if (is_pred_const) {
pred_node_ = NewNode("pred", CONSTANT, 0, 1);
int32_t weight[] = {static_cast<int32_t>(*pred_value)};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
OpDescUtils::SetWeights(pred_node_, {tensor});
} else {
pred_node_ = NewNode("pred", GREATER, 2, 1);
}
if (is_input_const) {
int32_t weight[] = {1};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
input_node_ = NewNode("input", CONSTANT, 0, 1);
OpDescUtils::SetWeights(input_node_, {tensor});
} else {
input_node_ = NewNode("input", RELU, 0, 1);
}
switch_node_ = NewNode("switch", SWITCH, 2, 2);
output_false_node_ = NewNode("false_output", RELU, 1, 1);
output_true_node_ = NewNode("true_output", RELU, 1, 1);
merge_node_ = NewNode("merge", MERGE, 2, 1);
switch_node_->GetOpDesc()->SetIsInputConst({false, is_pred_const});
GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
output_false_node_->GetOpDesc()->SetIsInputConst({false});
output_true_node_->GetOpDesc()->SetIsInputConst({false});
}
void TestPickOutput(bool expect_output) {
auto ret = pass_.Run(switch_node_);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph_->GetDirectNodesSize(), 5); // has two isolate nodes
EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
if (expect_output) {
EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
} else {
EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
}
}
ComputeGraphPtr graph_;
GeTensorDescPtr default_tensor_desc_;
SwitchPass pass_;
NodePtr pred_node_;
NodePtr input_node_;
NodePtr switch_node_;
NodePtr output_false_node_;
NodePtr output_true_node_;
NodePtr merge_node_;
};
} // namespace
TEST_F(UtestGraphPassesSwitchPass, null_input) {
NodePtr node = nullptr;
auto ret = pass_.Run(node);
EXPECT_EQ(ret, PARAM_INVALID);
}
TEST_F(UtestGraphPassesSwitchPass, null_pred) {
BuildDefaultGraph(false);
switch_node_->GetInDataAnchor(1)->UnlinkAll();
auto ret = pass_.Run(switch_node_);
EXPECT_EQ(ret, SUCCESS);
}
TEST_F(UtestGraphPassesSwitchPass, null_data) {
BuildDefaultGraph(false);
switch_node_->GetInDataAnchor(0)->UnlinkAll();
auto ret = pass_.Run(switch_node_);
EXPECT_EQ(ret, SUCCESS);
}
TEST_F(UtestGraphPassesSwitchPass, unsupported_node_type) {
auto node = NewNode("Op1", CONSTANT, 0, 1);
auto ret = pass_.Run(node);
EXPECT_EQ(ret, SUCCESS);
}
TEST_F(UtestGraphPassesSwitchPass, empty_output) {
BuildDefaultGraph(false);
switch_node_->GetOutDataAnchor(0)->UnlinkAll();
switch_node_->GetOutDataAnchor(1)->UnlinkAll();
auto ret = pass_.Run(switch_node_);
EXPECT_EQ(ret, SUCCESS);
}
TEST_F(UtestGraphPassesSwitchPass, non_const_pred) {
BuildDefaultGraph(false);
auto ret = pass_.Run(switch_node_);
EXPECT_EQ(ret, SUCCESS);
}
TEST_F(UtestGraphPassesSwitchPass, pick_output_false) {
bool pred_value = false;
BuildDefaultGraph(false, &pred_value);
TestPickOutput(false);
}
TEST_F(UtestGraphPassesSwitchPass, pick_output_false_float) {
bool pred_value = false;
BuildDefaultGraph(false, &pred_value);
float weight[] = {0.0f};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_FLOAT);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
OpDescUtils::SetWeights(pred_node_, {tensor});
TestPickOutput(false);
}
TEST_F(UtestGraphPassesSwitchPass, pick_output_false_bool) {
bool pred_value = false;
BuildDefaultGraph(false, &pred_value);
bool weight[] = {false};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_BOOL);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
OpDescUtils::SetWeights(pred_node_, {tensor});
TestPickOutput(false);
}
TEST_F(UtestGraphPassesSwitchPass, pick_output_false_u16) {
bool pred_value = false;
BuildDefaultGraph(false, &pred_value);
uint16_t weight[] = {0};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_UINT16);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
OpDescUtils::SetWeights(pred_node_, {tensor});
TestPickOutput(false);
}
TEST_F(UtestGraphPassesSwitchPass, pick_output_true) {
bool pred_value = true;
BuildDefaultGraph(false, &pred_value);
TestPickOutput(true);
}
TEST_F(UtestGraphPassesSwitchPass, pick_output_true_double) {
bool pred_value = true;
BuildDefaultGraph(false, &pred_value);
double weight[] = {1.0};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_DOUBLE);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
OpDescUtils::SetWeights(pred_node_, {tensor});
TestPickOutput(true);
}
TEST_F(UtestGraphPassesSwitchPass, pick_output_true_int64) {
bool pred_value = true;
BuildDefaultGraph(false, &pred_value);
int64_t weight[] = {1L};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT64);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
OpDescUtils::SetWeights(pred_node_, {tensor});
TestPickOutput(true);
}
TEST_F(UtestGraphPassesSwitchPass, inactive_output_not_exists) {
/// input pred(false)
/// \ /
/// Switch
/// |
/// F
/// |
/// Merge
bool pred_value = false;
BuildDefaultGraph(false, &pred_value);
output_true_node_->GetOutDataAnchor(0)->UnlinkAll();
GraphUtils::RemoveNodeWithoutRelink(graph_, output_true_node_);
switch_node_->GetOutDataAnchor(1)->UnlinkAll();
// switch_node_->outDataAnchors_.pop_back();
/// input
/// |
/// F
/// |
/// Merge
auto ret = pass_.Run(switch_node_);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph_->GetDirectNodesSize(), 4);
EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
}
TEST_F(UtestGraphPassesSwitchPass, const_input_pick_output_true) {
/// const pred(true)
/// \ /
/// Switch
/// | | \
/// F T1 T2
/// | | |
/// | | /
/// | T3
/// | |
/// Merge
bool pred_value = true;
BuildDefaultGraph(true, &pred_value);
auto output_true_node2 = NewNode("true_output2", RELU, 1, 1);
auto output_true_node3 = NewNode("true_output3", ADD, 2, 1);
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node2->GetInDataAnchor(0));
GraphUtils::RemoveEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), output_true_node3->GetInDataAnchor(0));
GraphUtils::AddEdge(output_true_node2->GetOutDataAnchor(0), output_true_node3->GetInDataAnchor(1));
GraphUtils::AddEdge(output_true_node3->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
/// pred C
/// | | |
/// F T1 T2
/// | /
/// T3
/// |
/// Merge
auto ret = pass_.Run(switch_node_);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node3->GetOutDataAnchor(0));
EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
EXPECT_NE(output_true_node2->GetInDataAnchor(0)->GetPeerOutAnchor(),
output_true_node3->GetInDataAnchor(0)->GetPeerOutAnchor());
}
TEST_F(UtestGraphPassesSwitchPass, after_switch_const_take_false_branch) {
/// C pred(false)
/// \ /
/// Switch
/// . .
/// . .
/// C_1 -> F T <- C_2
/// | |
/// Merge
bool pred_value = false;
BuildDefaultGraph(true, &pred_value);
switch_node_->GetOutDataAnchor(0)->UnlinkAll();
switch_node_->GetOutDataAnchor(1)->UnlinkAll();
NodePtr const_node_1 = NewNode("const_1", CONSTANT, 0, 1);
NodePtr const_node_2 = NewNode("const_2", CONSTANT, 0, 1);
GraphUtils::AddEdge(const_node_1->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(const_node_2->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInControlAnchor());
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInControlAnchor());
/// C pred(false)
///
/// C_1 C_2
/// | |
/// F T
/// |
/// Merge
auto ret = pass_.Run(switch_node_);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), const_node_1->GetOutDataAnchor(0));
}
TEST_F(UtestGraphPassesSwitchPass, after_switch_const_take_true_branch) {
/// C pred(true)
/// \ /
/// Switch
/// . .
/// . .
/// C_1 -> F T <- C_2
/// | |
/// Merge
bool pred_value = true;
BuildDefaultGraph(true, &pred_value);
switch_node_->GetOutDataAnchor(0)->UnlinkAll();
switch_node_->GetOutDataAnchor(1)->UnlinkAll();
NodePtr const_node_1 = NewNode("const_1", CONSTANT, 0, 1);
NodePtr const_node_2 = NewNode("const_2", CONSTANT, 0, 1);
GraphUtils::AddEdge(const_node_1->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(const_node_2->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInControlAnchor());
GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInControlAnchor());
/// C_1 C_2
/// | |
/// F T
/// |
/// Merge
auto ret = pass_.Run(switch_node_);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), const_node_2->GetOutDataAnchor(0));
}
TEST_F(UtestGraphPassesSwitchPass, dead_output_connected_to_merge) {
/// input pred(true)
/// \ /
/// Switch
/// | |
/// | T
/// | |
/// Merge
bool pred_value = true;
BuildDefaultGraph(false, &pred_value);
// graph_->RemoveNode(output_false_node_);
output_false_node_->GetOutDataAnchor(0)->UnlinkAll();
GraphUtils::RemoveNodeWithoutRelink(graph_, output_false_node_);
switch_node_->GetOutDataAnchor(0)->UnlinkAll();
/// input pred(true)
/// \ /
/// Switch
/// |
/// T
/// |
/// Merge
auto ret = pass_.Run(switch_node_);
EXPECT_EQ(ret, SUCCESS);
/// input
/// |
/// T
/// |
/// Merge
EXPECT_EQ(graph_->GetDirectNodesSize(), 4);
EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
}
} // namespace ge