You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc

77 lines
2.6 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <gtest/gtest.h>
#include <memory>
#include "common/ge_inner_error_codes.h"
#include "common/types.h"
#include "common/util.h"
#include "graph/passes/graph_builder_utils.h"
#include "graph/utils/attr_utils.h"
#include "graph/debug/ge_attr_define.h"
#define private public
#define protected public
#include "graph/preprocess/graph_preprocess.h"
#include "ge/ge_api.h"
#undef private
#undef protected
using namespace std;
namespace ge {
class UtestGraphPreproces : public testing::Test {
protected:
void SetUp() {
}
void TearDown() {
}
};
ComputeGraphPtr BuildGraph1(){
auto builder = ut::GraphBuilder("g1");
auto data1 = builder.AddNode("data1",DATA,1,1);
auto data_opdesc = data1->GetOpDesc();
AttrUtils::SetInt(data_opdesc, ATTR_NAME_INDEX, 0);
data1->UpdateOpDesc(data_opdesc);
return builder.GetGraph();
}
TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) {
ge::GraphPrepare graph_prepare;
graph_prepare.compute_graph_ = BuildGraph1();
// prepare user_input & graph option
ge::GeTensorDesc tensor1;
tensor1.SetFormat(ge::FORMAT_NCHW);
tensor1.SetShape(ge::GeShape({3, 12, 5, 5}));
tensor1.SetDataType(ge::DT_FLOAT);
GeTensor input1(tensor1);
std::vector<GeTensor> user_input = {input1};
std::map<string,string> graph_option = {{"ge.exec.dynamicGraphExecuteMode","dynamic_execute"},
{"ge.exec.dataInputsShapeRange","[3,1~20,2~10,5]"}};
auto ret = graph_prepare.UpdateInput(user_input, graph_option);
EXPECT_EQ(ret, ge::SUCCESS);
// check data node output shape_range and shape
auto data_node = graph_prepare.compute_graph_->FindNode("data1");
auto data_output_desc = data_node->GetOpDesc()->GetOutputDescPtr(0);
vector<int64_t> expect_shape = {3,-1,-1,5};
auto result_shape = data_output_desc->GetShape();
EXPECT_EQ(result_shape.GetDimNum(), expect_shape.size());
for(size_t i =0; i< expect_shape.size(); ++i){
EXPECT_EQ(result_shape.GetDim(i), expect_shape.at(i));
}
}
}