|
|
|
@ -34,6 +34,11 @@ using domi::SUCCESS;
|
|
|
|
|
namespace ge {
|
|
|
|
|
const int kValueIndexOutputIndex = 1;
|
|
|
|
|
|
|
|
|
|
bool IsEmptyTensor(const GeShape &shape) {
|
|
|
|
|
const auto &dims = shape.GetDims();
|
|
|
|
|
return std::any_of(dims.begin(), dims.end(), [](int64_t dim) { return dim == 0; });
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status MergePass::Run(NodePtr &node) {
|
|
|
|
|
GELOGD("MergePass running");
|
|
|
|
|
if (node == nullptr) {
|
|
|
|
@ -53,6 +58,11 @@ Status MergePass::Run(NodePtr &node) {
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (OptimizeEmptyTensorInput(node) != SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "[%s] remove empty_tensor inputs failed.", node->GetName().c_str());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto in_data_nodes = node->GetInDataNodes();
|
|
|
|
|
switch (in_data_nodes.size()) {
|
|
|
|
|
case 0: {
|
|
|
|
@ -202,4 +212,30 @@ bool MergePass::IsMergeInputNeedOptimized(NodePtr &node) const {
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status MergePass::OptimizeEmptyTensorInput(const NodePtr &node) {
|
|
|
|
|
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
|
|
|
|
|
const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
|
|
|
|
|
if (peer_data_anchor == nullptr) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if ((peer_data_anchor->GetOwnerNode() == nullptr) ||
|
|
|
|
|
(peer_data_anchor->GetOwnerNode()->GetOpDesc() == nullptr)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
const auto &op_desc = peer_data_anchor->GetOwnerNode()->GetOpDesc();
|
|
|
|
|
if (IsEmptyTensor(op_desc->GetOutputDesc(peer_data_anchor->GetIdx()).GetShape())) {
|
|
|
|
|
if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "Remove data edge %s:%d->%s:%d failed.",
|
|
|
|
|
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(),
|
|
|
|
|
node->GetName().c_str(), in_data_anchor->GetIdx());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
GELOGD("Remove data edge %s:%d->%s:%d",
|
|
|
|
|
op_desc->GetName().c_str(), peer_data_anchor->GetIdx(),
|
|
|
|
|
node->GetName().c_str(), in_data_anchor->GetIdx());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
} // namespace ge
|
|
|
|
|