You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/tests/ut/ge/graph/passes/pass_utils_unittest.cc

178 lines
7.1 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/pass_utils.h"
#include <gtest/gtest.h>
#include <vector>
#include "common/types.h"
#include "graph/types.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph_builder_utils.h"
#include "inc/kernel.h"
#include "inc/kernel_factory.h"
using namespace ge;
class UtestGraphPassesPassUtils : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};
class NodeBuilder {
public:
NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
ge::DataType data_type = DT_FLOAT) {
op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
return *this;
}
NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
ge::DataType data_type = DT_FLOAT) {
op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
return *this;
}
ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); }
private:
ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
ge::DataType data_type = DT_FLOAT) {
GeShape ge_shape{std::vector<int64_t>(shape)};
ge::GeTensorDescPtr tensor_desc = std::make_shared<ge::GeTensorDesc>();
tensor_desc->SetShape(ge_shape);
tensor_desc->SetFormat(format);
tensor_desc->SetDataType(data_type);
return tensor_desc;
}
ge::OpDescPtr op_desc_;
};
TEST_F(UtestGraphPassesPassUtils, set_out_node_weight) {
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
// data
ge::NodePtr node_data = NodeBuilder("data", DATA).AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT).Build(graph);
// const
ge::NodePtr node_const =
NodeBuilder("const", CONSTANT).AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT).Build(graph);
// relu
ge::NodePtr node_relu = NodeBuilder("node_relu1", RELU)
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.Build(graph);
// sinh
ge::NodePtr node_sinh = NodeBuilder("node_sinh", SINH)
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.Build(graph);
// relu
ge::NodePtr node_relu2 = NodeBuilder("node_relu2", RELU)
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.Build(graph);
// sinh
ge::NodePtr node_sinh2 = NodeBuilder("node_sinh2", SINH)
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.Build(graph);
// add edge
ge::GraphUtils::AddEdge(node_data->GetOutControlAnchor(), node_const->GetInControlAnchor());
ge::GraphUtils::AddEdge(node_const->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_sinh->GetInDataAnchor(0));
ge::GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_relu2->GetInControlAnchor());
ge::GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_sinh2->GetInDataAnchor(0));
for (auto node : graph->GetDirectNode()) {
if (node->GetType() == CONSTANT) {
int32_t weight[] = {1};
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
vector<GeTensorPtr> tensor_vec = {tensor};
OpDescUtils::SetWeights(node, tensor_vec);
}
if (!node->GetOutDataNodes().empty()) {
auto out_data_anchor = node->GetOutDataNodes().at(0)->GetOutDataAnchor(0);
Status status = PassUtils::SetOutNodeWeight(out_data_anchor, node);
EXPECT_EQ(SUCCESS, status);
}
}
}
// only some failure castes for coverage check
TEST_F(UtestGraphPassesPassUtils, is_constant_null) {
ge::NodePtr node = nullptr;
bool ret = PassUtils::IsConstant(node);
EXPECT_EQ(false, ret);
}
TEST_F(UtestGraphPassesPassUtils, get_in_data_node_fail) {
ge::NodePtr node = nullptr;
NodePtr in_data_node = PassUtils::GetInDataNode(node, 0);
EXPECT_EQ(nullptr, in_data_node);
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
// relu
ge::NodePtr node_relu = NodeBuilder("relu", RELU)
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.Build(graph);
NodePtr data_node = PassUtils::GetInDataNode(node_relu, 1);
EXPECT_EQ(nullptr, data_node);
}
TEST_F(UtestGraphPassesPassUtils, get_unique_in_data_anchor_index_failed) {
int invalid_index = -1;
ge::NodePtr node = nullptr;
int status = PassUtils::GetUniqueInDataAnchorIndex(node);
EXPECT_EQ(invalid_index, status);
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
// relu
ge::NodePtr node_relu = NodeBuilder("relu", RELU)
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.Build(graph);
int ret = PassUtils::GetUniqueInDataAnchorIndex(node_relu);
EXPECT_EQ(invalid_index, ret);
}
TEST_F(UtestGraphPassesPassUtils, unlink_node_with_ctrl_copy_fail) {
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
// relu
ge::NodePtr node_relu = NodeBuilder("relu", RELU)
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
.Build(graph);
Status status = PassUtils::UnlinkNodeWithControlCopy(node_relu, 1);
EXPECT_EQ(ge::SUCCESS, status);
Status ret = PassUtils::UnlinkNodeWithControlCopy(node_relu, 0);
EXPECT_EQ(ge::FAILED, ret);
}
TEST_F(UtestGraphPassesPassUtils, null_input) {
std::vector<NodePtr> deleted_nodes;
std::vector<NodePtr> end_nodes;
EXPECT_NE(PassUtils::RemoveInactiveBranchToMerge(nullptr, deleted_nodes, end_nodes), 0);
}