|
|
|
@ -565,6 +565,44 @@ bool GeGenerator::Impl::SetOmSystemInfo(AttrHolder &obj) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GeGenerator::SetModelNameForDump(const GeRootModelPtr &ge_root_model) {
|
|
|
|
|
bool is_unknown_shape = false;
|
|
|
|
|
Status ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
GELOGE(FAILED, "[Check][IsUnknownShape]Check root model is unknown shape failed, model id:%u",
|
|
|
|
|
ge_root_model->GetModelId());
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Check root model is unknown shape failed, model id:%zu",
|
|
|
|
|
ge_root_model->GetModelId());
|
|
|
|
|
return FAILED;
|
|
|
|
|
}
|
|
|
|
|
GeModelPtr model_root = nullptr;
|
|
|
|
|
if (is_unknown_shape) {
|
|
|
|
|
model_root = MakeShared<GeModel>();
|
|
|
|
|
GE_CHECK_NOTNULL(model_root);
|
|
|
|
|
model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
|
|
|
|
|
ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ModelHelper model_helper;
|
|
|
|
|
string model_name;
|
|
|
|
|
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
|
|
|
|
|
Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
|
|
|
|
|
model_name);
|
|
|
|
|
if (name_ret != SUCCESS) {
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
|
|
|
|
|
GELOGE(FAILED, "[Check][GetModelNameStep]Get model_name failed. Param --output is invalid, root graph name: %s",
|
|
|
|
|
ge_root_model->GetRootGraph()->GetName().c_str());
|
|
|
|
|
REPORT_CALL_ERROR("E19999", "Get model_name failed. Param --output is invalid,",
|
|
|
|
|
"root graph name: %s", ge_root_model->GetRootGraph()->GetName().c_str());
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
|
|
|
|
|
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
|
|
|
|
|
GE_CHECK_NOTNULL(ge_model);
|
|
|
|
|
ge_model->SetName(model_name);
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs,
|
|
|
|
|
ModelBufferData &model, bool is_offline) {
|
|
|
|
|
rtContext_t ctx = nullptr;
|
|
|
|
@ -599,20 +637,10 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
GE_CHECK_NOTNULL(ge_root_model);
|
|
|
|
|
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
|
|
|
|
|
ModelHelper model_helper;
|
|
|
|
|
string model_name = "";
|
|
|
|
|
Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(),
|
|
|
|
|
model_name);
|
|
|
|
|
if (name_ret != SUCCESS) {
|
|
|
|
|
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
|
|
|
|
|
GELOGE(FAILED, "Get model_name failed. Param --output is invalid.");
|
|
|
|
|
return PARAM_INVALID;
|
|
|
|
|
ret = SetModelNameForDump(ge_root_model);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
|
|
|
|
|
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
|
|
|
|
|
GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model cannot be null");
|
|
|
|
|
ge_model->SetName(model_name);
|
|
|
|
|
ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|
GELOGE(ret, "Save model failed");
|
|
|
|
@ -882,13 +910,12 @@ Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootMo
|
|
|
|
|
"ge root model has no sub model")
|
|
|
|
|
GeModelPtr model_root = nullptr;
|
|
|
|
|
if (is_unknown_shape) {
|
|
|
|
|
model_root = make_shared<GeModel>();
|
|
|
|
|
model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
|
|
|
|
|
ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
|
|
|
|
|
model_root->SetName(ge_root_model->GetRootGraph()->GetName());
|
|
|
|
|
auto name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
|
|
|
|
|
model_root = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
|
|
|
|
|
} else {
|
|
|
|
|
model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second;
|
|
|
|
|
}
|
|
|
|
|
GE_CHECK_NOTNULL(model_root);
|
|
|
|
|
// set atc version
|
|
|
|
|
if (!SetAtcVersionInfo(*(model_root.get()))) {
|
|
|
|
|
GELOGW("SetPackageVersionInfo of atc failed!");
|
|
|
|
|