!733 decrease om size

From: @dimitri_rose
Reviewed-by: @ji_chen,@sheng-nan
Signed-off-by: @startzgf168,@ji_chen
pull/733/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 5de4cd5479

@ -157,6 +157,8 @@ set(TRAIN_SRC_LIST
"graph/passes/compile_nodes_pass.cc"
"graph/passes/constant_folding_pass.cc"
"graph/passes/constant_fuse_same_pass.cc"
"graph/passes/remove_same_const_pass.cc"
"graph/passes/useless_control_out_remove_pass.cc"
"graph/passes/control_trigger_pass.cc"
"graph/passes/dimension_adjust_pass.cc"
"graph/passes/dimension_compute_pass.cc"
@ -522,6 +524,8 @@ set(INFER_SRC_LIST
"graph/passes/assign_pass.cc"
"graph/passes/addn_pass.cc"
"graph/passes/common_subexpression_elimination_pass.cc"
"graph/passes/remove_same_const_pass.cc"
"graph/passes/useless_control_out_remove_pass.cc"
"graph/passes/transop_symmetry_elimination_pass.cc"
"graph/passes/save_pass.cc"
"graph/passes/switch_dead_branch_elimination.cc"

@ -191,6 +191,8 @@ OMG_HOST_SRC_FILES := \
graph/passes/control_trigger_pass.cc \
graph/passes/cond_pass.cc \
graph/passes/cond_remove_pass.cc \
graph/passes/remove_same_const_pass.cc \
graph/passes/useless_control_out_remove_pass.cc \
graph/passes/for_pass.cc \
graph/passes/enter_pass.cc \
graph/passes/assign_pass.cc \

@ -126,6 +126,8 @@ LIBGE_LOCAL_SRC_FILES := \
graph/passes/compile_nodes_pass.cc \
graph/passes/constant_folding_pass.cc \
graph/passes/constant_fuse_same_pass.cc \
graph/passes/remove_same_const_pass.cc \
graph/passes/useless_control_out_remove_pass.cc \
graph/passes/control_trigger_pass.cc \
graph/passes/dimension_adjust_pass.cc \
graph/passes/dimension_compute_pass.cc \

@ -224,6 +224,7 @@ Status ModelBuilder::AdjustConstWeightSize(const ge::NodePtr &node, size_t &mem_
GeTensorDesc &tensor_desc = weight->MutableTensorDesc();
size_t output_size = weight->GetData().size();
TensorUtils::SetDataOffset(tensor_desc, mem_offset);
GELOGD("Node: %s, weight size: %zu.", node->GetName().c_str(), output_size);
mem_offset += output_size;
}
return SUCCESS;

@ -78,6 +78,7 @@
#include "graph/passes/prune_pass.h"
#include "graph/passes/ref_identity_delete_op_pass.h"
#include "graph/passes/replace_with_empty_const_pass.h"
#include "graph/passes/remove_same_const_pass.h"
#include "graph/passes/reshape_recovery_pass.h"
#include "graph/passes/reshape_remove_pass.h"
#include "graph/passes/same_transdata_breadth_fusion_pass.h"
@ -92,6 +93,7 @@
#include "graph/passes/transop_symmetry_elimination_pass.h"
#include "graph/passes/transop_without_reshape_fusion_pass.h"
#include "graph/passes/transpose_transdata_pass.h"
#include "graph/passes/useless_control_out_remove_pass.h"
#include "graph/passes/variable_op_pass.h"
#include "graph/passes/variable_prepare_op_pass.h"
#include "graph/passes/variable_ref_delete_op_pass.h"
@ -2150,6 +2152,7 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
TransposeTransDataPass transpose_transdata_pass;
TransOpSymmetryEliminationPass symmetry_elimination_pass;
DimensionComputePass dimension_compute_pass;
UselessControlOutRemovePass useless_control_out_remove_pass;
names_to_passes.emplace_back("EnterPass", &enter_pass);
names_to_passes.emplace_back("AddNPass", &addn_pass);
names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination);
@ -2163,6 +2166,7 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass);
names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass);
names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass);
names_to_passes.emplace_back("UselessControlOutRemovePass", &useless_control_out_remove_pass);
GE_TIMESTAMP_START(names_to_passes);
ret = GEPass(compute_graph).Run(names_to_passes);
GE_TIMESTAMP_END(names_to_passes, "GraphManager::OptimizeStage1_2");
@ -2203,6 +2207,8 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) {
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass",
new (std::nothrow) VariableRefUselessControlOutDeletePass))
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ReshapeRecoveryPass", new (std::nothrow) ReshapeRecoveryPass))
GE_CHK_STATUS_RET(
graph_pass.AddPass("OptimizeStage1_3::RemoveSameConstPass", new (std::nothrow) RemoveSameConstPass))
if (options_.train_graph_flag) {
// Priority: The GlobalStepInsertPass should work before graph partitioner.
// Reason: Make sure that the var "global_step" can be partitioned to known sub graph and allocated memory

@ -18,6 +18,8 @@
#include "ge/ge_api_types.h"
#include "graph/common/omg_util.h"
using std::string;
namespace ge {
Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) {
GELOGD("AttachStreamLabelPass Enter.");
@ -187,21 +189,10 @@ Status AttachStreamLabelPass::UpdateEnterNode() {
}
std::stack<NodePtr> enter_nodes;
std::string batch_label;
for (const auto &enter_node : pair.second) {
enter_nodes.emplace(enter_node);
std::string tmp_label;
(void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, tmp_label);
if (!tmp_label.empty()) {
if (batch_label.empty()) {
batch_label = tmp_label;
} else if (batch_label != tmp_label) {
GELOGE(FAILED, "multi batch_label exist, label1=%s, label2=%s.", batch_label.c_str(), tmp_label.c_str());
return FAILED;
}
}
}
if (UpdateLoopBranch(enter_nodes, active_label_list[0], batch_label) != SUCCESS) {
if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) {
GELOGE(FAILED, "Update stream_label for loop_branch failed.");
return FAILED;
}
@ -226,10 +217,7 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
}
for (const auto &enter_node : enter_nodes) {
GE_CHECK_NOTNULL(enter_node->GetOpDesc());
if (enter_node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL)) {
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed.");
}
GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed.");
}
return SUCCESS;
}
@ -241,8 +229,7 @@ Status AttachStreamLabelPass::SetEnterLabel(const std::vector<NodePtr> &enter_no
/// @param [in] batch_label
/// @return Status
///
Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label,
const std::string &batch_label) {
Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const string &stream_label) {
std::stack<NodePtr> nodes(enter_nodes);
NodePtr cur_node = nullptr;
while (!nodes.empty()) {
@ -251,11 +238,6 @@ Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack<NodePtr> &enter_
for (const NodePtr &out_node : cur_node->GetOutAllNodes()) {
OpDescPtr out_desc = out_node->GetOpDesc();
GE_CHECK_NOTNULL(out_desc);
std::string tmp_label;
(void)AttrUtils::GetStr(out_desc, ATTR_NAME_BATCH_LABEL, tmp_label);
if (!tmp_label.empty() && (tmp_label != batch_label)) {
continue;
}
std::string out_type = out_desc->GetType();
bool need_skip =
out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER) ||

@ -58,11 +58,9 @@ class AttachStreamLabelPass : public GraphPass {
/// @brief Update stream_label for loop_branch
/// @param [in] enter_nodes
/// @param [in] stream_label
/// @param [in] batch_label
/// @return Status
///
static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label,
const std::string &batch_label);
static Status UpdateLoopBranch(const std::stack<NodePtr> &enter_nodes, const std::string &stream_label);
///
/// @brief Update stream_label start with enter nodes

@ -96,7 +96,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder
node->GetName().c_str(), node->GetType().c_str());
continue;
}
if (node_to_re_pass->IsAllInNodesSeen(nodes_seen)) {
if (nodes_seen.count(node_to_re_pass.get()) > 0 || node_to_re_pass->IsAllInNodesSeen(nodes_seen)) {
GELOGD("The node %s will be re-pass later", node_to_re_pass->GetName().c_str());
nodes_re_pass.insert(node_to_re_pass);
} else {

@ -80,7 +80,71 @@ Status DimensionAdjustPass::Run(ge::NodePtr &node) {
}
}
ret = DealWithInNodes(node);
if (ret != SUCCESS) {
GELOGE(ret, "DealWithInNodes of %s failed.", node->GetName().c_str());
return ret;
}
std::vector<int> data_relink_io_map = {kDataInputIndex};
return IsolateAndDeleteNode(node, data_relink_io_map);
}
Status DimensionAdjustPass::DealWithInNodes(NodePtr &node) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
auto graph = node->GetOwnerComputeGraph();
auto in_data_anchors = node->GetAllInDataAnchors();
for (auto &in_data_anchor : in_data_anchors) {
if (in_data_anchor == nullptr) {
continue;
}
auto in_node_anchor = in_data_anchor->GetPeerOutAnchor();
if (in_node_anchor == nullptr) {
continue;
}
auto in_node = in_node_anchor->GetOwnerNode();
if (in_node->GetType() == SWITCHN) {
auto identity_name = node->GetName() + "_ctrl_identity_" + std::to_string(in_data_anchor->GetIdx());
auto identity =
AddIdentityNodeToGraph(identity_name, node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()), graph);
GE_CHECK_NOTNULL(identity);
GELOGI("Create new identity node[%s] after node %s[type: %s] success.", identity->GetName().c_str(),
in_node->GetName().c_str(), in_node->GetType().c_str());
GE_CHK_STATUS_RET(GraphUtils::AddEdge(in_node_anchor, identity->GetInDataAnchor(0)))
GE_CHECK_NOTNULL(identity->GetOutControlAnchor());
if (identity->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor())) {
continue;
}
GE_CHK_STATUS_RET(GraphUtils::AddEdge(identity->GetOutControlAnchor(), node->GetInControlAnchor()))
}
}
return SUCCESS;
}
NodePtr DimensionAdjustPass::AddIdentityNodeToGraph(const string &name, const GeTensorDesc &tensor,
ComputeGraphPtr &graph) {
if (graph == nullptr) {
GELOGE(INTERNAL_ERROR, "Comput graph ptr is null in creating identity node.");
return nullptr;
}
OpDescPtr desc = MakeShared<OpDesc>("", "");
if (desc == nullptr) {
GELOGE(MEMALLOC_FAILED, "Failed to create op desc.");
return nullptr;
}
desc->SetName(name);
desc->SetType(IDENTITY);
auto ret = desc->AddInputDesc(tensor);
auto ret2 = desc->AddOutputDesc(tensor);
if ((ret != GRAPH_SUCCESS) || (ret2 != GRAPH_SUCCESS)) {
GELOGE(INTERNAL_ERROR, "Failed to add input/output desc in creating identity.");
return nullptr;
}
return graph->AddNodeFront(desc);
}
} // namespace ge

@ -34,6 +34,10 @@ namespace ge {
class DimensionAdjustPass : public BaseNodePass {
public:
Status Run(ge::NodePtr &node) override;
private:
Status DealWithInNodes(ge::NodePtr &node);
NodePtr AddIdentityNodeToGraph(const std::string &name, const GeTensorDesc &tensor, ComputeGraphPtr &graph);
};
} // namespace ge

@ -23,6 +23,7 @@
namespace {
const size_t kOutNodesNum = 1;
const size_t kInCtrlNodesNum = 1;
}
namespace ge {
@ -55,6 +56,7 @@ Status EnterPass::Run(NodePtr &node) {
if (out_ctrl_node == nullptr) {
continue;
}
GELOGI("Remove control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str());
if (GraphUtils::RemoveEdge(node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove Enter ctrl output fail, %s->%s", node->GetName().c_str(),
out_ctrl_node->GetName().c_str());
@ -62,8 +64,12 @@ Status EnterPass::Run(NodePtr &node) {
}
}
} else {
if (OptimizeEnter(node, in_node) != SUCCESS) {
GELOGE(FAILED, "Optimize enter node[%s] failed.", node->GetName().c_str());
if (OptimizeEnterWithOnlyDataOut(node, in_node) != SUCCESS) {
GELOGE(FAILED, "Optimize enter node[%s] with only out data node failed.", node->GetName().c_str());
return FAILED;
}
if (UnlinkCtrlEdgeBeforeConst(node) != SUCCESS) {
GELOGE(FAILED, "Unlink control edge before const of node[%s]'s out nodes failed.", node->GetName().c_str());
return FAILED;
}
}
@ -72,7 +78,7 @@ Status EnterPass::Run(NodePtr &node) {
return SUCCESS;
}
Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) {
Status EnterPass::OptimizeEnterWithOnlyDataOut(NodePtr &node, NodePtr &in_node) {
if ((in_node->GetOutAllNodes().size() != kOutNodesNum) || !node->GetOutControlNodes().empty()) {
return SUCCESS;
}
@ -83,17 +89,61 @@ Status EnterPass::OptimizeEnter(NodePtr &node, NodePtr &in_node) {
}
GE_CHECK_NOTNULL(in_node->GetOutDataAnchor(0));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0)));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->Unlink(node->GetInDataAnchor(0)))
const auto &out_data_anchor = node->GetOutDataAnchor(0);
GE_CHECK_NOTNULL(out_data_anchor);
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor));
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor));
GE_CHK_STATUS_RET(out_data_anchor->Unlink(peer_in_data_anchor))
GE_CHK_STATUS_RET(in_node->GetOutDataAnchor(0)->LinkTo(peer_in_data_anchor))
}
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node));
GE_CHK_STATUS_RET(GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node))
AddNodeDeleted(node);
AddRePassNodesWithInOut(in_node);
return SUCCESS;
}
Status EnterPass::UnlinkCtrlEdgeBeforeConst(NodePtr &node) {
auto out_ctrl_nodes = node->GetOutControlNodes();
if (out_ctrl_nodes.empty()) {
return SUCCESS;
}
auto out_ctrl_anchor = node->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_ctrl_anchor);
for (auto &out_ctrl_node : out_ctrl_nodes) {
GE_CHECK_NOTNULL(out_ctrl_node);
if ((out_ctrl_node->GetType() != CONSTANT) && (out_ctrl_node->GetType() != CONSTANTOP)) {
continue;
}
auto in_ctrl_nodes = out_ctrl_node->GetInControlNodes();
if (in_ctrl_nodes.size() != kInCtrlNodesNum) {
continue;
}
// Skip when has merge out
bool has_merge_out = false;
auto out_nodes_of_const = out_ctrl_node->GetOutAllNodes();
for (const auto &out_node_of_const : out_nodes_of_const) {
GE_CHECK_NOTNULL(out_node_of_const);
if (out_node_of_const->GetType() == MERGE || out_node_of_const->GetType() == REFMERGE) {
has_merge_out = true;
break;
}
}
if (has_merge_out) {
continue;
}
GELOGI("Unlink control edge from %s to %s.", node->GetName().c_str(), out_ctrl_node->GetName().c_str());
GE_CHK_STATUS_RET(out_ctrl_anchor->Unlink(out_ctrl_node->GetInControlAnchor()))
for (auto &out_node_of_const : out_nodes_of_const) {
if (!out_ctrl_anchor->IsLinkedWith(out_node_of_const->GetInControlAnchor())) {
GELOGI("Link control edge from %s to %s.", node->GetName().c_str(), out_node_of_const->GetName().c_str());
GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(out_node_of_const->GetInControlAnchor()))
}
}
}
return SUCCESS;
}
} // namespace ge

@ -25,7 +25,8 @@ class EnterPass : public BaseNodePass {
Status Run(NodePtr &node) override;
private:
Status OptimizeEnter(NodePtr &node, NodePtr &in_node);
Status OptimizeEnterWithOnlyDataOut(NodePtr &node, NodePtr &in_node);
Status UnlinkCtrlEdgeBeforeConst(NodePtr &node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_ENTER_PASS_H_

@ -173,10 +173,7 @@ Status FoldingPass::DealWithInNodes(NodePtr &node) {
continue;
}
auto in_node = in_node_anchor->GetOwnerNode();
if (in_node == nullptr) {
continue;
}
if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH)) {
if ((in_node->GetType() == SWITCH) || (in_node->GetType() == REFSWITCH) || (in_node->GetType() == SWITCHN)) {
GELOGI("The in_node name is %s, and node type is %s.", in_node->GetName().c_str(), in_node->GetType().c_str());
auto ret = in_node_anchor->Unlink(in_data_anchor);
if (ret != SUCCESS) {

@ -89,16 +89,6 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co
GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed");
}
if (merge_op_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) {
string batch_label;
(void)AttrUtils::GetStr(merge_op_desc, ATTR_NAME_BATCH_LABEL, batch_label);
if (!batch_label.empty()) {
auto stream_merge_desc = stream_merge->GetOpDesc();
GE_CHECK_NOTNULL(stream_merge_desc);
(void)AttrUtils::SetStr(stream_merge_desc, ATTR_NAME_BATCH_LABEL, batch_label);
}
}
return AddActiveNodes(graph, stream_merge);
}

File diff suppressed because it is too large Load Diff

@ -46,13 +46,6 @@ class NextIterationPass : public GraphPass {
///
Status GroupEnterNode(const NodePtr &enter_node);
///
/// @brief Group Enter nodes without batch_label attr
/// @param [in] compute_graph
/// @return Status
///
Status GroupWithNoBatch(const ComputeGraphPtr &graph);
///
/// @brief Find while groups
/// @return Status
@ -97,13 +90,10 @@ class NextIterationPass : public GraphPass {
/// @param [out] target_node
/// @return Status
///
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input,
const std::string &batch_label, NodePtr &target_node);
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node);
// map<frame_name, vector<enter_node>>
std::unordered_map<std::string, std::vector<NodePtr>> frame_enter_map_;
// map<frame_name, map<batch_label, LoopCondGroup>>
std::unordered_map<std::string, std::unordered_map<std::string, LoopCondGroupPtr>> loop_group_map_;
// map<frame_name, LoopCondGroup>
std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_;
};
} // namespace ge
#endif // GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_

@ -0,0 +1,106 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "remove_same_const_pass.h"
#include <sstream>
#include <string>
#include <set>
#include "common/base64.h"
#include "ge_local_engine/engine/host_cpu_engine.h"
#include "graph/utils/node_utils.h"
namespace ge {
namespace {
std::string GetCseKey(const NodePtr &node) {
std::stringstream ss;
ss << node->GetType() << "control-inputs-";
std::set<std::string> control_in_node_names;
for (auto &src_node : node->GetInControlNodes()) {
control_in_node_names.insert(src_node->GetName());
}
for (auto &name : control_in_node_names) {
ss << name << "-";
}
ss << "attrs-" << AttrUtils::GetAllAttrsStr(node->GetOpDesc());
return ss.str();
}
bool IsConstType(const NodePtr &node) { return (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP); }
} // namespace
Status RemoveSameConstPass::Run(ComputeGraphPtr graph) {
GELOGD("Begin to run RemoveSameConstPass on the graph");
GE_CHECK_NOTNULL(graph);
std::map<std::string, NodePtr> keys_to_node;
for (const auto &node : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
if (!IsConstType(node)) {
continue;
}
bool is_unknown = false;
auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown);
if (ret != GRAPH_SUCCESS) {
GELOGW("Get node unknown status failed, node name:%s, type:%s.",
node->GetName().c_str(), node->GetType().c_str());
continue;
}
if (is_unknown) {
GELOGI("Current node %s, type %s is unknown shape which should be skip.",
node->GetName().c_str(), node->GetType().c_str());
continue;
}
auto key = GetCseKey(node);
GELOGD("The const node %s cse key %s", node->GetName().c_str(), ge::base64::EncodeToBase64(key).c_str());
auto iter = keys_to_node.find(key);
if (iter == keys_to_node.end()) {
keys_to_node[key] = node;
continue;
}
if (node->GetAllOutDataAnchorsSize() != iter->second->GetAllOutDataAnchorsSize()) {
GELOGW("The const node %s and %s have the same CSE key, but different output anchor count, skip to fusion them",
iter->second->GetName().c_str(), node->GetName().c_str());
continue;
}
std::vector<int> output_map(node->GetAllOutDataAnchorsSize());
for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
output_map[i] = i;
}
ret = GraphUtils::ReplaceNodeAnchors(iter->second, node, {}, output_map);
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to replace node %s by node %s", node->GetName().c_str(),
iter->second->GetName().c_str(), ret);
return INTERNAL_ERROR;
}
NodeUtils::UnlinkAll(*node);
ret = GraphUtils::RemoveNodeWithoutRelink(graph, node);
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to remove node %s from graph", node->GetName().c_str());
return INTERNAL_ERROR;
}
GELOGI("Remove const node %s by RemoveSameConstPass, replace it with node %s", node->GetName().c_str(),
iter->second->GetName().c_str());
}
return SUCCESS;
}
} // namespace ge

@ -0,0 +1,28 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_
#define GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_
#include "graph/types.h"
#include "inc/graph_pass.h"
namespace ge {
class RemoveSameConstPass : public GraphPass {
public:
Status Run(ge::ComputeGraphPtr graph) override ;
};
} // namespace ge
#endif //GE_GRAPH_PASSES_REMOVE_SAME_CONST_PASS_H_

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/useless_control_out_remove_pass.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
namespace ge {
Status UselessControlOutRemovePass::Run(NodePtr &node) {
GE_CHECK_NOTNULL(node);
if ((node->GetType() != CONSTANT) && (node->GetType() != CONSTANTOP)) {
return SUCCESS;
}
GELOGD("UselessControlOutRemovePass running, node: %s.", node->GetName().c_str());
// const has no control input
if (node->GetInControlNodes().empty()) {
if (node->GetOutDataNodes().empty()) {
// It is an isolated const, just remove it.
GELOGI("Delete isolated const: %s.", node->GetName().c_str());
GE_CHK_STATUS_RET(IsolateAndDeleteNode(node, {}))
AddNodeDeleted(node);
} else {
auto out_ctrl_anchor = node->GetOutControlAnchor();
if (out_ctrl_anchor != nullptr && !out_ctrl_anchor->GetPeerAnchors().empty()) {
GELOGI("Node: %s unlink all out control edge.", node->GetName().c_str());
out_ctrl_anchor->UnlinkAll();
}
}
}
return SUCCESS;
}
} // namespace ge

@ -0,0 +1,29 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_
#define GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_
#include "graph/passes/base_pass.h"
namespace ge {
class UselessControlOutRemovePass : public BaseNodePass {
public:
Status Run(NodePtr &node) override;
};
} // namespace ge
#endif // GE_GRAPH_PASSES_USELESS_CONTROL_OUT_REMOVE_PASS_H_

File diff suppressed because it is too large Load Diff

@ -18,6 +18,7 @@
#include <map>
#include <queue>
#include <vector>
#include <set>
#include "external/ge/ge_api_error_codes.h"
@ -64,12 +65,26 @@ class MultiBatchGraphCopyer {
private:
Status Init();
Status CheckArguments();
Status RelinkConstCtrlEdge();
Status ExtractUnchangedStructureOutofCycle();
Status GetEnterNodesGroupByFrame(std::map<std::string, std::vector<NodePtr>> &frame_enter);
Status GetNodeNeedExtract(const std::map<std::string, std::vector<NodePtr>> &frame_enter,
std::queue<NodePtr> &nodes_to_extract);
bool AllInDataNodesUnchangeAndNoMergeOut(const NodePtr &node);
Status MoveInEntersInDataAnchorDown(NodePtr &node, OpDescPtr &enter_desc);
Status InsertEnterAfterNode(NodePtr &node, const OpDescPtr &enter_desc, std::set<NodePtr> &out_nodes);
Status MoveCtrlEdgeToOutNodes(NodePtr &node, std::set<NodePtr> &out_nodes);
Status DeleteEnterWithoutDataOut();
// label status for origin_all_nodes_
Status LabelStatus();
Status LabelInBatchBranchStatus();
void LabelStatusForData(const NodePtr &data);
void LabelStatusForGetNextSink(const NodePtr &data);
void InitStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters);
void ResetEnterStatus(std::map<std::string, std::vector<NodePtr>> &frame_enters, const NodePtr &node);
// add nodes functions
Status CreateNewNodes();
@ -81,7 +96,6 @@ class MultiBatchGraphCopyer {
Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index,
std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn);
Status InsertIdentityAfterSwitchN();
Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index);
Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index);

@ -245,6 +245,8 @@ set(COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/hccl_group_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/set_input_output_offset_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/model/ge_model.cc"
"${GE_CODE_DIR}/ge/common/cust_aicpu_kernel_store.cc"
"${GE_CODE_DIR}/ge/graph/load/new_model_manager/model_utils.cc"
@ -475,6 +477,8 @@ set(GRAPH_PASS_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/passes/reshape_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/resource_pair_add_control_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/resource_pair_remove_control_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/remove_same_const_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/useless_control_out_remove_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_breadth_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_without_reshape_fusion_pass.cc"
"${GE_CODE_DIR}/ge/graph/passes/transop_depth_fusion_pass.cc"

Loading…
Cancel
Save