|
|
|
@ -206,6 +206,8 @@ DEFINE_string(mdl_bank_path, "", "Optional; model bank path");
|
|
|
|
|
|
|
|
|
|
DEFINE_string(op_bank_path, "", "Optional; op bank path");
|
|
|
|
|
|
|
|
|
|
DEFINE_string(display_model_info, "0", "Optional; display model info");
|
|
|
|
|
|
|
|
|
|
class GFlagUtils {
|
|
|
|
|
public:
|
|
|
|
|
/**
|
|
|
|
@ -225,7 +227,8 @@ class GFlagUtils {
|
|
|
|
|
"===== Basic Functionality =====\n"
|
|
|
|
|
"[General]\n"
|
|
|
|
|
" --h/help Show this help message\n"
|
|
|
|
|
" --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format "
|
|
|
|
|
" --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format; "
|
|
|
|
|
"6: display model info"
|
|
|
|
|
"3: only pre-check; 5: convert ge dump txt file to JSON format\n"
|
|
|
|
|
"\n[Input]\n"
|
|
|
|
|
" --model Model file\n"
|
|
|
|
@ -313,7 +316,8 @@ class GFlagUtils {
|
|
|
|
|
" --op_compiler_cache_dir Set the save path of operator compilation cache files.\n"
|
|
|
|
|
"Default value: $HOME/atc_data\n"
|
|
|
|
|
" --op_compiler_cache_mode Set the operator compilation cache mode."
|
|
|
|
|
"Options are disable(default), enable and force(force to refresh the cache)");
|
|
|
|
|
"Options are disable(default), enable and force(force to refresh the cache)\n"
|
|
|
|
|
" --display_model_info enable for display model info; 0(default): close display, 1: open display");
|
|
|
|
|
|
|
|
|
|
gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
|
|
|
|
|
// Using gflags to analyze input parameters
|
|
|
|
@ -862,7 +866,7 @@ domi::Status GenerateInfershapeJson() {
|
|
|
|
|
static Status ConvertModelToJson(int fwk_type, const string &model_file, const string &json_file) {
|
|
|
|
|
Status ret = ge::SUCCESS;
|
|
|
|
|
if (fwk_type == -1) {
|
|
|
|
|
ret = ge::ConvertOmModelToJson(model_file.c_str(), json_file.c_str());
|
|
|
|
|
ret = ge::ConvertOm(model_file.c_str(), json_file.c_str(), true);
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1176,6 +1180,8 @@ domi::Status GenerateOmModel() {
|
|
|
|
|
options.insert(std::pair<string, string>(string(ge::MDL_BANK_PATH_FLAG), FLAGS_mdl_bank_path));
|
|
|
|
|
|
|
|
|
|
options.insert(std::pair<string, string>(string(ge::OP_BANK_PATH_FLAG), FLAGS_op_bank_path));
|
|
|
|
|
|
|
|
|
|
options.insert(std::pair<string, string>(string(ge::DISPLAY_MODEL_INFO), FLAGS_display_model_info));
|
|
|
|
|
// set enable scope fusion passes
|
|
|
|
|
SetEnableScopeFusionPasses(FLAGS_enable_scope_fusion_passes);
|
|
|
|
|
// print atc option map
|
|
|
|
@ -1188,6 +1194,11 @@ domi::Status GenerateOmModel() {
|
|
|
|
|
return domi::FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (FLAGS_display_model_info == "1") {
|
|
|
|
|
GELOGI("need to display model info.");
|
|
|
|
|
return ge::ConvertOm(FLAGS_output.c_str(), "", false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return domi::SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1201,6 +1212,26 @@ domi::Status ConvertModelToJson() {
|
|
|
|
|
return domi::SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
domi::Status DisplayModelInfo() {
|
|
|
|
|
// No model path passed in
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "",
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"});
|
|
|
|
|
return ge::FAILED,
|
|
|
|
|
"Input parameter[--om]'s value is empty!!");
|
|
|
|
|
|
|
|
|
|
// Check if the model path is valid
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
|
|
|
|
|
FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"),
|
|
|
|
|
return ge::FAILED,
|
|
|
|
|
"model file path is invalid: %s.", FLAGS_om.c_str());
|
|
|
|
|
|
|
|
|
|
if (FLAGS_framework == -1) {
|
|
|
|
|
return ge::ConvertOm(FLAGS_om.c_str(), "", false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ge::FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CheckRet(domi::Status ret) {
|
|
|
|
|
if (ret != domi::SUCCESS) {
|
|
|
|
|
if (FLAGS_mode == ONLY_PRE_CHECK) {
|
|
|
|
@ -1344,6 +1375,9 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
} else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) {
|
|
|
|
|
GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED;
|
|
|
|
|
break, "ATC convert pbtxt to json execute failed!!");
|
|
|
|
|
} else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) {
|
|
|
|
|
GE_CHK_BOOL_EXEC(DisplayModelInfo() == domi::SUCCESS, ret = domi::FAILED;
|
|
|
|
|
break, "ATC DisplayModelInfo failed!!");
|
|
|
|
|
} else {
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage(
|
|
|
|
|
"E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport});
|
|
|
|
|