diff --git a/ge/ir_build/atc_ir_common.cc b/ge/ir_build/atc_ir_common.cc index 7d63514c..9eab901f 100755 --- a/ge/ir_build/atc_ir_common.cc +++ b/ge/ir_build/atc_ir_common.cc @@ -51,6 +51,7 @@ const char *const kDigitError = "is not digit"; const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; const char *const kSelectImplmodeError = "only support high_performance, high_precision"; const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; +const char *const kKeepDtypeError = "file not found"; vector SplitInputShape(const std::string &input_shape) { vector shape_pair_vec; @@ -438,6 +439,17 @@ Status CheckCompressWeightParamValid(const std::string enable_compress_weight, c return ge::SUCCESS; } +Status CheckKeepTypeParamValid(const std::string &keep_dtype) { + if ((!keep_dtype.empty()) && (!CheckInputPathValid(keep_dtype, "--keep_dtype"))) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, {"--keep_dtype", keep_dtype, kKeepDtypeError}); + GELOGE(ge::PARAM_INVALID, "keep dtype config file not found, file_name:%s", keep_dtype.c_str()); + return ge::PARAM_INVALID; + } + + return ge::SUCCESS; +} + int CheckLogParamValidAndSetLogLevel(const std::string log) { int ret = -1; if (log == "default") { diff --git a/ge/ir_build/atc_ir_common.h b/ge/ir_build/atc_ir_common.h index 1e4af73e..46e58a43 100644 --- a/ge/ir_build/atc_ir_common.h +++ b/ge/ir_build/atc_ir_common.h @@ -76,6 +76,7 @@ Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory) Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream); Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode); Status CheckInputFormat(const std::string &input_format); +Status CheckKeepTypeParamValid(const std::string &keep_dtype); void PrintOptionMap(std::map &options, std::string tips); void EraseEndSemicolon(std::string ¶m); } diff --git a/ge/offline/CMakeLists.txt b/ge/offline/CMakeLists.txt index af259ecb..48c1cbe7 100644 --- a/ge/offline/CMakeLists.txt +++ b/ge/offline/CMakeLists.txt @@ -10,6 +10,7 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) set(SRC_LIST "main.cc" "single_op_parser.cc" + "keep_dtype_option.cc" "../session/omg.cc" "../ir_build/atc_ir_common.cc" ) diff --git a/ge/offline/keep_dtype_option.cc b/ge/offline/keep_dtype_option.cc new file mode 100644 index 00000000..5624f21c --- /dev/null +++ b/ge/offline/keep_dtype_option.cc @@ -0,0 +1,116 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "keep_dtype_option.h" +#include +#include +#include +#include +#include "graph/debug/ge_attr_define.h" +#include "framework/common/util.h" +#include "common/util/error_manager/error_manager.h" + +namespace ge { +namespace { +const size_t kMaxOpsNum = 10; +} // namespace +bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { + std::vector original_op_names; + if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { + return false; + } + + for (auto &origin_name : original_op_names) { + if (origin_name == op_name) { + return true; + } + } + + return false; +} + +void KeepDtypeReportError(const std::vector &invalid_list) { + std::stringstream err_msg; + size_t list_size = invalid_list.size(); + err_msg << "config file contains " << list_size; + if (list_size == 1) { + err_msg << " operator not in the graph, op name:"; + } else { + err_msg << " operators not in the graph, op names:"; + } + + for (size_t i = 0; i < list_size; i++) { + if (i == kMaxOpsNum) { + err_msg << ".."; + break; + } + err_msg << invalid_list[i]; + if (i != list_size - 1) { + err_msg << " "; + } + } + + ErrorManager::GetInstance().ATCReportErrMessage( + "E10042", {"parameter", "reason"}, {"keep_dtype", err_msg.str().c_str()}); + GELOGE(FAILED, "%s", err_msg.str().c_str()); +} + +Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype) { + GE_CHECK_NOTNULL(graph); + if (keep_dtype.empty()) { + return SUCCESS; + } + std::string real_path = RealPath(keep_dtype.c_str()); + if (real_path.empty()) { + GELOGE(PARAM_INVALID, "Can not get real path for %s.", keep_dtype.c_str()); + return PARAM_INVALID; + } + std::ifstream ifs(real_path); + if (!ifs.is_open()) { + GELOGE(FAILED, "Open file %s failed", keep_dtype.c_str()); + return FAILED; + } + + std::string op_name; + std::vector invalid_list; + while (std::getline(ifs, op_name)) { + if (op_name.empty()) { + continue; + } + op_name = StringUtils::Trim(op_name); + bool is_find = false; + for (auto &node_ptr : graph->GetDirectNode()) { + auto op_desc = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + if ((op_desc->GetName() == op_name) || IsOriginalOpFind(op_desc, op_name)) { + is_find = true; + (void)AttrUtils::SetInt(op_desc, ATTR_NAME_KEEP_DTYPE, 1); + } + } + if (!is_find) { + invalid_list.push_back(op_name); + } + } + ifs.close(); + + if (!invalid_list.empty()) { + KeepDtypeReportError(invalid_list); + return PARAM_INVALID; + } + + return SUCCESS; +} +} // namespace ge diff --git a/ge/offline/keep_dtype_option.h b/ge/offline/keep_dtype_option.h new file mode 100644 index 00000000..2df2ed8c --- /dev/null +++ b/ge/offline/keep_dtype_option.h @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef KEEP_DTYPE_OPTION_H_ +#define KEEP_DTYPE_OPTION_H_ + +#include +#include "graph/compute_graph.h" +#include "framework/common/ge_inner_error_codes.h" + +namespace ge { +Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype); +} // namespace +#endif // KEEP_DTYPE_OPTION_H_ \ No newline at end of file diff --git a/ge/offline/main.cc b/ge/offline/main.cc index 46d25d97..e449f159 100755 --- a/ge/offline/main.cc +++ b/ge/offline/main.cc @@ -43,6 +43,7 @@ #include "parser/common/register_tbe.h" #include "register/op_registry.h" #include "single_op_parser.h" +#include "keep_dtype_option.h" using domi::BuildMode; using domi::OpRegistrationData; @@ -109,6 +110,9 @@ DEFINE_string(precision_mode, "force_fp16", "Optional; precision mode." "Support force_fp16, allow_mix_precision, allow_fp32_to_fp16, must_keep_origin_dtype."); +DEFINE_string(keep_dtype, "", + "Optional; config file to specify the precision used by the operator during compilation."); + DEFINE_string(input_format, "", "Optional; input_format, format of input data, NCHW;NHWC." "Format:\"NHWC\""); @@ -285,6 +289,8 @@ class GFlagUtils { "\n[Operator Tuning]\n" " --precision_mode precision mode, support force_fp16(default), allow_mix_precision, " "allow_fp32_to_fp16, must_keep_origin_dtype.\n" + " --keep_dtype Retains the precision of certain operators in inference " + "scenarios by using a configuration file.\n" " --auto_tune_mode Set tune mode. E.g.: \"GA,RL\", support configure multiple, spit by ,\n" " --op_select_implmode Set op select implmode. Support high_precision, high_performance. " "default: high_performance\n" @@ -421,6 +427,9 @@ class GFlagUtils { FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, ret = ge::FAILED, "check compress weight failed!"); + GE_CHK_BOOL_EXEC(ge::CheckKeepTypeParamValid(FLAGS_keep_dtype) == ge::SUCCESS, + ret = ge::FAILED, "check keep dtype failed!"); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( !ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), ret = ge::FAILED, "check_report file %s not found!!", FLAGS_check_report.c_str()); @@ -979,6 +988,13 @@ domi::Status GenerateModel(std::map &options, std::string output } } + Status ret = ge::DealKeepDtypeOption(ge::GraphUtils::GetComputeGraph(graph), FLAGS_keep_dtype); + if (ret != SUCCESS) { + (void)ge_generator.Finalize(); + (void)ge::GELib::GetInstance()->Finalize(); + return ret; + } + geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); if (geRet != ge::SUCCESS) { DOMI_LOGE("GE GenerateOfflineModel execute failed"); diff --git a/ge/offline/module.mk b/ge/offline/module.mk index 5c7a919c..eb31fcb7 100755 --- a/ge/offline/module.mk +++ b/ge/offline/module.mk @@ -11,6 +11,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg LOCAL_SRC_FILES := \ main.cc \ single_op_parser.cc \ + keep_dtype_option.cc \ ../session/omg.cc \ ../ir_build/atc_ir_common.cc \ @@ -64,6 +65,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg LOCAL_SRC_FILES := \ main.cc \ single_op_parser.cc \ + keep_dtype_option.cc \ ../session/omg.cc \ ../ir_build/atc_ir_common.cc \ @@ -117,6 +119,7 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg LOCAL_SRC_FILES := \ main.cc \ single_op_parser.cc \ + keep_dtype_option.cc \ ../session/omg.cc \ ../ir_build/atc_ir_common.cc \ diff --git a/metadef b/metadef index af156f82..98a7ac86 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit af156f825aa53a24bd30ae4065e3ea356cf555ef +Subproject commit 98a7ac86170097104a94d72b64bd1a8644c5b3c5