!698 Add keep_dtype attribute on operators to keep precision unchanged
From: @li-lei0106 Reviewed-by: Signed-off-by:pull/698/MERGE
commit
9aec8b4f0f
@ -0,0 +1,107 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "keep_dtype_option.h"
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include "graph/debug/ge_attr_define.h"
|
||||
#include "framework/common/util.h"
|
||||
#include "common/util/error_manager/error_manager.h"
|
||||
|
||||
namespace ge {
|
||||
namespace {
|
||||
const size_t kMaxOpsNum = 10;
|
||||
} // namespace
|
||||
bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) {
|
||||
std::vector<std::string> original_op_names;
|
||||
if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto &origin_name : original_op_names) {
|
||||
if (origin_name == op_name) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void KeepDtypeReportError(const std::vector<std::string> &invalid_list) {
|
||||
std::stringstream error_ops;
|
||||
for (size_t i = 0; i < invalid_list.size(); i++) {
|
||||
if (i == kMaxOpsNum) {
|
||||
error_ops << "...";
|
||||
break;
|
||||
}
|
||||
error_ops << invalid_list[i] << " ";
|
||||
}
|
||||
std::string err_msg = "config file contains ";
|
||||
err_msg = err_msg.append(std::to_string(invalid_list.size()))
|
||||
.append(" operators not in the graph, op names:")
|
||||
.append(error_ops.str());
|
||||
ErrorManager::GetInstance().ATCReportErrMessage(
|
||||
"E10042", {"parameter", "reason"}, {"keep_dtype", err_msg.c_str()});
|
||||
GELOGE(FAILED, "%s", err_msg.c_str());
|
||||
}
|
||||
|
||||
Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype) {
|
||||
GE_CHECK_NOTNULL(graph);
|
||||
if (keep_dtype.empty()) {
|
||||
return SUCCESS;
|
||||
}
|
||||
std::string real_path = RealPath(keep_dtype.c_str());
|
||||
if (real_path.empty()) {
|
||||
GELOGE(PARAM_INVALID, "Can not get real path for %s.", keep_dtype.c_str());
|
||||
return PARAM_INVALID;
|
||||
}
|
||||
std::ifstream ifs(real_path);
|
||||
if (!ifs.is_open()) {
|
||||
GELOGE(FAILED, "Open file %s failed", keep_dtype.c_str());
|
||||
return FAILED;
|
||||
}
|
||||
|
||||
std::string op_name;
|
||||
std::vector<std::string> invalid_list;
|
||||
while (std::getline(ifs, op_name)) {
|
||||
if (op_name.empty()) {
|
||||
continue;
|
||||
}
|
||||
op_name = StringUtils::Trim(op_name);
|
||||
bool is_find = false;
|
||||
for (auto &node_ptr : graph->GetDirectNode()) {
|
||||
auto op_desc = node_ptr->GetOpDesc();
|
||||
GE_CHECK_NOTNULL(op_desc);
|
||||
|
||||
if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) {
|
||||
is_find = true;
|
||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1);
|
||||
}
|
||||
}
|
||||
if (!is_find) {
|
||||
invalid_list.push_back(op_name);
|
||||
}
|
||||
}
|
||||
|
||||
if (!invalid_list.empty()) {
|
||||
KeepDtypeReportError(invalid_list);
|
||||
return PARAM_INVALID;
|
||||
}
|
||||
|
||||
return SUCCESS;
|
||||
}
|
||||
} // namespace ge
|
@ -0,0 +1,26 @@
|
||||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef KEEP_DTYPE_OPTION_H_
|
||||
#define KEEP_DTYPE_OPTION_H_
|
||||
|
||||
#include <string>
|
||||
#include "graph/compute_graph.h"
|
||||
#include "framework/common/ge_inner_error_codes.h"
|
||||
|
||||
namespace ge {
|
||||
Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype);
|
||||
} // namespace
|
||||
#endif // KEEP_DTYPE_OPTION_H_
|
Loading…
Reference in new issue