You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
868 lines
36 KiB
868 lines
36 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "graph/passes/net_output_pass.h"
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include "common/ge_inner_error_codes.h"
|
|
#include "common/types.h"
|
|
#include "ge/ge_api.h"
|
|
#include "graph/compute_graph.h"
|
|
#include "graph/debug/graph_debug.h"
|
|
#include "graph/manager/graph_manager.h"
|
|
#include "graph/manager/graph_manager_utils.h"
|
|
#include "graph/operator_reg.h"
|
|
#include "graph/utils/op_desc_utils.h"
|
|
#include "inc/pass_manager.h"
|
|
#include "init/gelib.h"
|
|
#include "opskernel_manager/ops_kernel_manager.h"
|
|
|
|
using namespace std;
|
|
using namespace testing;
|
|
using namespace ge;
|
|
|
|
class UtestGraphPassesNetOutputPass : public testing::Test {
|
|
protected:
|
|
void SetUp() {}
|
|
void TearDown() {}
|
|
};
|
|
|
|
ge::ComputeGraphPtr BuildClearWeightGraph(void) {
|
|
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
|
|
ge::OpDescPtr cast_op = std::make_shared<ge::OpDesc>();
|
|
cast_op->SetType(CAST);
|
|
cast_op->SetName("Cast1");
|
|
cast_op->AddInputDesc(ge::GeTensorDesc());
|
|
cast_op->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr cast_node = graph->AddNode(cast_op);
|
|
|
|
ge::OpDescPtr const_op = std::make_shared<ge::OpDesc>();
|
|
const_op->SetType(CONSTANT);
|
|
const_op->SetName("Const1");
|
|
const_op->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr const_node = graph->AddNode(const_op);
|
|
|
|
ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), cast_node->GetInDataAnchor(0));
|
|
|
|
return graph;
|
|
}
|
|
|
|
ge::ComputeGraphPtr build_graph(bool with_leaf_node = false) {
|
|
ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
|
|
ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
|
|
data_op->SetType(DATA);
|
|
data_op->SetName("Data1");
|
|
data_op->AddInputDesc(ge::GeTensorDesc());
|
|
data_op->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr data1 = graph->AddNode(data_op);
|
|
|
|
ge::OpDescPtr relu_op1 = std::make_shared<ge::OpDesc>();
|
|
relu_op1->SetType(ACTIVATION);
|
|
relu_op1->SetName("Relu1");
|
|
relu_op1->AddInputDesc(ge::GeTensorDesc());
|
|
relu_op1->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr relu1 = graph->AddNode(relu_op1);
|
|
|
|
ge::OpDescPtr relu_op2 = std::make_shared<ge::OpDesc>();
|
|
relu_op2->SetType(RELU);
|
|
relu_op2->SetName("Relu2");
|
|
relu_op2->AddInputDesc(ge::GeTensorDesc());
|
|
relu_op2->AddOutputDesc(ge::GeTensorDesc());
|
|
relu_op2->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr relu2 = graph->AddNode(relu_op2);
|
|
|
|
ge::OpDescPtr relu_op3 = std::make_shared<ge::OpDesc>();
|
|
relu_op3->SetType(ACTIVATION);
|
|
relu_op3->SetName("Relu3");
|
|
relu_op3->AddInputDesc(ge::GeTensorDesc());
|
|
relu_op3->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr relu3;
|
|
if (with_leaf_node == true) {
|
|
relu3 = graph->AddNode(relu_op3);
|
|
}
|
|
|
|
ge::OpDescPtr mul_op = std::make_shared<ge::OpDesc>();
|
|
mul_op->SetType(MUL);
|
|
mul_op->SetName("Mul");
|
|
mul_op->AddInputDesc(ge::GeTensorDesc());
|
|
mul_op->AddInputDesc(ge::GeTensorDesc());
|
|
mul_op->AddOutputDesc(ge::GeTensorDesc());
|
|
mul_op->AddOutputDesc(ge::GeTensorDesc());
|
|
mul_op->AddOutputDesc(ge::GeTensorDesc());
|
|
mul_op->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr mul = graph->AddNode(mul_op);
|
|
|
|
ge::OpDescPtr mul_op1 = std::make_shared<ge::OpDesc>();
|
|
mul_op1->SetType(MUL);
|
|
mul_op1->SetName("Mul1");
|
|
mul_op1->AddInputDesc(ge::GeTensorDesc());
|
|
mul_op1->AddInputDesc(ge::GeTensorDesc());
|
|
mul_op1->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr mul1 = graph->AddNode(mul_op1);
|
|
|
|
ge::OpDescPtr mul_op2 = std::make_shared<ge::OpDesc>();
|
|
mul_op2->SetType(MUL);
|
|
mul_op2->SetName("Mul2");
|
|
mul_op2->AddInputDesc(ge::GeTensorDesc());
|
|
mul_op2->AddInputDesc(ge::GeTensorDesc());
|
|
mul_op2->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr mul2 = graph->AddNode(mul_op2);
|
|
|
|
ge::OpDescPtr fc_op = std::make_shared<ge::OpDesc>();
|
|
fc_op->SetType(FULL_CONNECTION);
|
|
fc_op->SetName("FullConnection");
|
|
fc_op->AddInputDesc(ge::GeTensorDesc());
|
|
fc_op->AddOutputDesc(ge::GeTensorDesc());
|
|
fc_op->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr fc = graph->AddNode(fc_op);
|
|
|
|
ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), relu1->GetInDataAnchor(0));
|
|
ge::GraphUtils::AddEdge(relu1->GetOutDataAnchor(0), fc->GetInDataAnchor(0));
|
|
ge::GraphUtils::AddEdge(fc->GetOutDataAnchor(0), relu2->GetInDataAnchor(0));
|
|
if (with_leaf_node == true) {
|
|
ge::GraphUtils::AddEdge(fc->GetOutDataAnchor(1), relu3->GetInDataAnchor(0));
|
|
}
|
|
ge::GraphUtils::AddEdge(relu2->GetOutDataAnchor(0), mul->GetInDataAnchor(0));
|
|
ge::GraphUtils::AddEdge(relu2->GetOutDataAnchor(1), mul->GetInDataAnchor(1));
|
|
ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(0), mul1->GetInDataAnchor(0));
|
|
ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(1), mul1->GetInDataAnchor(1));
|
|
ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(2), mul2->GetInDataAnchor(0));
|
|
ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(3), mul2->GetInDataAnchor(1));
|
|
|
|
return graph;
|
|
}
|
|
TEST_F(UtestGraphPassesNetOutputPass, add_ctrl_edge_for_netout_from_leaf_success) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph(true);
|
|
|
|
// construct targets
|
|
ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
ge::NodePtr relu3 = compute_graph->FindNode("Relu3");
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes);
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
// check contain netoutput
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
/// check input data node of netoutput
|
|
/// when output and targets set conflicts each other , output set is prio
|
|
/// Check data input
|
|
int input_data_node_num = net_out_node->GetInDataNodes().size();
|
|
EXPECT_EQ(input_data_node_num, 1);
|
|
|
|
std::vector<string> expect_input_data_result{"Relu3"};
|
|
for (auto node : net_out_node->GetInDataNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
|
|
if (iter != expect_input_data_result.end()) {
|
|
expect_input_data_result.erase(iter);
|
|
}
|
|
}
|
|
input_data_node_num = expect_input_data_result.size();
|
|
EXPECT_EQ(input_data_node_num, 0);
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
EXPECT_EQ(control_node_num, 2);
|
|
|
|
std::vector<string> expect_result{"Mul1", "Mul2"};
|
|
for (auto node : net_out_node->GetInControlNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_result.begin(), expect_result.end(), name);
|
|
if (iter != expect_result.end()) {
|
|
expect_result.erase(iter);
|
|
}
|
|
}
|
|
control_node_num = expect_result.size();
|
|
EXPECT_EQ(control_node_num, 0);
|
|
}
|
|
TEST_F(UtestGraphPassesNetOutputPass, only_target_node_success) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
// construct targets
|
|
ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
std::vector<ge::NodePtr> target_nodes = {mul1, mul2};
|
|
compute_graph->SetGraphTargetNodesInfo(target_nodes);
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
// check contain netoutput
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
/// check input data node of netoutput
|
|
/// Check data input
|
|
int input_data_node_num = net_out_node->GetInDataNodes().size();
|
|
EXPECT_EQ(input_data_node_num, 0);
|
|
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
EXPECT_EQ(control_node_num, 2);
|
|
|
|
std::vector<string> expect_result{"Mul1", "Mul2"};
|
|
for (auto node : net_out_node->GetInControlNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_result.begin(), expect_result.end(), name);
|
|
if (iter != expect_result.end()) {
|
|
expect_result.erase(iter);
|
|
}
|
|
}
|
|
control_node_num = expect_result.size();
|
|
EXPECT_EQ(control_node_num, 0);
|
|
}
|
|
TEST_F(UtestGraphPassesNetOutputPass, targets_with_retval_success) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
// Imitate the output node of _Retval issued
|
|
ge::OpDescPtr retval_node_desc1 = std::make_shared<ge::OpDesc>("reval_node1", FRAMEWORKOP);
|
|
retval_node_desc1->AddInputDesc(ge::GeTensorDesc());
|
|
(void)ge::AttrUtils::SetStr(retval_node_desc1, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
|
|
(void)ge::AttrUtils::SetInt(retval_node_desc1, RETVAL_ATTR_NAME_INDEX, 0);
|
|
ge::NodePtr retval_node1 = compute_graph->AddNode(retval_node_desc1);
|
|
EXPECT_NE(retval_node1, nullptr);
|
|
|
|
ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
|
|
retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
|
|
(void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
|
|
(void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 1);
|
|
ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
|
|
EXPECT_NE(retval_node2, nullptr);
|
|
// construct targets
|
|
std::vector<ge::NodePtr> target_nodes = {retval_node1, retval_node2};
|
|
compute_graph->SetGraphTargetNodesInfo(target_nodes);
|
|
|
|
for (NodePtr node : compute_graph->GetDirectNode()) {
|
|
if (node->GetName() == "Mul1") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node1->GetInDataAnchor(0));
|
|
} else if (node->GetName() == "Mul2") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
|
|
}
|
|
}
|
|
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
// check contain netoutput
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
/// check input data node of netoutput
|
|
/// Check data input
|
|
int input_data_node_num = net_out_node->GetInDataNodes().size();
|
|
EXPECT_EQ(input_data_node_num, 0);
|
|
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
EXPECT_EQ(control_node_num, 2);
|
|
|
|
std::vector<string> expect_result{"Mul1", "Mul2"};
|
|
for (auto node : net_out_node->GetInControlNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_result.begin(), expect_result.end(), name);
|
|
if (iter != expect_result.end()) {
|
|
expect_result.erase(iter);
|
|
}
|
|
}
|
|
control_node_num = expect_result.size();
|
|
EXPECT_EQ(control_node_num, 0);
|
|
// Check the deletion of _Retval node
|
|
retval_node1 = compute_graph->FindNode("reval_node1");
|
|
EXPECT_EQ(retval_node1, nullptr);
|
|
retval_node2 = compute_graph->FindNode("reval_node2");
|
|
EXPECT_EQ(retval_node2, nullptr);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesNetOutputPass, output_node_and_target_node_no_duplicate_success) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph(true);
|
|
|
|
// construct targets
|
|
ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
std::vector<ge::NodePtr> target_nodes = {mul1, mul2};
|
|
compute_graph->SetGraphTargetNodesInfo(target_nodes);
|
|
ge::NodePtr relu3 = compute_graph->FindNode("Relu3");
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes);
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
// check contain netoutput
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
/// check input data node of netoutput
|
|
/// when output and targets set conflicts each other , output set is prio
|
|
/// Check data input
|
|
int input_data_node_num = net_out_node->GetInDataNodes().size();
|
|
EXPECT_EQ(input_data_node_num, 1);
|
|
|
|
std::vector<string> expect_input_data_result{"Relu3"};
|
|
for (auto node : net_out_node->GetInDataNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
|
|
if (iter != expect_input_data_result.end()) {
|
|
expect_input_data_result.erase(iter);
|
|
}
|
|
}
|
|
input_data_node_num = expect_input_data_result.size();
|
|
EXPECT_EQ(input_data_node_num, 0);
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
EXPECT_EQ(control_node_num, 2);
|
|
|
|
std::vector<string> expect_result{"Mul1", "Mul2"};
|
|
for (auto node : net_out_node->GetInControlNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_result.begin(), expect_result.end(), name);
|
|
if (iter != expect_result.end()) {
|
|
expect_result.erase(iter);
|
|
}
|
|
}
|
|
control_node_num = expect_result.size();
|
|
EXPECT_EQ(control_node_num, 0);
|
|
}
|
|
TEST_F(UtestGraphPassesNetOutputPass, output_node_and_target_node_duplicate_success) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
// construct targets
|
|
ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
std::vector<ge::NodePtr> target_nodes = {mul2};
|
|
compute_graph->SetGraphTargetNodesInfo(target_nodes);
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes);
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
// check contain netoutput
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
/// check input data node of netoutput
|
|
/// Check data input
|
|
int input_data_node_num = net_out_node->GetInDataNodes().size();
|
|
EXPECT_EQ(input_data_node_num, 2);
|
|
|
|
std::vector<string> expect_input_data_result{"Mul1"};
|
|
for (auto node : net_out_node->GetInDataNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
|
|
if (iter != expect_input_data_result.end()) {
|
|
expect_input_data_result.erase(iter);
|
|
}
|
|
}
|
|
input_data_node_num = expect_input_data_result.size();
|
|
EXPECT_EQ(input_data_node_num, 0);
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
EXPECT_EQ(control_node_num, 0);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_target_node_success) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
|
|
netout->AddInputDesc(ge::GeTensorDesc());
|
|
netout->AddInputDesc(ge::GeTensorDesc());
|
|
netout->AddOutputDesc(ge::GeTensorDesc());
|
|
netout->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr netout_node = compute_graph->AddNode(netout);
|
|
EXPECT_NE(netout_node, nullptr);
|
|
|
|
for (NodePtr node : compute_graph->GetDirectNode()) {
|
|
if (node->GetName() == "Mul1") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
|
|
} else if (node->GetName() == "Mul2") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(1));
|
|
}
|
|
}
|
|
// construct targets
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
std::vector<ge::NodePtr> target_nodes = {mul2};
|
|
compute_graph->SetGraphTargetNodesInfo(target_nodes);
|
|
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
// check contain netoutput
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
/// check input data node of netoutput
|
|
/// Check data input
|
|
int input_data_node_num = net_out_node->GetInDataNodes().size();
|
|
EXPECT_EQ(input_data_node_num, 1);
|
|
|
|
std::vector<string> expect_input_data_result{"Mul1"};
|
|
for (auto node : net_out_node->GetInDataNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
|
|
if (iter != expect_input_data_result.end()) {
|
|
expect_input_data_result.erase(iter);
|
|
}
|
|
}
|
|
input_data_node_num = expect_input_data_result.size();
|
|
EXPECT_EQ(input_data_node_num, 0);
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
EXPECT_EQ(control_node_num, 1);
|
|
std::vector<string> expect_control_data_result{"Mul2"};
|
|
for (auto node : net_out_node->GetInControlNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_control_data_result.begin(), expect_control_data_result.end(), name);
|
|
if (iter != expect_control_data_result.end()) {
|
|
expect_control_data_result.erase(iter);
|
|
}
|
|
}
|
|
control_node_num = expect_control_data_result.size();
|
|
EXPECT_EQ(control_node_num, 0);
|
|
}
|
|
/// graph have netoutput node.User set outputnodes and target nodes at the same time.output nodes
|
|
/// include one common node with target nodes.
|
|
/// Notice: output nodes set is more prio
|
|
TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_target_node_success_1) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
|
|
netout->AddInputDesc(ge::GeTensorDesc());
|
|
netout->AddInputDesc(ge::GeTensorDesc());
|
|
netout->AddOutputDesc(ge::GeTensorDesc());
|
|
netout->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr netout_node = compute_graph->AddNode(netout);
|
|
EXPECT_NE(netout_node, nullptr);
|
|
|
|
for (NodePtr node : compute_graph->GetDirectNode()) {
|
|
if (node->GetName() == "Mul1") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
|
|
} else if (node->GetName() == "Mul2") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(1));
|
|
}
|
|
}
|
|
// construct targets
|
|
ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
std::vector<ge::NodePtr> target_nodes = {mul2};
|
|
compute_graph->SetGraphTargetNodesInfo(target_nodes);
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes);
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
// check contain netoutput
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
/// check input data node of netoutput
|
|
/// Check data input
|
|
int input_data_node_num = net_out_node->GetInDataNodes().size();
|
|
EXPECT_EQ(input_data_node_num, 2);
|
|
|
|
std::vector<string> expect_input_data_result{"Mul1", "Mul2"};
|
|
for (auto node : net_out_node->GetInDataNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
|
|
if (iter != expect_input_data_result.end()) {
|
|
expect_input_data_result.erase(iter);
|
|
}
|
|
}
|
|
input_data_node_num = expect_input_data_result.size();
|
|
EXPECT_EQ(input_data_node_num, 0);
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
EXPECT_EQ(control_node_num, 0);
|
|
}
|
|
/// graph have netoutput node.User set outputnodes and target nodes at the same time.output nodes
|
|
/// include one common node with target nodes.
|
|
/// Notice: output nodes set is more prio
|
|
TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_target_node_success_2) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph(true);
|
|
|
|
ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
|
|
netout->AddInputDesc(ge::GeTensorDesc());
|
|
netout->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr netout_node = compute_graph->AddNode(netout);
|
|
EXPECT_NE(netout_node, nullptr);
|
|
|
|
for (const auto &node : compute_graph->GetDirectNode()) {
|
|
if (node->GetName() == "Mul1") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
|
|
}
|
|
if (node->GetName() == "Mul2") {
|
|
GraphUtils::AddEdge(node->GetOutControlAnchor(), netout_node->GetInControlAnchor());
|
|
}
|
|
if (node->GetName() == "Relu3") {
|
|
GraphUtils::AddEdge(node->GetOutControlAnchor(), netout_node->GetInControlAnchor());
|
|
}
|
|
}
|
|
// construct targets
|
|
ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
std::vector<ge::NodePtr> target_nodes = {mul2};
|
|
compute_graph->SetGraphTargetNodesInfo(target_nodes);
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes);
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
// check contain netoutput
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
/// check input data node of netoutput
|
|
/// Check data input
|
|
int input_data_node_num = net_out_node->GetInDataNodes().size();
|
|
EXPECT_EQ(input_data_node_num, 1);
|
|
|
|
std::vector<string> expect_input_data_result{"Mul1"};
|
|
for (const auto &node : net_out_node->GetInDataNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
|
|
if (iter != expect_input_data_result.end()) {
|
|
expect_input_data_result.erase(iter);
|
|
}
|
|
}
|
|
input_data_node_num = expect_input_data_result.size();
|
|
EXPECT_EQ(input_data_node_num, 0);
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
EXPECT_EQ(control_node_num, 2);
|
|
std::vector<string> expect_control_data_result{"Mul2", "Relu3"};
|
|
for (const auto &node : net_out_node->GetInControlNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_control_data_result.begin(), expect_control_data_result.end(), name);
|
|
if (iter != expect_control_data_result.end()) {
|
|
expect_control_data_result.erase(iter);
|
|
}
|
|
}
|
|
control_node_num = expect_control_data_result.size();
|
|
EXPECT_EQ(control_node_num, 0);
|
|
}
|
|
/// graph have netoutput node.User set outputnodes and target nodes at the same time.output nodes
|
|
/// include one common node with target nodes.
|
|
/// Notice: output nodes set is more prio
|
|
TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_target_node_success_3) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
|
|
netout->AddInputDesc(ge::GeTensorDesc());
|
|
netout->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr netout_node = compute_graph->AddNode(netout);
|
|
EXPECT_NE(netout_node, nullptr);
|
|
|
|
for (const auto &node : compute_graph->GetDirectNode()) {
|
|
if (node->GetName() == "Mul1") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
|
|
}
|
|
if (node->GetName() == "Mul2") {
|
|
GraphUtils::AddEdge(node->GetOutControlAnchor(), netout_node->GetInControlAnchor());
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInControlAnchor());
|
|
}
|
|
}
|
|
// construct targets
|
|
ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
std::vector<ge::NodePtr> target_nodes = {mul2};
|
|
compute_graph->SetGraphTargetNodesInfo(target_nodes);
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes);
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
// check contain netoutput
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
/// check input data node of netoutput
|
|
/// Check data input
|
|
int input_data_node_num = net_out_node->GetInDataNodes().size();
|
|
EXPECT_EQ(input_data_node_num, 1);
|
|
|
|
std::vector<string> expect_input_data_result{"Mul1"};
|
|
for (const auto &node : net_out_node->GetInDataNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
|
|
if (iter != expect_input_data_result.end()) {
|
|
expect_input_data_result.erase(iter);
|
|
}
|
|
}
|
|
input_data_node_num = expect_input_data_result.size();
|
|
EXPECT_EQ(input_data_node_num, 0);
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
EXPECT_EQ(control_node_num, 1);
|
|
std::vector<string> expect_control_data_result{"Mul2"};
|
|
for (const auto &node : net_out_node->GetInControlNodes()) {
|
|
auto name = node->GetName();
|
|
auto iter = std::find(expect_control_data_result.begin(), expect_control_data_result.end(), name);
|
|
if (iter != expect_control_data_result.end()) {
|
|
expect_control_data_result.erase(iter);
|
|
}
|
|
}
|
|
control_node_num = expect_control_data_result.size();
|
|
EXPECT_EQ(control_node_num, 0);
|
|
}
|
|
TEST_F(UtestGraphPassesNetOutputPass, no_output_no_target_no_retval_success) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
// Construct specified output
|
|
ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes);
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesNetOutputPass, user_out_node_success) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
// Construct specified output
|
|
ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes);
|
|
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
|
|
// Check data input
|
|
string str;
|
|
for (ge::NodePtr input_data_node : net_out_node->GetInDataNodes()) {
|
|
str += input_data_node->GetName() + ";";
|
|
}
|
|
EXPECT_EQ(str, "Mul1;Mul2;");
|
|
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
|
|
EXPECT_EQ(control_node_num, 0);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesNetOutputPass, retval_node_for_out_success) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
// Imitate the output node of _Retval issued
|
|
ge::OpDescPtr retval_node_desc1 = std::make_shared<ge::OpDesc>("reval_node1", FRAMEWORKOP);
|
|
retval_node_desc1->AddInputDesc(ge::GeTensorDesc());
|
|
(void)ge::AttrUtils::SetStr(retval_node_desc1, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
|
|
(void)ge::AttrUtils::SetInt(retval_node_desc1, RETVAL_ATTR_NAME_INDEX, 0);
|
|
ge::NodePtr retval_node1 = compute_graph->AddNode(retval_node_desc1);
|
|
EXPECT_NE(retval_node1, nullptr);
|
|
|
|
ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
|
|
retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
|
|
(void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
|
|
(void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 1);
|
|
ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
|
|
EXPECT_NE(retval_node2, nullptr);
|
|
|
|
for (NodePtr node : compute_graph->GetDirectNode()) {
|
|
if (node->GetName() == "Mul1") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node1->GetInDataAnchor(0));
|
|
} else if (node->GetName() == "Mul2") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
|
|
}
|
|
}
|
|
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
|
|
// Check data input
|
|
string str;
|
|
for (ge::NodePtr input_data_node : net_out_node->GetInDataNodes()) {
|
|
str += input_data_node->GetName() + ";";
|
|
}
|
|
EXPECT_EQ(str, "Mul1;Mul2;");
|
|
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
EXPECT_EQ(control_node_num, 0);
|
|
|
|
// Check the deletion of _Retval node
|
|
retval_node1 = compute_graph->FindNode("reval_node1");
|
|
EXPECT_EQ(retval_node1, nullptr);
|
|
retval_node2 = compute_graph->FindNode("reval_node2");
|
|
EXPECT_EQ(retval_node2, nullptr);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesNetOutputPass, check_order_and_const_flag_success) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
ge::OpDescPtr const_node_desc = std::make_shared<ge::OpDesc>("const_output", CONSTANT);
|
|
const_node_desc->AddOutputDesc(ge::GeTensorDesc());
|
|
ge::NodePtr const_node = compute_graph->AddNode(const_node_desc);
|
|
EXPECT_NE(const_node, nullptr);
|
|
NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
EXPECT_NE(mul1, nullptr);
|
|
GraphUtils::AddEdge(mul1->GetOutControlAnchor(), const_node->GetInControlAnchor());
|
|
|
|
// Construct specified output
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{const_node, 0}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes);
|
|
|
|
ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
|
|
retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
|
|
(void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
|
|
(void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 0);
|
|
ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
|
|
EXPECT_NE(retval_node2, nullptr);
|
|
NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
EXPECT_NE(mul2, nullptr);
|
|
GraphUtils::AddEdge(mul2->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
|
|
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_NE(net_out_node, nullptr);
|
|
|
|
// Check data input
|
|
string str;
|
|
for (ge::NodePtr input_data_node : net_out_node->GetInDataNodes()) {
|
|
str += input_data_node->GetName() + ";";
|
|
}
|
|
EXPECT_EQ(str, "const_output;Mul2;");
|
|
|
|
// Check control input
|
|
int control_node_num = net_out_node->GetInControlNodes().size();
|
|
EXPECT_EQ(control_node_num, 0);
|
|
|
|
// Check is_input_const flag
|
|
std::vector<bool> is_input_const = net_out_node->GetOpDesc()->GetIsInputConst();
|
|
EXPECT_EQ(is_input_const.size(), 2);
|
|
EXPECT_EQ(is_input_const[0], true);
|
|
EXPECT_EQ(is_input_const[1], false);
|
|
|
|
// Check the deletion of _Retval node
|
|
retval_node2 = compute_graph->FindNode("reval_node2");
|
|
EXPECT_EQ(retval_node2, nullptr);
|
|
}
|
|
|
|
/*
|
|
TEST_F(UtestGraphPassesNetOutputPass, out_node_check_fail) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
// Construct specified output
|
|
ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_invalid_name = {{nullptr, 0}, {mul2, 0}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes_invalid_name);
|
|
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::INTERNAL_ERROR);
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_EQ(net_out_node, nullptr);
|
|
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_invalid_index = {{mul1, 0}, {mul2, 100}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes_invalid_index);
|
|
|
|
status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::INTERNAL_ERROR);
|
|
net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_EQ(net_out_node, nullptr);
|
|
}
|
|
*/
|
|
|
|
TEST_F(UtestGraphPassesNetOutputPass, retval_node_check_fail) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
// Imitate the output node of _Retval issued
|
|
ge::OpDescPtr retval_node_desc1 = std::make_shared<ge::OpDesc>("reval_node1", FRAMEWORKOP);
|
|
retval_node_desc1->AddInputDesc(ge::GeTensorDesc());
|
|
(void)ge::AttrUtils::SetStr(retval_node_desc1, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
|
|
(void)ge::AttrUtils::SetInt(retval_node_desc1, RETVAL_ATTR_NAME_INDEX, 0);
|
|
ge::NodePtr retval_node1 = compute_graph->AddNode(retval_node_desc1);
|
|
EXPECT_NE(retval_node1, nullptr);
|
|
|
|
ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
|
|
retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
|
|
(void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
|
|
(void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 0);
|
|
ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
|
|
EXPECT_NE(retval_node2, nullptr);
|
|
|
|
for (NodePtr node : compute_graph->GetDirectNode()) {
|
|
if (node->GetName() == "Mul1") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node1->GetInDataAnchor(0));
|
|
} else if (node->GetName() == "Mul2") {
|
|
GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
|
|
}
|
|
}
|
|
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::INTERNAL_ERROR);
|
|
NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
|
|
EXPECT_EQ(net_out_node, nullptr);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesNetOutputPass, out_node_update_desc_check_fail) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
|
|
ge::NodePtr netout_node = compute_graph->AddNode(netout);
|
|
EXPECT_NE(netout_node, nullptr);
|
|
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::INTERNAL_ERROR);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesNetOutputPass, out_node_remove_check_fail) {
|
|
ge::ComputeGraphPtr compute_graph = build_graph();
|
|
|
|
// Construct specified output
|
|
ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
|
|
ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
|
|
std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
|
|
compute_graph->SetGraphOutNodesInfo(output_nodes);
|
|
mul1->GetInDataAnchor(0)->UnlinkAll();
|
|
mul1->GetInDataAnchor(1)->UnlinkAll();
|
|
GraphUtils::RemoveNodeWithoutRelink(compute_graph, mul1);
|
|
mul1 = compute_graph->FindNode("Mul1");
|
|
EXPECT_EQ(mul1, nullptr);
|
|
|
|
ge::PassManager pass_managers;
|
|
pass_managers.AddPass("", new (std::nothrow) NetOutputPass);
|
|
Status status = pass_managers.Run(compute_graph);
|
|
EXPECT_EQ(status, ge::SUCCESS);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesNetOutputPass, clear_weight) {
|
|
ge::ComputeGraphPtr compute_graph = BuildClearWeightGraph();
|
|
auto cast = compute_graph->FindNode("Cast1");
|
|
Status ret = ge::OpDescUtils::ClearWeights(cast);
|
|
EXPECT_EQ(ge::SUCCESS, ret);
|
|
}
|