!432 rm empty_tensor input for merge node

From: @chen_yemeng
Reviewed-by: 
Signed-off-by:
pull/432/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 13ff4ac8c0

@ -34,6 +34,11 @@ using domi::SUCCESS;
namespace ge {
const int kValueIndexOutputIndex = 1;
bool IsEmptyTensor(const GeShape &shape) {
const auto &dims = shape.GetDims();
return std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim == 0; });
}
Status MergePass::Run(NodePtr &node) {
GELOGD("MergePass running");
if (node == nullptr) {
@ -53,6 +58,11 @@ Status MergePass::Run(NodePtr &node) {
return PARAM_INVALID;
}
if (OptimizeEmptyTensorInput(node) != SUCCESS) {
GELOGE(FAILED, "[%s] remove empty_tensor inputs failed.", node->GetName().c_str());
return FAILED;
}
auto in_data_nodes = node->GetInDataNodes();
switch (in_data_nodes.size()) {
case 0: {
@ -202,4 +212,30 @@ bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const {
}
return true;
}
Status MergePass::OptimizeEmptyTensorInput(const NodePtr &node) {
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (peer_data_anchor == nullptr) {
continue;
}
if ((peer_data_anchor->GetOwnerNode() == nullptr) ||
(peer_data_anchor->GetOwnerNode()->GetOpDesc() == nullptr)) {
continue;
}
const auto &op_desc = peer_data_anchor->GetOwnerNode()->GetOpDesc();
if (IsEmptyTensor(op_desc->GetOutputDesc(peer_data_anchor->GetIdx()).GetShape())) {
if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) {
GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.",
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(),
node->GetName().c_str(), in_data_anchor->GetIdx());
return FAILED;
}
GELOGD("Remove data edge %s:%d->%s:%d",
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(),
node->GetName().c_str(), in_data_anchor->GetIdx());
}
}
return SUCCESS;
}
} // namespace ge

@ -29,6 +29,7 @@ class MergePass : public BaseNodePass {
Status ChangeIndexToConstant(NodePtr &node, int &value_index);
Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc);
bool IsMergeInputNeedOptimized(NodePtr &node) const;
static Status OptimizeEmptyTensorInput(const NodePtr &node);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_MERGE_PASS_H_

Loading…
Cancel
Save