modified: tests/ut/ge/CMakeLists.txt

new file:   tests/ut/ge/graph/passes/replace_with_empty_const_pass_unittest.cc
pull/1157/head
zhaoxinxin 4 years ago
parent 1c733582ef
commit ce359a3c32

@ -688,6 +688,7 @@ set(PASS_TEST_FILES
"graph/passes/no_use_reshape_remove_pass_unittest.cc"
"graph/passes/infershape_pass_unittest.cc"
"graph/passes/multi_batch_clone_pass_unittest.cc"
"graph/passes/replace_with_empty_const_pass_unittest.cc"
)
set(KERNEL_TEST_FILES

@ -0,0 +1,83 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/replace_with_empty_const_pass.h"
#include <gtest/gtest.h>
#include <set>
#include <string>
#include "graph_builder_utils.h"
namespace ge {
class UtestReplaceWithEmptyConstPass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};
namespace {
/// data1 const1
/// \ /
/// add1
/// |
/// cast1(empty)
/// |
/// conv2d
ut::GraphBuilder Graph1Builder() {
ut::GraphBuilder builder = ut::GraphBuilder("g1");
auto data1 = builder.AddNode("data1", "Data", 0, 1);
auto const1 = builder.AddNode("const1", "Const", 0, 1);
auto add1 = builder.AddNode("add1", "Add", 2, 1);
auto cast1 = builder.AddNode("cast1", "Cast", 1, 1);
auto conv2d = builder.AddNode("conv2d", "Conv2D", 1, 0);
add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1,1,8,8}),FORMAT_NCHW));
add1->GetOpDesc()->AddInputDesc(GeTensorDesc(GeShape({1,1,8,8}),FORMAT_NCHW));
add1->GetOpDesc()->AddOutputDesc(GeTensorDesc(GeShape({1,1,8,8}),FORMAT_NCHW));
cast1->GetOpDesc()->AddOutputDesc(GeTensorDesc(GeShape({1,1,8,8}),FORMAT_NCHW));
GeTensorDesc empty_tensor(GeShape({1,0,8,8}),FORMAT_NCHW);
cast1->GetOpDesc()->UpdateOutputDesc(0,empty_tensor);
builder.AddDataEdge(data1, 0, add1, 0);
builder.AddDataEdge(const1, 0, add1, 1);
builder.AddDataEdge(add1, 0, cast1, 0);
builder.AddDataEdge(cast1, 0, conv2d, 0);
return builder;
}
} // namespace
TEST_F(UtestReplaceWithEmptyConstPass, replace_whith_empty_const_success) {
auto builder = Graph1Builder();
auto graph = builder.GetGraph();
graph->SetSessionID(0);
ReplaceWithEmptyConstPass replace_with_empty_const_pass;
EXPECT_EQ(graph->GetDirectNodesSize(),5);
// run pass on add1, graph still has 5 nodes
auto add1 = graph->FindNode("add1");
Status ret = replace_with_empty_const_pass.Run(add1);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph->GetDirectNodesSize(),5);
auto cast1 = graph->FindNode("cast1");
ret = replace_with_empty_const_pass.Run(cast1)
EXPECT_EQ(cast1->GetOutAllNodes().size(),0);
auto conv2d = graph->FindNode("conv2d");
EXPECT_EQ(conv2d->GetInDataNodes().at(0)->GetType(),"Const");
}
} // namespace ge
Loading…
Cancel
Save