/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "graph/common/transop_util.h" #include "common/types.h" #include "graph/utils/type_utils.h" #include "framework/common/debug/ge_log.h" namespace { const int kInvalidTransopDataIndex = -1; const int kTransOpOutIndex = 0; std::map precision_loss_transfer_map = {{ge::DT_FLOAT, ge::DT_BOOL}}; } // namespace namespace ge { TransOpUtil::TransOpUtil() { transop_index_map_ = {{TRANSDATA, 0}, {TRANSPOSE, 0}, {TRANSPOSED, 0}, {RESHAPE, 0}, {REFORMAT, 0}, {CAST, 0}, {SQUEEZE, 0}, {EXPANDDIMS, 0}}; } TransOpUtil::~TransOpUtil() {} TransOpUtil &TransOpUtil::Instance() { static TransOpUtil inst; return inst; } bool TransOpUtil::IsTransOp(const NodePtr &node) { if (node == nullptr) { return false; } return IsTransOp(node->GetType()); } bool TransOpUtil::IsTransOp(const std::string &type) { return Instance().transop_index_map_.find(type) != Instance().transop_index_map_.end(); } int TransOpUtil::GetTransOpDataIndex(const NodePtr &node) { if (node == nullptr) { return kInvalidTransopDataIndex; } return GetTransOpDataIndex(node->GetType()); } int TransOpUtil::GetTransOpDataIndex(const std::string &type) { auto it = Instance().transop_index_map_.find(type); if (it != Instance().transop_index_map_.end()) { return it->second; } return kInvalidTransopDataIndex; } bool TransOpUtil::CheckPrecisionLoss(const ge::NodePtr &src_node) { auto idx = TransOpUtil::GetTransOpDataIndex(src_node); auto input_desc = src_node->GetOpDesc()->GetInputDesc(idx); auto output_desc = src_node->GetOpDesc()->GetOutputDesc(kTransOpOutIndex); auto src_dtype = input_desc.GetDataType(); auto dst_dtype = output_desc.GetDataType(); auto iter = precision_loss_transfer_map.find(src_dtype); if (iter != precision_loss_transfer_map.end() && iter->second == dst_dtype) { GELOGW("Node %s transfer data type from %s to %s ,it will cause precision loss. ignore pass.", src_node->GetName().c_str(), TypeUtils::DataTypeToSerialString(src_dtype).c_str(), TypeUtils::DataTypeToSerialString(dst_dtype).c_str()); return false; } return true; } std::string TransOpUtil::TransopMapToString() { std::string buffer; for (auto &key : Instance().transop_index_map_) { buffer += key.first + " "; } return buffer; } } // namespace ge