|
|
|
@ -194,6 +194,7 @@ DEFINE_int32(op_debug_level, 0, "Optional; configure debug level of compiler. 0(
|
|
|
|
|
"1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler");
|
|
|
|
|
DEFINE_string(enable_scope_fusion_passes, "", "Optional; validate the non-general scope fusion pass,"
|
|
|
|
|
"multiple names can be set and separated by ','.");
|
|
|
|
|
DEFINE_string(display_model_info, "0", "Optional; display model info");
|
|
|
|
|
|
|
|
|
|
class GFlagUtils {
|
|
|
|
|
public:
|
|
|
|
@ -215,7 +216,7 @@ class GFlagUtils {
|
|
|
|
|
"[General]\n"
|
|
|
|
|
" --h/help Show this help message\n"
|
|
|
|
|
" --mode Run mode. 0(default): generate offline model; 1: convert model to JSON format "
|
|
|
|
|
"3: only pre-check; 5: convert pbtxt file to JSON format\n"
|
|
|
|
|
"3: only pre-check; 5: convert pbtxt file to JSON format; 6: display model info\n"
|
|
|
|
|
"\n[Input]\n"
|
|
|
|
|
" --model Model file\n"
|
|
|
|
|
" --weight Weight file. Required when framework is Caffe\n"
|
|
|
|
@ -296,7 +297,8 @@ class GFlagUtils {
|
|
|
|
|
" --save_original_model Control whether to output original model. E.g.: true: output original model\n"
|
|
|
|
|
" --log Generate log with level. Support debug, info, warning, error, null\n"
|
|
|
|
|
" --dump_mode The switch of dump json with shape, to be used with mode 1."
|
|
|
|
|
"0(default): disable; 1: enable.");
|
|
|
|
|
"0(default): disable; 1: enable.\n"
|
|
|
|
|
" --display_model_info enable for display model info; 0(default): close display, 1: open display");
|
|
|
|
|
|
|
|
|
|
gflags::ParseCommandLineNonHelpFlags(&argc, &argv, true);
|
|
|
|
|
// Using gflags to analyze input parameters
|
|
|
|
@ -1133,6 +1135,8 @@ domi::Status GenerateOmModel() {
|
|
|
|
|
|
|
|
|
|
options.insert(std::pair<string, string>(string(ge::ENABLE_SINGLE_STREAM), FLAGS_enable_single_stream));
|
|
|
|
|
|
|
|
|
|
options.insert(std::pair<string, string>(string(ge::DISPLAY_MODEL_INFO), FLAGS_display_model_info));
|
|
|
|
|
|
|
|
|
|
SetDynamicInputSizeOptions();
|
|
|
|
|
|
|
|
|
|
if (!FLAGS_save_original_model.empty()) {
|
|
|
|
@ -1152,10 +1156,34 @@ domi::Status GenerateOmModel() {
|
|
|
|
|
if (ret != domi::SUCCESS) {
|
|
|
|
|
return domi::FAILED;
|
|
|
|
|
}
|
|
|
|
|
if (FLAGS_display_model_info == "1") {
|
|
|
|
|
GELOGI("need to display model info.");
|
|
|
|
|
return ge::ConvertOm(FLAGS_output.c_str(), "", false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return domi::SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
domi::Status DisplayModelInfo() {
|
|
|
|
|
// No model path passed in
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "",
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"});
|
|
|
|
|
return ge::FAILED,
|
|
|
|
|
"Input parameter[--om]'s value is empty!!");
|
|
|
|
|
|
|
|
|
|
// Check if the model path is valid
|
|
|
|
|
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
|
|
|
|
|
FLAGS_om != "" && !ge::CheckInputPathValid(FLAGS_om, "--om"),
|
|
|
|
|
return ge::FAILED,
|
|
|
|
|
"model file path is invalid: %s.", FLAGS_om.c_str());
|
|
|
|
|
|
|
|
|
|
if (FLAGS_framework == -1) {
|
|
|
|
|
return ge::ConvertOm(FLAGS_om.c_str(), "", false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return ge::FAILED;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
domi::Status ConvertModelToJson() {
|
|
|
|
|
Status ret = GFlagUtils::CheckConverJsonParamFlags();
|
|
|
|
|
GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check convert json params flags failed!");
|
|
|
|
@ -1176,6 +1204,8 @@ bool CheckRet(domi::Status ret) {
|
|
|
|
|
GELOGW("ATC convert model to json file failed.");
|
|
|
|
|
} else if (FLAGS_mode == PBTXT_TO_JSON) {
|
|
|
|
|
GELOGW("ATC convert pbtxt to json file failed.");
|
|
|
|
|
} else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) {
|
|
|
|
|
GELOGW("ATC display om info failed.");
|
|
|
|
|
} else {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
@ -1190,6 +1220,8 @@ bool CheckRet(domi::Status ret) {
|
|
|
|
|
GELOGI("ATC convert model to json file success.");
|
|
|
|
|
} else if (FLAGS_mode == PBTXT_TO_JSON) {
|
|
|
|
|
GELOGI("ATC convert pbtxt to json file success.");
|
|
|
|
|
} else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) {
|
|
|
|
|
GELOGW("ATC display om info success.");
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
@ -1309,6 +1341,9 @@ int main(int argc, char* argv[]) {
|
|
|
|
|
} else if (FLAGS_mode == ge::RunMode::PBTXT_TO_JSON) {
|
|
|
|
|
GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED;
|
|
|
|
|
break, "ATC convert pbtxt to json execute failed!!");
|
|
|
|
|
} else if (FLAGS_mode == ge::RunMode::DISPLAY_OM_INFO) {
|
|
|
|
|
GE_CHK_BOOL_EXEC(DisplayModelInfo() == domi::SUCCESS, ret = domi::FAILED;
|
|
|
|
|
break, "ATC DisplayModelInfo failed!!");
|
|
|
|
|
} else {
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage(
|
|
|
|
|
"E10001", {"parameter", "value", "reason"}, {"--mode", std::to_string(FLAGS_mode), kModeSupport});
|
|
|
|
|