You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/tests/ut/ge/graph/passes/prune_pass_unittest.cc

491 lines
20 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <vector>
#include "omg/omg_inner_types.h"
#define protected public
#define private public
#include "graph/passes/prune_pass.h"
#include "anchor.h"
#include "common/debug/log.h"
#include "common/debug/memory_dumper.h"
#include "common/op/attr_value_util.h"
#include "common/types.h"
#include "framework/common/ge_inner_error_codes.h"
#include "graph/attr_value.h"
#include "graph/debug/ge_attr_define.h"
#include "inc/pass_manager.h"
#undef protected
#undef private
using namespace testing;
using namespace ge;
using namespace std;
class UtestGraphPassesPrunePass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};
// case1:no net_out_put_node
TEST_F(UtestGraphPassesPrunePass, no_net_out_put_node) {
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
ge::OpDescPtr reverse_op = std::make_shared<ge::OpDesc>();
reverse_op->SetType(REVERSE);
reverse_op->SetName("Reverse");
reverse_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr reverse_node = graph->AddNode(reverse_op);
ge::OpDescPtr floor_op = std::make_shared<ge::OpDesc>();
floor_op->SetType(FLOOR);
floor_op->SetName("Floor");
floor_op->AddInputDesc(ge::GeTensorDesc());
floor_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr floor_node = graph->AddNode(floor_op);
ge::GraphUtils::AddEdge(reverse_node->GetOutDataAnchor(0), floor_node->GetInDataAnchor(0));
uint64_t size_ori = graph->GetDirectNode().size();
PrunePass prune_pass;
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} };
Status status = PassManager::Run(graph, passes);
EXPECT_EQ(ge::SUCCESS, status);
uint64_t size = graph->GetDirectNode().size();
EXPECT_EQ(size, size_ori);
}
// case2: one net path with one bypass branch
TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_only_one_path) {
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
ge::OpDescPtr reverse_op = std::make_shared<ge::OpDesc>();
reverse_op->SetType(REVERSE);
reverse_op->SetName("Reverse");
reverse_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr reverse_node = graph->AddNode(reverse_op);
ge::OpDescPtr floor_op = std::make_shared<ge::OpDesc>();
floor_op->SetType(FLOOR);
floor_op->SetName("Floor");
floor_op->AddInputDesc(ge::GeTensorDesc());
floor_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr floor_node = graph->AddNode(floor_op);
ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
net_output_op->AddInputDesc(ge::GeTensorDesc());
net_output_op->AddOutputDesc(ge::GeTensorDesc());
ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
ge::OpDescPtr reverse_op1 = std::make_shared<ge::OpDesc>();
reverse_op->SetType(REVERSE);
reverse_op->SetName("Reverse1");
reverse_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr reverse_node1 = graph->AddNode(reverse_op1);
ge::GraphUtils::AddEdge(reverse_node->GetOutDataAnchor(0), floor_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(floor_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
uint64_t size_ori = graph->GetDirectNode().size();
PrunePass prune_pass;
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} };
Status status = PassManager::Run(graph, passes);
uint64_t size = graph->GetDirectNode().size();
int diff = size_ori - size;
EXPECT_EQ(ge::SUCCESS, status);
EXPECT_EQ(diff, 1);
}
// case3: one net path with one bypass branch
TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_one_valid_path_and_one_bypass_path) {
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
// valid path construct (reverse->floor->net_out)
ge::OpDescPtr reverse_op = std::make_shared<ge::OpDesc>();
reverse_op->SetType(REVERSE);
reverse_op->SetName("Reverse");
reverse_op->AddOutputDesc(ge::GeTensorDesc());
reverse_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr reverse_node = graph->AddNode(reverse_op);
ge::OpDescPtr floor_op = std::make_shared<ge::OpDesc>();
floor_op->SetType(FLOOR);
floor_op->SetName("Floor");
floor_op->AddInputDesc(ge::GeTensorDesc());
floor_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr floor_node = graph->AddNode(floor_op);
ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
net_output_op->AddInputDesc(ge::GeTensorDesc());
net_output_op->AddOutputDesc(ge::GeTensorDesc());
ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
ge::GraphUtils::AddEdge(reverse_node->GetOutDataAnchor(0), floor_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(floor_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
// incvalid path construct (reverse->floor1->floor2)
ge::OpDescPtr floor_op1 = std::make_shared<ge::OpDesc>();
floor_op1->SetType(FLOOR);
floor_op1->SetName("Floor1");
floor_op1->AddInputDesc(ge::GeTensorDesc());
floor_op1->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr floor_node1 = graph->AddNode(floor_op1);
ge::OpDescPtr floor_op2 = std::make_shared<ge::OpDesc>();
floor_op2->SetType(FLOOR);
floor_op2->SetName("Floor2");
floor_op2->AddInputDesc(ge::GeTensorDesc());
floor_op2->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr floor_node2 = graph->AddNode(floor_op2);
// isolated node
ge::OpDescPtr floor_op3 = std::make_shared<ge::OpDesc>();
floor_op3->SetType(FLOOR);
floor_op3->SetName("Floor3");
floor_op3->AddInputDesc(ge::GeTensorDesc());
floor_op3->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr floor_node3 = graph->AddNode(floor_op3);
ge::GraphUtils::AddEdge(reverse_node->GetOutDataAnchor(1), floor_node1->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(floor_node1->GetOutDataAnchor(0), floor_node2->GetInDataAnchor(0));
uint64_t size_ori = graph->GetDirectNode().size();
PrunePass prune_pass;
vector<GraphPass *> passes = {&prune_pass};
}
// case 4: multi net path with one common netout(1:multi1)
TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_multi_path) {
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(DATA);
data_op->SetName("data");
data_op->AddOutputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr data_node = graph->AddNode(data_op);
ge::OpDescPtr reverse_op1 = std::make_shared<ge::OpDesc>();
reverse_op1->SetType(REVERSE);
reverse_op1->SetName("Reverse1");
reverse_op1->AddInputDesc(ge::GeTensorDesc());
reverse_op1->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr reverse_node1 = graph->AddNode(reverse_op1);
ge::OpDescPtr floor_op1 = std::make_shared<ge::OpDesc>();
floor_op1->SetType(FLOOR);
floor_op1->SetName("Floor1");
floor_op1->AddInputDesc(ge::GeTensorDesc());
floor_op1->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr floor_node1 = graph->AddNode(floor_op1);
ge::OpDescPtr reverse_op2 = std::make_shared<ge::OpDesc>();
reverse_op2->SetType(REVERSE);
reverse_op2->SetName("Reverse2");
reverse_op2->AddInputDesc(ge::GeTensorDesc());
reverse_op2->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr reverse_node2 = graph->AddNode(reverse_op2);
ge::OpDescPtr floor_op2 = std::make_shared<ge::OpDesc>();
floor_op2->SetType(FLOOR);
floor_op2->SetName("Floor2");
floor_op2->AddInputDesc(ge::GeTensorDesc());
floor_op2->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr floor_node2 = graph->AddNode(floor_op2);
ge::OpDescPtr reverse_op3 = std::make_shared<ge::OpDesc>();
reverse_op3->SetType(REVERSE);
reverse_op3->SetName("Reverse3");
reverse_op3->AddInputDesc(ge::GeTensorDesc());
reverse_op3->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr reverse_node3 = graph->AddNode(reverse_op3);
ge::OpDescPtr floor_op3 = std::make_shared<ge::OpDesc>();
floor_op3->SetType(FLOOR);
floor_op3->SetName("Floor3");
floor_op3->AddInputDesc(ge::GeTensorDesc());
floor_op3->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr floor_node3 = graph->AddNode(floor_op3);
ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
net_output_op->AddInputDesc(ge::GeTensorDesc());
net_output_op->AddInputDesc(ge::GeTensorDesc());
net_output_op->AddInputDesc(ge::GeTensorDesc());
net_output_op->AddOutputDesc(ge::GeTensorDesc());
ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), reverse_node1->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(1), reverse_node2->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(2), reverse_node3->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(reverse_node1->GetOutDataAnchor(0), floor_node1->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(floor_node1->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(reverse_node2->GetOutDataAnchor(0), floor_node2->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(floor_node2->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(1));
ge::GraphUtils::AddEdge(reverse_node3->GetOutDataAnchor(0), floor_node3->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(floor_node3->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(2));
uint64_t size_ori = graph->GetDirectNode().size();
PrunePass prune_pass;
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} };
Status status = PassManager::Run(graph, passes);
uint64_t size_after_proc = graph->GetDirectNode().size();
EXPECT_EQ(size_ori, size_after_proc);
}
// case 5: circle,diamand style
TEST_F(UtestGraphPassesPrunePass, multi_net_out_put_node_with_circle_net) {
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(DATA);
data_op->SetName("data");
data_op->AddOutputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr data_node = graph->AddNode(data_op);
ge::OpDescPtr op_1 = std::make_shared<ge::OpDesc>();
op_1->SetType(REVERSE);
op_1->SetName("Reverse1");
op_1->AddInputDesc(ge::GeTensorDesc());
op_1->AddOutputDesc(ge::GeTensorDesc());
op_1->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_1 = graph->AddNode(op_1);
ge::OpDescPtr op_2 = std::make_shared<ge::OpDesc>();
op_2->SetType(REVERSE);
op_2->SetName("Reverse2");
op_2->AddInputDesc(ge::GeTensorDesc());
op_2->AddInputDesc(ge::GeTensorDesc());
op_2->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_2 = graph->AddNode(op_2);
ge::OpDescPtr op_3 = std::make_shared<ge::OpDesc>();
op_3->SetType(REVERSE);
op_3->SetName("Reverse3");
op_3->AddInputDesc(ge::GeTensorDesc());
op_3->AddInputDesc(ge::GeTensorDesc());
op_3->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_3 = graph->AddNode(op_3);
ge::OpDescPtr op_4 = std::make_shared<ge::OpDesc>();
op_4->SetType(REVERSE);
op_4->SetName("Reverse4");
op_4->AddInputDesc(ge::GeTensorDesc());
op_4->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_4 = graph->AddNode(op_4);
ge::OpDescPtr op_5 = std::make_shared<ge::OpDesc>();
op_5->SetType(REVERSE);
op_5->SetName("Reverse5");
op_5->AddInputDesc(ge::GeTensorDesc());
op_5->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_5 = graph->AddNode(op_5);
ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
net_output_op->AddInputDesc(ge::GeTensorDesc());
net_output_op->AddOutputDesc(ge::GeTensorDesc());
ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
ge::GraphUtils::AddEdge(node_1->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_2->GetOutDataAnchor(0), node_1->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_3->GetOutDataAnchor(0), node_2->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_4->GetOutDataAnchor(0), node_3->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_1->GetOutDataAnchor(1), node_4->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node_2->GetInDataAnchor(1));
ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(1), node_5->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_5->GetOutDataAnchor(0), node_3->GetInDataAnchor(1));
uint64_t size_ori = graph->GetDirectNode().size();
PrunePass prune_pass;
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} };
Status status = PassManager::Run(graph, passes);
EXPECT_EQ(ge::SUCCESS, status);
uint64_t size_after_proc = graph->GetDirectNode().size();
EXPECT_EQ(size_ori, size_after_proc);
}
// case 6: two mix circle and multi path,diamand style
TEST_F(UtestGraphPassesPrunePass, mix_two_circle_net) {
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
data_op->SetType(DATA);
data_op->SetName("data");
data_op->AddOutputDesc(ge::GeTensorDesc());
data_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr data_node = graph->AddNode(data_op);
ge::OpDescPtr op_1 = std::make_shared<ge::OpDesc>();
op_1->SetType(REVERSE);
op_1->SetName("Reverse1");
op_1->AddInputDesc(ge::GeTensorDesc());
op_1->AddInputDesc(ge::GeTensorDesc());
op_1->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_1 = graph->AddNode(op_1);
ge::OpDescPtr op_2 = std::make_shared<ge::OpDesc>();
op_2->SetType(REVERSE);
op_2->SetName("Reverse2");
op_2->AddInputDesc(ge::GeTensorDesc());
op_2->AddOutputDesc(ge::GeTensorDesc());
op_2->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_2 = graph->AddNode(op_2);
ge::OpDescPtr op_3 = std::make_shared<ge::OpDesc>();
op_3->SetType(REVERSE);
op_3->SetName("Reverse3");
op_3->AddInputDesc(ge::GeTensorDesc());
op_3->AddInputDesc(ge::GeTensorDesc());
op_3->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_3 = graph->AddNode(op_3);
ge::OpDescPtr op_4 = std::make_shared<ge::OpDesc>();
op_4->SetType(REVERSE);
op_4->SetName("Reverse4");
op_4->AddInputDesc(ge::GeTensorDesc());
op_4->AddInputDesc(ge::GeTensorDesc());
op_4->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_4 = graph->AddNode(op_4);
ge::OpDescPtr op_5 = std::make_shared<ge::OpDesc>();
op_5->SetType(REVERSE);
op_5->SetName("Reverse5");
op_5->AddInputDesc(ge::GeTensorDesc());
op_5->AddOutputDesc(ge::GeTensorDesc());
op_5->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_5 = graph->AddNode(op_5);
ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
net_output_op->AddInputDesc(ge::GeTensorDesc());
net_output_op->AddOutputDesc(ge::GeTensorDesc());
ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
ge::GraphUtils::AddEdge(node_1->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_2->GetOutDataAnchor(0), node_1->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_5->GetOutDataAnchor(0), node_1->GetInDataAnchor(1));
ge::GraphUtils::AddEdge(node_4->GetOutDataAnchor(0), node_2->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_2->GetOutDataAnchor(1), node_3->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_5->GetOutDataAnchor(1), node_3->GetInDataAnchor(1));
ge::GraphUtils::AddEdge(node_3->GetOutDataAnchor(0), node_4->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node_4->GetInDataAnchor(1));
ge::GraphUtils::AddEdge(node_4->GetOutDataAnchor(1), node_5->GetInDataAnchor(0));
// construct two isolated node
ge::OpDescPtr op_6 = std::make_shared<ge::OpDesc>();
op_6->SetType(REVERSE);
op_6->SetName("Reverse");
op_6->AddInputDesc(ge::GeTensorDesc());
op_6->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_6 = graph->AddNode(op_6);
ge::OpDescPtr op_7 = std::make_shared<ge::OpDesc>();
op_7->SetType(REVERSE);
op_7->SetName("Reverse");
op_7->AddInputDesc(ge::GeTensorDesc());
op_7->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr node_7 = graph->AddNode(op_7);
uint64_t size_ori = graph->GetDirectNode().size();
PrunePass prune_pass;
vector<GraphPass *> passes = {&prune_pass};
}
// case7: one net path with two DATA node
TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_two_isolate_data_node) {
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
ge::OpDescPtr reverse_op = std::make_shared<ge::OpDesc>();
reverse_op->SetType(REVERSE);
reverse_op->SetName("Reverse");
reverse_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr reverse_node = graph->AddNode(reverse_op);
ge::OpDescPtr floor_op = std::make_shared<ge::OpDesc>();
floor_op->SetType(FLOOR);
floor_op->SetName("Floor");
floor_op->AddInputDesc(ge::GeTensorDesc());
floor_op->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr floor_node = graph->AddNode(floor_op);
ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
net_output_op->AddInputDesc(ge::GeTensorDesc());
net_output_op->AddOutputDesc(ge::GeTensorDesc());
ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
// construct one isolated DATA node (to be deleted)
ge::OpDescPtr reverse_op_1 = std::make_shared<ge::OpDesc>();
reverse_op_1->SetType(REVERSE);
reverse_op_1->SetName("Reverse1");
reverse_op_1->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr reverse_node_1 = graph->AddNode(reverse_op_1);
ge::GraphUtils::AddEdge(reverse_node->GetOutDataAnchor(0), floor_node->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(floor_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
// construct two isolated DATA nodes(to be not deleted)
ge::OpDescPtr data_op_1 = std::make_shared<ge::OpDesc>();
data_op_1->SetType(DATA);
data_op_1->SetName("data");
data_op_1->AddOutputDesc(ge::GeTensorDesc());
data_op_1->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr data_node_1 = graph->AddNode(data_op_1);
ge::OpDescPtr data_op_2 = std::make_shared<ge::OpDesc>();
data_op_2->SetType(DATA);
data_op_2->SetName("data1");
data_op_2->AddOutputDesc(ge::GeTensorDesc());
data_op_2->AddOutputDesc(ge::GeTensorDesc());
ge::NodePtr data_node = graph->AddNode(data_op_2);
uint64_t size_ori = graph->GetDirectNode().size();
PrunePass prune_pass;
std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} };
Status status = PassManager::Run(graph, passes);
uint64_t size = graph->GetDirectNode().size();
EXPECT_EQ(ge::SUCCESS, status);
EXPECT_EQ(size_ori, (size + 1));
// it should check net_out_put's input data node and input control node
auto control_vec = netoutput_node->GetInControlNodes();
EXPECT_EQ(control_vec.size(), 2);
// check control_vec contains only data node
for (auto node : control_vec) {
bool result = (node->GetName() == "data" || node->GetName() == "data1") ? true : false;
EXPECT_EQ(result, true);
}
auto data_vec = netoutput_node->GetInDataNodes();
EXPECT_EQ(data_vec.size(), 1);
// check data_vec contains only Floor node
for (auto node : data_vec) {
bool result = (node->GetName() == "Floor") ? true : false;
EXPECT_EQ(result, true);
}
}