|
|
|
/**
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include "graph/passes/permute_pass.h"
|
|
|
|
#include <queue>
|
|
|
|
#include <vector>
|
|
|
|
#include "common/debug/log.h"
|
|
|
|
#include "common/types.h"
|
|
|
|
#include "graph/utils/attr_utils.h"
|
|
|
|
#include "graph/utils/op_desc_utils.h"
|
|
|
|
#include "inc/kernel.h"
|
|
|
|
#include "inc/kernel_factory.h"
|
|
|
|
#include "framework/omg/omg_inner_types.h"
|
|
|
|
#include "graph/common/local_context.h"
|
|
|
|
|
|
|
|
using domi::DOMI_TENSOR_ND;
|
|
|
|
using domi::DOMI_TENSOR_NHWC;
|
|
|
|
using domi::SUCCESS;
|
|
|
|
using domi::TENSORFLOW;
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
Status PermutePass::Run(ComputeGraphPtr graph) {
|
|
|
|
GE_CHECK_NOTNULL(graph);
|
|
|
|
std::vector<NodePtr> isolate_nodes;
|
|
|
|
for (NodePtr &node : graph->GetDirectNode()) {
|
|
|
|
OpDescPtr op_desc_ptr = node->GetOpDesc();
|
|
|
|
GE_CHECK_NOTNULL(op_desc_ptr);
|
|
|
|
GE_IF_BOOL_EXEC(
|
|
|
|
op_desc_ptr->GetType() == PERMUTE && GetLocalOmgContext().type == domi::TENSORFLOW,
|
|
|
|
/// Input format 5D means NHWC in 4D way. So if input origin foramt is NCHW and
|
|
|
|
/// permute paramter list is [0,3,1,2], this permute can be optimised.
|
|
|
|
GE_IF_BOOL_EXEC(
|
|
|
|
GetLocalOmgContext().format != DOMI_TENSOR_ND,
|
|
|
|
// Get input origin foramt
|
|
|
|
for (NodePtr &n
|
|
|
|
: graph->GetDirectNode()) {
|
|
|
|
GE_IF_BOOL_EXEC(
|
|
|
|
n->GetOpDesc()->GetType() == PERMUTE, std::queue<NodePtr> q_node; q_node.push(n); bool jump_out = false;
|
|
|
|
while (!q_node.empty()) {
|
|
|
|
NodePtr n_temp = q_node.back();
|
|
|
|
q_node.pop();
|
|
|
|
for (auto &inNode : n_temp->GetInDataNodes()) {
|
|
|
|
int64_t cur_format = 0;
|
|
|
|
GE_IF_BOOL_EXEC(AttrUtils::GetInt(inNode->GetOpDesc(), ATTR_NAME_FORMAT, cur_format),
|
|
|
|
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(n->GetOpDesc(), "permute_src_format", cur_format),
|
|
|
|
GELOGW("set permute_src_format failed");
|
|
|
|
continue);
|
|
|
|
jump_out = true; break);
|
|
|
|
q_node.push(inNode);
|
|
|
|
}
|
|
|
|
GE_IF_BOOL_EXEC(jump_out, break);
|
|
|
|
});
|
|
|
|
}
|
|
|
|
|
|
|
|
int64_t permute_src_format = 0;
|
|
|
|
GE_IF_BOOL_EXEC(!AttrUtils::GetInt(op_desc_ptr, "permute_src_format", permute_src_format), continue);
|
|
|
|
// Get dim_index_
|
|
|
|
std::vector<int64_t> index_list; GE_CHK_BOOL_RET_STATUS(
|
|
|
|
AttrUtils::GetListInt(op_desc_ptr, PERMUTE_ATTR_ORDER, index_list), INTERNAL_ERROR, "get index list failed");
|
|
|
|
|
|
|
|
size_t index_size = index_list.size(); GE_IF_BOOL_EXEC(index_size == 0, continue);
|
|
|
|
|
|
|
|
GE_IF_BOOL_EXEC(index_size == 4 && (permute_src_format == DOMI_TENSOR_NHWC && index_list.at(0) == 0 &&
|
|
|
|
index_list.at(1) == 3 && index_list.at(2) == 1 && index_list.at(3) == 2),
|
|
|
|
isolate_nodes.push_back(node);
|
|
|
|
continue);
|
|
|
|
int64_t conv_format = 0; GE_IF_BOOL_EXEC(
|
|
|
|
index_size == 4 &&
|
|
|
|
(index_list.at(0) == 0 && index_list.at(1) == 2 && index_list.at(2) == 3 && index_list.at(3) == 1),
|
|
|
|
GE_IF_BOOL_EXEC(
|
|
|
|
(node->GetOutDataNodesSize() > 0 && node->GetOutDataNodes().at(0) != nullptr &&
|
|
|
|
node->GetOutDataNodes().at(0)->GetOpDesc() != nullptr) &&
|
|
|
|
((node->GetOutDataNodesSize() != 0 &&
|
|
|
|
CONVOLUTION == node->GetOutDataNodes().at(0)->GetOpDesc()->GetType() &&
|
|
|
|
AttrUtils::GetInt(node->GetOutDataNodes().at(0)->GetOpDesc(), ATTR_NAME_FORMAT, conv_format) &&
|
|
|
|
conv_format == DOMI_TENSOR_NHWC) ||
|
|
|
|
(node->GetOutDataNodesSize() != 0 &&
|
|
|
|
node->GetOutDataNodes().at(0)->GetOpDesc()->GetType() == DEPCONVOLUTION) ||
|
|
|
|
(node->GetOutDataNodesSize() != 0 &&
|
|
|
|
node->GetOutDataNodes().at(0)->GetOpDesc()->GetType() == DECONVOLUTION) ||
|
|
|
|
(node->GetOutDataNodesSize() != 0 && node->GetOutDataNodes().at(0)->GetOpDesc()->GetType() == PAD &&
|
|
|
|
node->GetOutDataNodes().at(0)->GetOutDataNodesSize() != 0 &&
|
|
|
|
node->GetOutDataNodes().at(0)->GetOutDataNodes().at(0) != nullptr &&
|
|
|
|
node->GetOutDataNodes().at(0)->GetOutDataNodes().at(0)->GetOpDesc() != nullptr &&
|
|
|
|
node->GetOutDataNodes().at(0)->GetOutDataNodes().at(0)->GetOpDesc()->GetType() == CONVOLUTION)),
|
|
|
|
isolate_nodes.push_back(node);
|
|
|
|
continue););););
|
|
|
|
}
|
|
|
|
|
|
|
|
GE_IF_BOOL_EXEC(
|
|
|
|
isolate_nodes.size() != 0, for (auto &node
|
|
|
|
: isolate_nodes) {
|
|
|
|
// Adding an attribute indicates that the predecessor Permute has been deleted for the Builder to process.
|
|
|
|
for (auto &outNode : node->GetOutDataNodes()) {
|
|
|
|
OpDescPtr op_desc_ptr = outNode->GetOpDesc();
|
|
|
|
GE_CHECK_NOTNULL(op_desc_ptr);
|
|
|
|
if (!AttrUtils::SetBool(op_desc_ptr, ATTR_NAME_PRED_PERMUTE_DELETED, true)) {
|
|
|
|
GELOGE(INTERNAL_ERROR, "set ATTR_NAME_PRED_PERMUTE_DELETED failed");
|
|
|
|
return INTERNAL_ERROR;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
GE_RETURN_WITH_LOG_IF_ERROR(graph->RemoveNode(node), "[%s]:remove permute node failed",
|
|
|
|
node->GetOpDesc()->GetName().c_str());
|
|
|
|
});
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
} // namespace ge
|