You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
77 lines
2.6 KiB
77 lines
2.6 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <gtest/gtest.h>
|
|
#include <memory>
|
|
|
|
#include "common/ge_inner_error_codes.h"
|
|
#include "common/types.h"
|
|
#include "common/util.h"
|
|
#include "graph/passes/graph_builder_utils.h"
|
|
#include "graph/utils/attr_utils.h"
|
|
#include "graph/debug/ge_attr_define.h"
|
|
|
|
#define private public
|
|
#define protected public
|
|
#include "graph/preprocess/graph_preprocess.h"
|
|
#include "ge/ge_api.h"
|
|
#undef private
|
|
#undef protected
|
|
|
|
using namespace std;
|
|
namespace ge {
|
|
class UtestGraphPreproces : public testing::Test {
|
|
protected:
|
|
void SetUp() {
|
|
}
|
|
void TearDown() {
|
|
}
|
|
};
|
|
|
|
ComputeGraphPtr BuildGraph1(){
|
|
auto builder = ut::GraphBuilder("g1");
|
|
auto data1 = builder.AddNode("data1",DATA,1,1);
|
|
auto data_opdesc = data1->GetOpDesc();
|
|
AttrUtils::SetInt(data_opdesc, ATTR_NAME_INDEX, 0);
|
|
data1->UpdateOpDesc(data_opdesc);
|
|
return builder.GetGraph();
|
|
}
|
|
|
|
TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) {
|
|
ge::GraphPrepare graph_prepare;
|
|
graph_prepare.compute_graph_ = BuildGraph1();
|
|
// prepare user_input & graph option
|
|
ge::GeTensorDesc tensor1;
|
|
tensor1.SetFormat(ge::FORMAT_NCHW);
|
|
tensor1.SetShape(ge::GeShape({3, 12, 5, 5}));
|
|
tensor1.SetDataType(ge::DT_FLOAT);
|
|
GeTensor input1(tensor1);
|
|
std::vector<GeTensor> user_input = {input1};
|
|
std::map<string,string> graph_option = {{"ge.exec.dynamicGraphExecuteMode","dynamic_execute"},
|
|
{"ge.exec.dataInputsShapeRange","[3,1~20,2~10,5]"}};
|
|
auto ret = graph_prepare.UpdateInput(user_input, graph_option);
|
|
EXPECT_EQ(ret, ge::SUCCESS);
|
|
// check data node output shape_range and shape
|
|
auto data_node = graph_prepare.compute_graph_->FindNode("data1");
|
|
auto data_output_desc = data_node->GetOpDesc()->GetOutputDescPtr(0);
|
|
vector<int64_t> expect_shape = {3,-1,-1,5};
|
|
auto result_shape = data_output_desc->GetShape();
|
|
EXPECT_EQ(result_shape.GetDimNum(), expect_shape.size());
|
|
for(size_t i =0; i< expect_shape.size(); ++i){
|
|
EXPECT_EQ(result_shape.GetDim(i), expect_shape.at(i));
|
|
}
|
|
}
|
|
} |