You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
178 lines
7.1 KiB
178 lines
7.1 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "graph/passes/pass_utils.h"
|
|
|
|
#include <gtest/gtest.h>
|
|
#include <vector>
|
|
|
|
#include "common/types.h"
|
|
#include "graph/types.h"
|
|
#include "graph/utils/graph_utils.h"
|
|
#include "graph/utils/op_desc_utils.h"
|
|
#include "graph_builder_utils.h"
|
|
#include "inc/kernel.h"
|
|
#include "inc/kernel_factory.h"
|
|
|
|
using namespace ge;
|
|
|
|
class UtestGraphPassesPassUtils : public testing::Test {
|
|
protected:
|
|
void SetUp() {}
|
|
|
|
void TearDown() {}
|
|
};
|
|
|
|
class NodeBuilder {
|
|
public:
|
|
NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
|
|
|
|
NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
|
|
ge::DataType data_type = DT_FLOAT) {
|
|
op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
|
|
return *this;
|
|
}
|
|
|
|
NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
|
|
ge::DataType data_type = DT_FLOAT) {
|
|
op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
|
|
return *this;
|
|
}
|
|
|
|
ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); }
|
|
|
|
private:
|
|
ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
|
|
ge::DataType data_type = DT_FLOAT) {
|
|
GeShape ge_shape{std::vector<int64_t>(shape)};
|
|
ge::GeTensorDescPtr tensor_desc = std::make_shared<ge::GeTensorDesc>();
|
|
tensor_desc->SetShape(ge_shape);
|
|
tensor_desc->SetFormat(format);
|
|
tensor_desc->SetDataType(data_type);
|
|
return tensor_desc;
|
|
}
|
|
|
|
ge::OpDescPtr op_desc_;
|
|
};
|
|
|
|
TEST_F(UtestGraphPassesPassUtils, set_out_node_weight) {
|
|
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
|
|
// data
|
|
ge::NodePtr node_data = NodeBuilder("data", DATA).AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT).Build(graph);
|
|
// const
|
|
ge::NodePtr node_const =
|
|
NodeBuilder("const", CONSTANT).AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT).Build(graph);
|
|
// relu
|
|
ge::NodePtr node_relu = NodeBuilder("node_relu1", RELU)
|
|
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.Build(graph);
|
|
// sinh
|
|
ge::NodePtr node_sinh = NodeBuilder("node_sinh", SINH)
|
|
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.Build(graph);
|
|
// relu
|
|
ge::NodePtr node_relu2 = NodeBuilder("node_relu2", RELU)
|
|
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.Build(graph);
|
|
// sinh
|
|
ge::NodePtr node_sinh2 = NodeBuilder("node_sinh2", SINH)
|
|
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.Build(graph);
|
|
|
|
// add edge
|
|
ge::GraphUtils::AddEdge(node_data->GetOutControlAnchor(), node_const->GetInControlAnchor());
|
|
ge::GraphUtils::AddEdge(node_const->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0));
|
|
ge::GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_sinh->GetInDataAnchor(0));
|
|
ge::GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_relu2->GetInControlAnchor());
|
|
ge::GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_sinh2->GetInDataAnchor(0));
|
|
|
|
for (auto node : graph->GetDirectNode()) {
|
|
if (node->GetType() == CONSTANT) {
|
|
int32_t weight[] = {1};
|
|
GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
|
|
GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
|
|
vector<GeTensorPtr> tensor_vec = {tensor};
|
|
OpDescUtils::SetWeights(node, tensor_vec);
|
|
}
|
|
if (!node->GetOutDataNodes().empty()) {
|
|
auto out_data_anchor = node->GetOutDataNodes().at(0)->GetOutDataAnchor(0);
|
|
Status status = PassUtils::SetOutNodeWeight(out_data_anchor, node);
|
|
EXPECT_EQ(SUCCESS, status);
|
|
}
|
|
}
|
|
}
|
|
|
|
// only some failure castes for coverage check
|
|
TEST_F(UtestGraphPassesPassUtils, is_constant_null) {
|
|
ge::NodePtr node = nullptr;
|
|
bool ret = PassUtils::IsConstant(node);
|
|
EXPECT_EQ(false, ret);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesPassUtils, get_in_data_node_fail) {
|
|
ge::NodePtr node = nullptr;
|
|
NodePtr in_data_node = PassUtils::GetInDataNode(node, 0);
|
|
EXPECT_EQ(nullptr, in_data_node);
|
|
|
|
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
|
|
// relu
|
|
ge::NodePtr node_relu = NodeBuilder("relu", RELU)
|
|
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.Build(graph);
|
|
NodePtr data_node = PassUtils::GetInDataNode(node_relu, 1);
|
|
EXPECT_EQ(nullptr, data_node);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesPassUtils, get_unique_in_data_anchor_index_failed) {
|
|
int invalid_index = -1;
|
|
ge::NodePtr node = nullptr;
|
|
int status = PassUtils::GetUniqueInDataAnchorIndex(node);
|
|
EXPECT_EQ(invalid_index, status);
|
|
|
|
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
|
|
// relu
|
|
ge::NodePtr node_relu = NodeBuilder("relu", RELU)
|
|
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.Build(graph);
|
|
int ret = PassUtils::GetUniqueInDataAnchorIndex(node_relu);
|
|
EXPECT_EQ(invalid_index, ret);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesPassUtils, unlink_node_with_ctrl_copy_fail) {
|
|
ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
|
|
// relu
|
|
ge::NodePtr node_relu = NodeBuilder("relu", RELU)
|
|
.AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
|
|
.Build(graph);
|
|
Status status = PassUtils::UnlinkNodeWithControlCopy(node_relu, 1);
|
|
EXPECT_EQ(ge::SUCCESS, status);
|
|
Status ret = PassUtils::UnlinkNodeWithControlCopy(node_relu, 0);
|
|
EXPECT_EQ(ge::FAILED, ret);
|
|
}
|
|
|
|
TEST_F(UtestGraphPassesPassUtils, null_input) {
|
|
std::vector<NodePtr> deleted_nodes;
|
|
std::vector<NodePtr> end_nodes;
|
|
EXPECT_NE(PassUtils::RemoveInactiveBranchToMerge(nullptr, deleted_nodes, end_nodes), 0);
|
|
}
|