diff --git a/ge/graph/passes/merge_pass.cc b/ge/graph/passes/merge_pass.cc index d2340037..0b367614 100644 --- a/ge/graph/passes/merge_pass.cc +++ b/ge/graph/passes/merge_pass.cc @@ -34,6 +34,11 @@ using domi::SUCCESS; namespace ge { const int kValueIndexOutputIndex = 1; +bool IsEmptyTensor(const GeShape &shape) { + const auto &dims = shape.GetDims(); + return std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim == 0; }); +} + Status MergePass::Run(NodePtr &node) { GELOGD("MergePass running"); if (node == nullptr) { @@ -53,6 +58,11 @@ Status MergePass::Run(NodePtr &node) { return PARAM_INVALID; } + if (OptimizeEmptyTensorInput(node) != SUCCESS) { + GELOGE(FAILED, "[%s] remove empty_tensor inputs failed.", node->GetName().c_str()); + return FAILED; + } + auto in_data_nodes = node->GetInDataNodes(); switch (in_data_nodes.size()) { case 0: { @@ -202,4 +212,30 @@ bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const { } return true; } + +Status MergePass::OptimizeEmptyTensorInput(const NodePtr &node) { + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); + if (peer_data_anchor == nullptr) { + continue; + } + if ((peer_data_anchor->GetOwnerNode() == nullptr) || + (peer_data_anchor->GetOwnerNode()->GetOpDesc() == nullptr)) { + continue; + } + const auto &op_desc = peer_data_anchor->GetOwnerNode()->GetOpDesc(); + if (IsEmptyTensor(op_desc->GetOutputDesc(peer_data_anchor->GetIdx()).GetShape())) { + if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.", + op_desc->GetName().c_str(), peer_data_anchor->GetIdx(), + node->GetName().c_str(), in_data_anchor->GetIdx()); + return FAILED; + } + GELOGD("Remove data edge %s:%d->%s:%d", + op_desc->GetName().c_str(), peer_data_anchor->GetIdx(), + node->GetName().c_str(), in_data_anchor->GetIdx()); + } + } + return SUCCESS; +} } // namespace ge diff --git a/ge/graph/passes/merge_pass.h b/ge/graph/passes/merge_pass.h index 2cdb5022..464f2172 100755 --- a/ge/graph/passes/merge_pass.h +++ b/ge/graph/passes/merge_pass.h @@ -29,6 +29,7 @@ class MergePass : public BaseNodePass { Status ChangeIndexToConstant(NodePtr &node, int &value_index); Status CreateConstByValue(NodePtr &node, int value_index, OpDescPtr &op_desc); bool IsMergeInputNeedOptimized(NodePtr &node) const; + static Status OptimizeEmptyTensorInput(const NodePtr &node); }; } // namespace ge #endif // GE_GRAPH_PASSES_MERGE_PASS_H_