/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/passes/get_original_format_pass.h" #include #include "common/debug/log.h" #include "common/types.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/omg/omg_inner_types.h" #include "graph/utils/attr_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/common/local_context.h" using domi::DOMI_TENSOR_NCHW; using domi::DOMI_TENSOR_NHWC; using domi::DOMI_TENSOR_RESERVED; using domi::FAILED; using domi::PARAM_INVALID; using domi::SUCCESS; namespace ge { Status GetOriginalFormatPass::Run(ge::ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); GE_RETURN_WITH_LOG_IF_ERROR(SetOriginalFormat(graph), "SetOriginalFormat failed"); return SUCCESS; } Status GetOriginalFormatPass::SetOriginalFormat(const ge::ComputeGraphPtr &graph) { GE_CHECK_NOTNULL(graph); int64_t ori_format = 0; int64_t tmp_format = 0; for (auto &node_ptr : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node_ptr); GE_IF_BOOL_EXEC(!AttrUtils::SetInt(node_ptr->GetOpDesc(), ATTR_NAME_INFERRED_FORMAT, DOMI_TENSOR_RESERVED), GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed"); return FAILED); } for (auto &node_ptr : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node_ptr); OpDescPtr desc_ptr = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(desc_ptr); auto is_data = (desc_ptr->GetType() == DATA_TYPE || desc_ptr->GetType() == AIPP_DATA_TYPE); if (is_data) { GELOGI("Data node: %s,format :%d", node_ptr->GetName().c_str(), GetLocalOmgContext().format); ori_format = static_cast(GetLocalOmgContext().format); GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_FORMAT, ori_format), GELOGE(FAILED, "set ATTR_NAME_FORMAT failed"); return FAILED); GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_INFERRED_FORMAT, ori_format), GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed"); return FAILED); continue; } int32_t i = 0; bool continue_flag = false; bool ignore_pred_format = false; for (auto &bias_node_ptr : node_ptr->GetInDataNodes()) { GE_CHECK_NOTNULL(bias_node_ptr); OpDescPtr bias_op_ptr = bias_node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(bias_op_ptr); if (bias_op_ptr->GetType() == BIASADD) { ignore_pred_format = true; std::size_t tmp_size = ge::OpDescUtils::GetNonConstInputsSize(bias_node_ptr); GE_IF_BOOL_EXEC(tmp_size > 2 || tmp_size == 0, GELOGW("bias_node is node followed by %zu nodes, should be 1 or 2", tmp_size); continue_flag = true; break); OpDescPtr tmp_first_op_ptr = bias_node_ptr->GetInDataNodes().at(0)->GetOpDesc(); GE_CHECK_NOTNULL(tmp_first_op_ptr); bias_op_ptr = tmp_first_op_ptr; // if biasadd have 2 input edges, format should be same if (tmp_size == 2) { int64_t first_input_format = 0; int64_t second_input_format = 0; OpDescPtr tmpSecondOpPtr = bias_node_ptr->GetInDataNodes().at(1)->GetOpDesc(); GE_CHECK_NOTNULL(tmpSecondOpPtr); GE_IF_BOOL_EXEC( !AttrUtils::GetInt(tmp_first_op_ptr, ATTR_NAME_FORMAT, first_input_format), continue_flag = true; break); GE_IF_BOOL_EXEC( !AttrUtils::GetInt(tmpSecondOpPtr, ATTR_NAME_FORMAT, second_input_format), continue_flag = true; break); if (first_input_format != second_input_format) { GELOGW("biasadd node is followed two nodes with different format, get original format failed"); continue_flag = true; break; } } } GE_IF_BOOL_EXEC(!AttrUtils::GetInt(bias_op_ptr, ATTR_NAME_FORMAT, tmp_format), continue_flag = true; break;); if (i == 0) { ori_format = tmp_format; } GE_IF_BOOL_EXEC(tmp_format != ori_format, GELOGW("node: %s , original format of src nodes must be same!", bias_node_ptr->GetName().c_str()); continue_flag = true; break;); i++; } GE_IF_BOOL_EXEC(continue_flag, continue); OpDescPtr tmp_op_ptr = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(tmp_op_ptr); if (IsFormatTranspose(tmp_op_ptr, static_cast(ori_format))) { ori_format = (ori_format == DOMI_TENSOR_NCHW) ? DOMI_TENSOR_NHWC : DOMI_TENSOR_NCHW; } if (ignore_pred_format) { GE_IF_BOOL_EXEC(!AttrUtils::SetBool(tmp_op_ptr, ATTR_NAME_IGNORE_PRED_FORMAT, true), GELOGE(FAILED, "remove edge failed"); return FAILED); } // Do not reset ATTR_NAME_FORMAT if it is set in the OpParser. if (!tmp_op_ptr->HasAttr(ATTR_NAME_FORMAT)) { GE_IF_BOOL_EXEC(!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_FORMAT, ori_format), GELOGE(FAILED, "set ATTR_NAME_FORMAT failed"); return FAILED); GE_IF_BOOL_EXEC(!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_INFERRED_FORMAT, ori_format), GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed"); return FAILED); } else { int64_t existingFormat = 0; GE_RETURN_WITH_LOG_IF_FALSE(AttrUtils::GetInt(tmp_op_ptr, ATTR_NAME_FORMAT, existingFormat), "Get existing_format attr failed"); if (!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_INFERRED_FORMAT, existingFormat)) { GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed"); return FAILED; } } } return SUCCESS; } bool GetOriginalFormatPass::IsFormatTranspose(const ge::OpDescPtr op_ptr, int32_t ori_format) { GE_CHK_BOOL_EXEC(op_ptr != nullptr, return false, "opdef is nullptr"); if (op_ptr->GetType() == PERMUTE) { vector index_list; GE_IF_BOOL_EXEC(!AttrUtils::GetListInt(op_ptr, PERMUTE_ATTR_ORDER, index_list), return false); auto index_size = index_list.size(); GE_IF_BOOL_EXEC(static_cast(index_size) != PERMUTE_ORDER_NUM, return false); int32_t perm_nchw[4] = {0, 2, 3, 1}; // 4 format nums, {0,2,3,1} NCHW -> NHWC int32_t perm_nhwc[4] = {0, 3, 1, 2}; // 4 format nums, {0,3,1,2} NHWC -> NCHW bool is_nchw = true; bool is_nhwc = true; for (size_t i = 0; i < index_size; ++i) { is_nchw = (perm_nchw[i] != index_list[i]) ? false : is_nchw; is_nhwc = (perm_nhwc[i] != index_list[i]) ? false : is_nhwc; } bool ret = (is_nchw && ori_format == DOMI_TENSOR_NCHW && !is_nhwc) || (is_nhwc && ori_format == DOMI_TENSOR_NHWC && !is_nchw); return ret; } return false; } } // namespace ge