parent
169d1efff3
commit
6f10a03c59
@ -0,0 +1,119 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include "common/ge_inner_error_codes.h"
|
||||
#include "graph/utils/op_desc_utils.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "graph/utils/node_utils.h"
|
||||
|
||||
using std::map;
|
||||
using std::vector;
|
||||
using std::set;
|
||||
using std::string;
|
||||
|
||||
namespace ge {
|
||||
Status FuseDataNodesWithCommonInputPass::Run(ge::ComputeGraphPtr graph) {
|
||||
if (graph == nullptr) {
|
||||
GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null.");
|
||||
return GE_GRAPH_PARAM_NULLPTR;
|
||||
}
|
||||
GELOGD("FuseDataNodesWithCommonInputPass in.");
|
||||
// key: subgraph, value:--key: peer out anchor to parent node, --value: parent indexes to parent node
|
||||
map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> subgraphs_to_need_fuse_nodes_info;
|
||||
if (InitNeedFuseNodesInfo(graph, subgraphs_to_need_fuse_nodes_info) != SUCCESS) {
|
||||
GELOGE(FAILED, "InitNeedFuseNodesInfo failed.");
|
||||
return FAILED;
|
||||
}
|
||||
return FuseDataNodes(subgraphs_to_need_fuse_nodes_info);
|
||||
}
|
||||
|
||||
Status FuseDataNodesWithCommonInputPass::InitNeedFuseNodesInfo(ComputeGraphPtr &graph,
|
||||
map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info) {
|
||||
for (const auto &subgraph : graph->GetAllSubgraphs()) {
|
||||
GE_CHECK_NOTNULL(subgraph);
|
||||
auto parent_node = subgraph->GetParentNode();
|
||||
GE_CHECK_NOTNULL(parent_node);
|
||||
if (parent_node->GetType() == CASE || parent_node->GetType() == IF) {
|
||||
auto &peer_out_anchors_to_parent_indexes = subgraphs_to_need_fuse_nodes_info[subgraph];
|
||||
for (const auto &in_data_anchor : parent_node->GetAllInDataAnchors()) {
|
||||
GE_CHECK_NOTNULL(in_data_anchor);
|
||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
|
||||
uint32_t parent_index = static_cast<uint32_t>(in_data_anchor->GetIdx());
|
||||
GE_CHECK_NOTNULL(peer_out_anchor);
|
||||
peer_out_anchors_to_parent_indexes[peer_out_anchor].insert(parent_index);
|
||||
GELOGD("Peer node %s is the %d input of parent node %s in %s.",
|
||||
peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_index, parent_node->GetName().c_str(),
|
||||
subgraph->GetName().c_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
Status FuseDataNodesWithCommonInputPass::FuseDataNodes(
|
||||
const map<ComputeGraphPtr, map<OutDataAnchorPtr, set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info) {
|
||||
for (const auto &subgraph_to_need_fuse_nodes_info : subgraphs_to_need_fuse_nodes_info) {
|
||||
auto subgraph = subgraph_to_need_fuse_nodes_info.first;
|
||||
for (const auto &peer_out_anchors_to_parent_indexes : subgraph_to_need_fuse_nodes_info.second) {
|
||||
if (peer_out_anchors_to_parent_indexes.second.size() <= 1) {
|
||||
continue;
|
||||
}
|
||||
// key: out anchor, value: data nodes with common input will be fused
|
||||
map<OutDataAnchorPtr, vector<NodePtr>> peer_out_anchors_to_need_fuse_nodes;
|
||||
for (const auto &node : subgraph->GetDirectNode()) {
|
||||
if (node->GetType() != DATA) {
|
||||
continue;
|
||||
}
|
||||
GE_CHECK_NOTNULL(node->GetOpDesc());
|
||||
uint32_t parent_index = 0;
|
||||
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
|
||||
if (peer_out_anchors_to_parent_indexes.second.count(parent_index) > 0) {
|
||||
peer_out_anchors_to_need_fuse_nodes[peer_out_anchors_to_parent_indexes.first].emplace_back(node);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (const auto &peer_out_anchor_to_need_fuse_nodes : peer_out_anchors_to_need_fuse_nodes) {
|
||||
auto need_fuse_data_nodes = peer_out_anchor_to_need_fuse_nodes.second;
|
||||
auto first_node = need_fuse_data_nodes.at(0);
|
||||
for (size_t i = 1; i < need_fuse_data_nodes.size(); ++i) {
|
||||
auto node = need_fuse_data_nodes.at(i);
|
||||
GELOGI("Replace redundant data node %s by %s exist in graph: %s.", node->GetName().c_str(),
|
||||
first_node->GetName().c_str(), subgraph->GetName().c_str());
|
||||
// the data node which can be fused has none input(both data and control in)
|
||||
if (GraphUtils::MoveOutCtrlEdges(node, first_node) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
if (GraphUtils::ReplaceNodeDataAnchors(first_node, node, {}, {0}) != SUCCESS) {
|
||||
return FAILED;
|
||||
}
|
||||
if (GraphUtils::RemoveNodeWithoutRelink(subgraph, node) != SUCCESS) {
|
||||
GELOGE(FAILED, "[%s] RemoveNodeWithoutRelink failed.", node->GetName().c_str());
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace ge
|
@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_
|
||||
#define GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_
|
||||
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "graph/types.h"
|
||||
#include "inc/graph_pass.h"
|
||||
|
||||
namespace ge {
|
||||
class FuseDataNodesWithCommonInputPass : public GraphPass {
|
||||
public:
|
||||
Status Run(ge::ComputeGraphPtr graph) override;
|
||||
|
||||
private:
|
||||
Status InitNeedFuseNodesInfo(ComputeGraphPtr &graph,
|
||||
map<ComputeGraphPtr, map<OutDataAnchorPtr, std::set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info);
|
||||
Status FuseDataNodes(
|
||||
const map<ComputeGraphPtr, map<OutDataAnchorPtr, std::set<uint32_t>>> &subgraphs_to_need_fuse_nodes_info);
|
||||
};
|
||||
} // namespace ge
|
||||
#endif // GE_GRAPH_PASSES_FUSE_DATA_NODES_WITH_COMMON_INPUT_PASS_H_
|
@ -0,0 +1,36 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "graph/passes/no_data_out_const_elimination_pass.h"
|
||||
|
||||
namespace ge {
|
||||
Status NoDataOutConstEliminationPass::Run(NodePtr &node) {
|
||||
GE_CHECK_NOTNULL(node);
|
||||
GELOGD("RemoveConstWithoutDataPass running of %s.", node->GetName().c_str());
|
||||
if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) {
|
||||
GE_CHECK_NOTNULL(node->GetOpDesc());
|
||||
// delete const which has no input and no output of data
|
||||
if (node->GetOpDesc()->GetInputsSize() == 0 && node->GetOutDataNodes().size() == 0) {
|
||||
GELOGI("Remove const %s.", node->GetName().c_str());
|
||||
if (IsolateAndDeleteNode(node, {}) != SUCCESS) {
|
||||
GELOGE(FAILED, "IsolateAndDeleteNode %s failed.", node->GetName().c_str());
|
||||
return FAILED;
|
||||
}
|
||||
}
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace ge
|
@ -0,0 +1,31 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_
|
||||
#define GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_
|
||||
|
||||
#include "graph/passes/base_pass.h"
|
||||
#include "framework/common/debug/ge_log.h"
|
||||
#include "framework/common/util.h"
|
||||
|
||||
namespace ge {
|
||||
class NoDataOutConstEliminationPass : public BaseNodePass {
|
||||
public:
|
||||
Status Run(ge::NodePtr &node) override;
|
||||
};
|
||||
} // namespace ge
|
||||
|
||||
#endif // GE_GRAPH_PASSES_REMOVE_CONST_WITHOUT_DATA_PASS_H_
|
@ -0,0 +1,156 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "graph/passes/fuse_data_nodes_with_common_input_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "inc/pass_manager.h"
|
||||
#include "common/ge_inner_error_codes.h"
|
||||
#include "graph_builder_utils.h"
|
||||
#include "graph/utils/tensor_utils.h"
|
||||
#include "graph/utils/op_desc_utils.h"
|
||||
#include "graph/utils/type_utils.h"
|
||||
#include "graph/utils/node_utils.h"
|
||||
|
||||
namespace ge {
|
||||
|
||||
class UtestFuseDataNodesWithCommonInputPass : public testing::Test {
|
||||
protected:
|
||||
void SetUp() {}
|
||||
void TearDown() {}
|
||||
|
||||
public:
|
||||
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) {
|
||||
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT);
|
||||
auto op_desc = std::make_shared<OpDesc>(name, type);
|
||||
for (auto i = 0; i < in_num; ++i) {
|
||||
op_desc->AddInputDesc(test_desc);
|
||||
}
|
||||
for (auto i = 0; i < out_num; ++i) {
|
||||
op_desc->AddOutputDesc(test_desc);
|
||||
}
|
||||
return graph->AddNode(op_desc);
|
||||
}
|
||||
};
|
||||
|
||||
/// graph with subgraph
|
||||
/// const
|
||||
/// | | |
|
||||
/// case
|
||||
/// |
|
||||
/// netoutput
|
||||
/// ...
|
||||
/// data0 data1 data2
|
||||
/// | \ /
|
||||
/// conv add
|
||||
TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph1) {
|
||||
PassManager pass_manager;
|
||||
pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass);
|
||||
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph");
|
||||
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const");
|
||||
auto parent_case = MakeNode(parent_graph, 3, 1, "parent_case", "Case");
|
||||
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput");
|
||||
|
||||
GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT);
|
||||
|
||||
parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
||||
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc);
|
||||
parent_case->GetOpDesc()->UpdateInputDesc(2, tensor_desc);
|
||||
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
|
||||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0));
|
||||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1));
|
||||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_case->GetInDataAnchor(2));
|
||||
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0));
|
||||
|
||||
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
|
||||
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data");
|
||||
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
||||
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data");
|
||||
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
||||
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
auto data2 = MakeNode(parent_graph, 1, 1, "data2", "Data");
|
||||
data2->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
||||
data2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0);
|
||||
(void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1);
|
||||
(void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2);
|
||||
|
||||
sub_graph->SetParentNode(parent_case);
|
||||
sub_graph->SetParentGraph(parent_graph);
|
||||
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS);
|
||||
}
|
||||
|
||||
/// graph with subgraph
|
||||
/// const
|
||||
/// / \
|
||||
/// cast1 cast2
|
||||
/// \ /
|
||||
/// case
|
||||
/// |
|
||||
/// netoutput
|
||||
/// ...
|
||||
/// data1 data2
|
||||
/// \ /
|
||||
/// add
|
||||
TEST_F(UtestFuseDataNodesWithCommonInputPass, graph_with_subgraph2) {
|
||||
PassManager pass_manager;
|
||||
pass_manager.AddPass("FuseDataNodesWithCommonInputPass", new (std::nothrow) FuseDataNodesWithCommonInputPass);
|
||||
ComputeGraphPtr parent_graph = std::make_shared<ComputeGraph>("parent_graph");
|
||||
auto parent_const = MakeNode(parent_graph, 0, 1, "parent_const", "Const");
|
||||
auto parent_cast1 = MakeNode(parent_graph, 1, 1, "parent_cast1", "Cast");
|
||||
auto parent_cast2 = MakeNode(parent_graph, 1, 1, "parent_cast2", "Cast");
|
||||
auto parent_case = MakeNode(parent_graph, 2, 1, "parent_case", "Case");
|
||||
auto parent_output = MakeNode(parent_graph, 1, 0, "parent_output", "NetOutput");
|
||||
|
||||
GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT);
|
||||
|
||||
parent_const->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
parent_cast1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
||||
parent_cast1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
parent_cast2->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
||||
parent_cast2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
parent_case->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
||||
parent_case->GetOpDesc()->UpdateInputDesc(1, tensor_desc);
|
||||
parent_case->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
|
||||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast1->GetInDataAnchor(0));
|
||||
GraphUtils::AddEdge(parent_cast1->GetOutDataAnchor(0), parent_case->GetInDataAnchor(0));
|
||||
GraphUtils::AddEdge(parent_const->GetOutDataAnchor(0), parent_cast2->GetInDataAnchor(0));
|
||||
GraphUtils::AddEdge(parent_cast2->GetOutDataAnchor(0), parent_case->GetInDataAnchor(1));
|
||||
GraphUtils::AddEdge(parent_case->GetOutDataAnchor(0), parent_output->GetInDataAnchor(0));
|
||||
|
||||
ComputeGraphPtr sub_graph = std::make_shared<ComputeGraph>("sub_graph");
|
||||
auto data0 = MakeNode(parent_graph, 1, 1, "data0", "Data");
|
||||
data0->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
||||
data0->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
auto data1 = MakeNode(parent_graph, 1, 1, "data1", "Data");
|
||||
data1->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
||||
data1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
(void)AttrUtils::SetInt(data0->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 0);
|
||||
(void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1);
|
||||
|
||||
sub_graph->SetParentNode(parent_case);
|
||||
sub_graph->SetParentGraph(parent_graph);
|
||||
EXPECT_EQ(pass_manager.Run(sub_graph), SUCCESS);
|
||||
}
|
||||
} // namespace ge
|
@ -0,0 +1,75 @@
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "graph/passes/no_data_out_const_elimination_pass.h"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
|
||||
#include "common/ge_inner_error_codes.h"
|
||||
#include "graph/utils/graph_utils.h"
|
||||
|
||||
namespace ge {
|
||||
|
||||
class UtestNoDataOutConstEliminationPass : public testing::Test {
|
||||
protected:
|
||||
void SetUp() {}
|
||||
void TearDown() {}
|
||||
|
||||
public:
|
||||
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) {
|
||||
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT);
|
||||
auto op_desc = std::make_shared<OpDesc>(name, type);
|
||||
for (auto i = 0; i < in_num; ++i) {
|
||||
op_desc->AddInputDesc(test_desc);
|
||||
}
|
||||
for (auto i = 0; i < out_num; ++i) {
|
||||
op_desc->AddOutputDesc(test_desc);
|
||||
}
|
||||
return graph->AddNode(op_desc);
|
||||
}
|
||||
};
|
||||
|
||||
/// graph with subgraph
|
||||
/// const1
|
||||
/// |(control)
|
||||
/// const2
|
||||
/// |
|
||||
/// output
|
||||
TEST_F(UtestNoDataOutConstEliminationPass, succ_graph1) {
|
||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
|
||||
auto const_node1 = MakeNode(graph, 0, 1, "const_node1", "Const");
|
||||
auto const_node2 = MakeNode(graph, 1, 1, "const_node2", "Const");
|
||||
auto output_node = MakeNode(graph, 1, 0, "output_node", "NetOutput");
|
||||
GeTensorDesc tensor_desc(GeShape({1,3,224,224}), FORMAT_NCHW, DT_FLOAT);
|
||||
|
||||
const_node1->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
const_node2->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
||||
const_node2->GetOpDesc()->UpdateOutputDesc(0, tensor_desc);
|
||||
output_node->GetOpDesc()->UpdateInputDesc(0, tensor_desc);
|
||||
|
||||
GraphUtils::AddEdge(const_node1->GetOutControlAnchor(), const_node2->GetInControlAnchor());
|
||||
GraphUtils::AddEdge(const_node2->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
|
||||
|
||||
GEPass pass(graph);
|
||||
NamesToPass node_pass;
|
||||
NoDataOutConstEliminationPass no_data_out_const_elimination_pass;
|
||||
node_pass.emplace_back("NoDataOutConstEliminationPass", &no_data_out_const_elimination_pass);
|
||||
EXPECT_EQ(pass.Run(node_pass), SUCCESS);
|
||||
}
|
||||
} // namespace ge
|
Loading…
Reference in new issue