You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/passes/get_original_format_pass.cc

184 lines
7.2 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/get_original_format_pass.h"
#include <vector>
#include "common/debug/log.h"
#include "common/types.h"
#include "common/util.h"
#include "framework/common/debug/ge_log.h"
#include "framework/omg/omg_inner_types.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/common/local_context.h"
using domi::DOMI_TENSOR_NCHW;
using domi::DOMI_TENSOR_NHWC;
using domi::DOMI_TENSOR_RESERVED;
using domi::FAILED;
using domi::PARAM_INVALID;
using domi::SUCCESS;
namespace ge {
Status GetOriginalFormatPass::Run(ge::ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);
GE_RETURN_WITH_LOG_IF_ERROR(SetOriginalFormat(graph), "SetOriginalFormat failed");
return SUCCESS;
}
Status GetOriginalFormatPass::SetOriginalFormat(const ge::ComputeGraphPtr &graph) {
GE_CHECK_NOTNULL(graph);
int64_t ori_format = 0;
int64_t tmp_format = 0;
for (auto &node_ptr : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node_ptr);
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(node_ptr->GetOpDesc(), ATTR_NAME_INFERRED_FORMAT, DOMI_TENSOR_RESERVED),
GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
return FAILED);
}
for (auto &node_ptr : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node_ptr);
OpDescPtr desc_ptr = node_ptr->GetOpDesc();
GE_CHECK_NOTNULL(desc_ptr);
auto is_data = (desc_ptr->GetType() == DATA_TYPE || desc_ptr->GetType() == AIPP_DATA_TYPE);
if (is_data) {
GELOGI("Data node: %s,format :%d", node_ptr->GetName().c_str(), GetLocalOmgContext().format);
ori_format = static_cast<int64_t>(GetLocalOmgContext().format);
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_FORMAT, ori_format),
GELOGE(FAILED, "set ATTR_NAME_FORMAT failed");
return FAILED);
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_INFERRED_FORMAT, ori_format),
GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
return FAILED);
continue;
}
int32_t i = 0;
bool continue_flag = false;
bool ignore_pred_format = false;
for (auto &bias_node_ptr : node_ptr->GetInDataNodes()) {
GE_CHECK_NOTNULL(bias_node_ptr);
OpDescPtr bias_op_ptr = bias_node_ptr->GetOpDesc();
GE_CHECK_NOTNULL(bias_op_ptr);
if (bias_op_ptr->GetType() == BIASADD) {
ignore_pred_format = true;
std::size_t tmp_size = ge::OpDescUtils::GetNonConstInputsSize(bias_node_ptr);
GE_IF_BOOL_EXEC(tmp_size > 2 || tmp_size == 0,
GELOGW("bias_node is node followed by %zu nodes, should be 1 or 2", tmp_size);
continue_flag = true; break);
OpDescPtr tmp_first_op_ptr = bias_node_ptr->GetInDataNodes().at(0)->GetOpDesc();
GE_CHECK_NOTNULL(tmp_first_op_ptr);
bias_op_ptr = tmp_first_op_ptr;
// if biasadd have 2 input edges, format should be same
if (tmp_size == 2) {
int64_t first_input_format = 0;
int64_t second_input_format = 0;
OpDescPtr tmpSecondOpPtr = bias_node_ptr->GetInDataNodes().at(1)->GetOpDesc();
GE_CHECK_NOTNULL(tmpSecondOpPtr);
GE_IF_BOOL_EXEC(
!AttrUtils::GetInt(tmp_first_op_ptr, ATTR_NAME_FORMAT, first_input_format), continue_flag = true; break);
GE_IF_BOOL_EXEC(
!AttrUtils::GetInt(tmpSecondOpPtr, ATTR_NAME_FORMAT, second_input_format), continue_flag = true; break);
if (first_input_format != second_input_format) {
GELOGW("biasadd node is followed two nodes with different format, get original format failed");
continue_flag = true;
break;
}
}
}
GE_IF_BOOL_EXEC(!AttrUtils::GetInt(bias_op_ptr, ATTR_NAME_FORMAT, tmp_format), continue_flag = true; break;);
if (i == 0) {
ori_format = tmp_format;
}
GE_IF_BOOL_EXEC(tmp_format != ori_format,
GELOGW("node: %s , original format of src nodes must be same!", bias_node_ptr->GetName().c_str());
continue_flag = true; break;);
i++;
}
GE_IF_BOOL_EXEC(continue_flag, continue);
OpDescPtr tmp_op_ptr = node_ptr->GetOpDesc();
GE_CHECK_NOTNULL(tmp_op_ptr);
if (IsFormatTranspose(tmp_op_ptr, static_cast<int32_t>(ori_format))) {
ori_format = (ori_format == DOMI_TENSOR_NCHW) ? DOMI_TENSOR_NHWC : DOMI_TENSOR_NCHW;
}
if (ignore_pred_format) {
GE_IF_BOOL_EXEC(!AttrUtils::SetBool(tmp_op_ptr, ATTR_NAME_IGNORE_PRED_FORMAT, true),
GELOGE(FAILED, "remove edge failed");
return FAILED);
}
// Do not reset ATTR_NAME_FORMAT if it is set in the OpParser.
if (!tmp_op_ptr->HasAttr(ATTR_NAME_FORMAT)) {
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_FORMAT, ori_format),
GELOGE(FAILED, "set ATTR_NAME_FORMAT failed");
return FAILED);
GE_IF_BOOL_EXEC(!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_INFERRED_FORMAT, ori_format),
GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
return FAILED);
} else {
int64_t existingFormat = 0;
GE_RETURN_WITH_LOG_IF_FALSE(AttrUtils::GetInt(tmp_op_ptr, ATTR_NAME_FORMAT, existingFormat),
"Get existing_format attr failed");
if (!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_INFERRED_FORMAT, existingFormat)) {
GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
return FAILED;
}
}
}
return SUCCESS;
}
bool GetOriginalFormatPass::IsFormatTranspose(const ge::OpDescPtr op_ptr, int32_t ori_format) {
GE_CHK_BOOL_EXEC(op_ptr != nullptr, return false, "opdef is nullptr");
if (op_ptr->GetType() == PERMUTE) {
vector<int32_t> index_list;
GE_IF_BOOL_EXEC(!AttrUtils::GetListInt(op_ptr, PERMUTE_ATTR_ORDER, index_list), return false);
auto index_size = index_list.size();
GE_IF_BOOL_EXEC(static_cast<int32_t>(index_size) != PERMUTE_ORDER_NUM, return false);
int32_t perm_nchw[4] = {0, 2, 3, 1}; // 4 format nums, {0,2,3,1} NCHW -> NHWC
int32_t perm_nhwc[4] = {0, 3, 1, 2}; // 4 format nums, {0,3,1,2} NHWC -> NCHW
bool is_nchw = true;
bool is_nhwc = true;
for (size_t i = 0; i < index_size; ++i) {
is_nchw = (perm_nchw[i] != index_list[i]) ? false : is_nchw;
is_nhwc = (perm_nhwc[i] != index_list[i]) ? false : is_nhwc;
}
bool ret = (is_nchw && ori_format == DOMI_TENSOR_NCHW && !is_nhwc) ||
(is_nhwc && ori_format == DOMI_TENSOR_NHWC && !is_nchw);
return ret;
}
return false;
}
} // namespace ge