You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/tests/ut/ge/graph/passes/guarantee_const_pass_unitte...

221 lines
8.5 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <unordered_map>
#include <vector>
#define protected public
#define private public
#include "graph/passes/guarantee_const_pass.h"
#include "../ops_stub.h"
#include "common/ge_inner_error_codes.h"
#include "common/types.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "inc/pass_manager.h"
#undef protected
#undef private
using namespace testing;
using namespace ge;
using namespace std;
// To check whether the shape of output is correct or not
#define TEST_OPERATOR(op_, input_shapes, output_shapes) \
{ \
auto op = op_; \
for (auto input_pair : input_shapes) SetInputShape(op, input_pair.first, input_pair.second); \
op.InferShapeAndType(); \
for (auto output_pair : output_shapes) CheckOutputShape(op, output_pair.first, output_pair.second); \
}
#define LOOP_VEC(v) for (unsigned i = 0; i < v.size(); i++)
class UtestGraphPassesGuaranteeConstPass : public testing::Test {
protected:
void SetUp() { init(); }
void TearDown() { destory(); }
private:
void init() { guarantee_const_op_remove_pass_ = new ::ge::GuaranteeConstPass(); }
void destory() {
delete guarantee_const_op_remove_pass_;
guarantee_const_op_remove_pass_ = NULL;
}
protected:
ge::GuaranteeConstPass *guarantee_const_op_remove_pass_;
void SetInputShape(Operator op, string name, vector<int64_t> shape) {
TensorDesc td = op.GetInputDesc(name);
td.SetShape(ge::Shape(shape));
op.UpdateInputDesc(name, td);
}
void CheckOutputShape(Operator op, string name, vector<int64_t> shape) {
ge::Shape s = op.GetOutputDesc(name).GetShape();
EXPECT_EQ(s.GetDims().size(), shape.size());
LOOP_VEC(shape) EXPECT_EQ(s.GetDim(i), shape[i]);
}
/// Init the node which will be passed in graph, isMultiInput represents whether using more than
/// one data anchor or not.
NodePtr init_node(ComputeGraphPtr graph, vector<int64_t> dims_vec, vector<int32_t> data_vec, bool isMultiInput,
string type) {
// middle
OpDescPtr op_def = std::make_shared<OpDesc>("op_def", type);
OpDescPtr in_op_def = std::make_shared<OpDesc>("op_def_in", "test");
OpDescPtr out_op_def = std::make_shared<OpDesc>("op_def_out", "test");
OpDescPtr another_in_op_def = std::make_shared<OpDesc>("another_op_def_in", "test");
// whether using another input data anchor or not
if (isMultiInput) {
vector<bool> is_input_const_vec = {true, true};
op_def->SetIsInputConst(is_input_const_vec);
AttrUtils::SetInt(op_def, ge::ATTR_NAME_T, (int64_t)DT_INT32);
}
// input tensor;
GeTensorDesc tensor_desc(GeShape(dims_vec), FORMAT_NCHW, DT_INT32);
ge::ConstGeTensorPtr const_tensor =
std::make_shared<GeTensor>(tensor_desc, (uint8_t *)&data_vec[0], data_vec.size() * sizeof(int32_t));
ge::AttrUtils::SetTensor(in_op_def, ge::ATTR_NAME_WEIGHTS, const_tensor);
op_def->AddInputDesc(tensor_desc);
// whether using another input data anchor or not
if (isMultiInput) {
vector<int64_t> dims_vec_another = {6};
vector<int32_t> data_vec_another = {1, 2, 3, 4, 5, 6};
GeTensorDesc another_tensor_desc(GeShape(dims_vec_another), FORMAT_NCHW, DT_INT32);
ge::ConstGeTensorPtr const_tensor_another = std::make_shared<GeTensor>(
another_tensor_desc, (uint8_t *)&data_vec_another[0], data_vec_another.size() * sizeof(int32_t));
ge::AttrUtils::SetTensor(another_in_op_def, ge::ATTR_NAME_WEIGHTS, const_tensor_another);
op_def->AddInputDesc(another_tensor_desc);
another_in_op_def->AddOutputDesc(another_tensor_desc);
out_op_def->AddInputDesc(another_tensor_desc);
}
GeTensorDesc tensor_desc_out(GeShape(dims_vec), FORMAT_NCHW, DT_INT32);
op_def->AddOutputDesc(tensor_desc_out);
in_op_def->AddOutputDesc(tensor_desc);
// add attr of out_node
vector<bool> is_output_const(3, false);
is_output_const[0] = true;
out_op_def->SetIsInputConst(is_output_const);
out_op_def->AddInputDesc(tensor_desc);
// Add node
NodePtr in_node = graph->AddNode(in_op_def);
NodePtr node = graph->AddNode(op_def);
NodePtr out_node = graph->AddNode(out_op_def);
// Add edge
GraphUtils::AddEdge(in_node->GetOutDataAnchor(0), node->GetInDataAnchor(0));
GraphUtils::AddEdge(node->GetOutDataAnchor(0), out_node->GetInDataAnchor(0));
// when need multi input nodes (which to verify the isolate node function)
if (isMultiInput) {
NodePtr another_in_node = graph->AddNode(another_in_op_def);
GraphUtils::AddEdge(another_in_node->GetOutDataAnchor(0), node->GetInDataAnchor(1));
}
return node;
}
};
TEST_F(UtestGraphPassesGuaranteeConstPass, not_changed) {
// the original type of op is not guarantee_const
string type = SIZE;
// input tensor
vector<int64_t> dims_vec = {6};
vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
ge::Status ret = guarantee_const_op_remove_pass_->Run(node);
EXPECT_EQ(SUCCESS, ret);
}
TEST_F(UtestGraphPassesGuaranteeConstPass, get_origenal_type_fail) {
string type = GUARANTEECONST;
// input tensor
vector<int64_t> dims_vec = {6};
vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
// change the type
string type2 = "FrameworkOp";
node->GetOpDesc()->SetType(type2);
ge::Status ret = guarantee_const_op_remove_pass_->Run(node);
}
TEST_F(UtestGraphPassesGuaranteeConstPass, int32_success_6) {
// input tensor
string type = GUARANTEECONST;
vector<int64_t> dims_vec = {6};
vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
// when input tensor is [1, 2, 3, 4, 5, 6], return success
ge::Status output = guarantee_const_op_remove_pass_->Run(node);
EXPECT_EQ(ge::SUCCESS, output);
}
TEST_F(UtestGraphPassesGuaranteeConstPass, int32_success_2_3) {
// input tensor
string type = GUARANTEECONST;
vector<int64_t> dims_vec = {2, 3};
vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
// when input tensor is [[1, 2, 3], [4, 5, 6]], return success
ge::Status output = guarantee_const_op_remove_pass_->Run(node);
EXPECT_EQ(ge::SUCCESS, output);
}
TEST_F(UtestGraphPassesGuaranteeConstPass, isolate_node_failed) {
// input tensor
string type = GUARANTEECONST;
vector<int64_t> dims_vec = {2, 3};
vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
// add another input node
NodePtr node = init_node(graph, dims_vec, data_vec, true, type);
// when there are more than one input anchors, return failed
ge::Status output = guarantee_const_op_remove_pass_->Run(node);
EXPECT_EQ(ge::PARAM_INVALID, output);
}
// IR test, the shape and data type of input should be equal to the shape and data type of output
TEST_F(UtestGraphPassesGuaranteeConstPass, ir_infer_shape) {
auto input = unordered_map<string, vector<int64_t>>({
{"x", {3, 5, 3, 4}},
});
auto output = unordered_map<string, vector<int64_t>>({
{"y", {3, 5, 3, 4}},
});
auto guaranteeConst = op::GuaranteeConst("guaranteeconst");
TEST_OPERATOR(guaranteeConst, input, output);
}