test atc dynamic shape

pull/512/head
yskhhh 4 years ago
parent c8f1bc7e7b
commit 335e9e892c

@ -24,6 +24,7 @@ set(SRC_LIST
"helper/om_file_helper.cc" "helper/om_file_helper.cc"
"helper/model_helper.cc" "helper/model_helper.cc"
"../model/ge_model.cc" "../model/ge_model.cc"
"../model/ge_root_model.cc"
"auth/file_saver.cc" "auth/file_saver.cc"
"fp16_t.cc" "fp16_t.cc"
"math/fp16_math.cc" "math/fp16_math.cc"

@ -258,6 +258,65 @@ FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header, Mod
return SUCCESS; return SUCCESS;
} }
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status
FileSaver::SaveToFile(const string &file_path, ModelFileHeader &file_header,
vector<ModelPartitionTable *> &model_partition_tables,
const vector<vector<ModelPartition>> &all_partition_datas) {
file_header.is_encrypt = ModelEncryptType::UNENCRYPTED;
const Status ret = SaveWithFileHeader(file_path, file_header, model_partition_tables, all_partition_datas);
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "save file failed, file_path:%s, file header len:%u.",
file_path.c_str(), file_header.length);
return SUCCESS;
}
Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header,
vector<ModelPartitionTable *> &model_partition_tables,
const vector<vector<ModelPartition>> &all_partition_datas) {
GE_CHK_BOOL_EXEC(model_partition_tables.size() == all_partition_datas.size(),
return PARAM_INVALID,
"model table size %zu does not match partition size %zu",
model_partition_tables.size(), all_partition_datas.size())
for (size_t index = 0; index < model_partition_tables.size(); ++index) {
auto &cur_partiton_data = all_partition_datas[index];
auto &cur_model_partition_table = *model_partition_tables[index];
GE_CHK_BOOL_RET_STATUS(!cur_partiton_data.empty() && cur_model_partition_table.num != 0
&& cur_model_partition_table.num == cur_partiton_data.size(), FAILED,
"Invalid param:partition data size is (%u), model_partition_table.num is (%zu).",
cur_model_partition_table.num, cur_partiton_data.size());
}
// Open file
int32_t fd = 0;
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(OpenFile(fd, file_path) != SUCCESS, return FAILED);
Status ret = SUCCESS;
do {
// Write file header
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
WriteData(static_cast<const void *>(&file_header), sizeof(ModelFileHeader), fd) != SUCCESS, ret = FAILED;
break);
for (size_t index = 0; index < model_partition_tables.size(); ++index) {
// Write model partition table
auto &cur_tabel = *model_partition_tables[index];
uint32_t table_size = static_cast<uint32_t>(SIZE_OF_MODEL_PARTITION_TABLE(cur_tabel));
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
WriteData(static_cast<const void *>(&cur_tabel), table_size, fd) != SUCCESS, ret = FAILED; break);
// Write partition data
auto &cur_partition_datas = all_partition_datas[index];
for (const auto &partition_data : cur_partition_datas) {
GELOGI("GC:size[%zu]", partition_data.size);
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(
WriteData(static_cast<const void *>(partition_data.data), partition_data.size, fd) != SUCCESS, ret = FAILED;
break);
}
}
} while (0);
// Close file
GE_CHK_BOOL_RET_STATUS(mmClose(fd) == EN_OK, FAILED, "Close file failed.");
return ret;
}
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(const string &file_path, const void *data, FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(const string &file_path, const void *data,
int len) { int len) {
if (data == nullptr || len <= 0) { if (data == nullptr || len <= 0) {

@ -74,6 +74,10 @@ class FileSaver {
ModelPartitionTable &model_partition_table, ModelPartitionTable &model_partition_table,
const std::vector<ModelPartition> &partition_datas); const std::vector<ModelPartition> &partition_datas);
static Status SaveToFile(const string &file_path, ModelFileHeader &file_header,
vector<ModelPartitionTable *> &model_partition_tables,
const vector<vector<ModelPartition>> &all_partition_datas);
static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header, static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header,
ModelPartitionTable &model_partition_table, ModelPartitionTable &model_partition_table,
const std::vector<ModelPartition> &partitionDatas, const std::vector<ModelPartition> &partitionDatas,
@ -108,6 +112,9 @@ class FileSaver {
static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header, static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header,
ModelPartitionTable &model_partition_table, ModelPartitionTable &model_partition_table,
const std::vector<ModelPartition> &partition_datas); const std::vector<ModelPartition> &partition_datas);
static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header,
vector<ModelPartitionTable *> &model_partition_tables,
const vector<vector<ModelPartition>> &all_partition_datas);
}; };
} // namespace ge } // namespace ge
#endif // GE_COMMON_AUTH_FILE_SAVER_H_ #endif // GE_COMMON_AUTH_FILE_SAVER_H_

@ -7,6 +7,7 @@ GE_COMMON_LOCAL_SRC_FILES := \
helper/om_file_helper.cc \ helper/om_file_helper.cc \
helper/model_helper.cc \ helper/model_helper.cc \
../model/ge_model.cc \ ../model/ge_model.cc \
../model/ge_root_model.cc \
auth/file_saver.cc \ auth/file_saver.cc \
fp16_t.cc \ fp16_t.cc \
math/fp16_math.cc \ math/fp16_math.cc \

File diff suppressed because it is too large Load Diff

@ -52,6 +52,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(u
return SUCCESS; return SUCCESS;
} }
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(uint8_t *model_data,
uint32_t model_data_size,
uint32_t model_num) {
Status status = LoadModelPartitionTable(model_data, model_data_size, model_num);
if (status != SUCCESS) {
return status;
}
is_inited_ = true;
return SUCCESS;
}
// Use both // Use both
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type, FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type,
ModelPartition &partition) { ModelPartition &partition) {
@ -79,6 +90,37 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetMod
return SUCCESS; return SUCCESS;
} }
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetModelPartition(ModelPartitionType type,
ModelPartition &partition,
size_t model_index) {
if (!is_inited_) {
GELOGE(PARAM_INVALID, "OmFileLoadHelper has not been initialized!");
return PARAM_INVALID;
}
if (model_index >= model_contexts_.size()) {
GELOGE(PARAM_INVALID, "cur index : %zu, model_contexts size:%zu", model_index, model_contexts_.size());
return PARAM_INVALID;
}
auto &cur_ctx = model_contexts_[model_index];
bool found = false;
for (ModelPartition &part : cur_ctx.partition_datas_) {
if (part.type == type) {
partition = part;
found = true;
break;
}
}
if (!found) {
if (type != ModelPartitionType::TBE_KERNELS && type != ModelPartitionType::WEIGHTS_DATA &&
type != ModelPartitionType::CUST_AICPU_KERNELS) {
GELOGE(FAILED, "GetModelPartition:type:%d is not in partition_datas!", static_cast<int>(type));
return FAILED;
}
}
return SUCCESS;
}
Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const { Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const {
// Parameter validity check // Parameter validity check
if (model.model_data == nullptr) { if (model.model_data == nullptr) {
@ -148,6 +190,61 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint
return SUCCESS; return SUCCESS;
} }
Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, uint32_t model_data_size, uint32_t model_num) {
if (model_data == nullptr) {
GELOGE(PARAM_INVALID, "Param model_data must not be null!");
return PARAM_INVALID;
}
uint32_t cur_offset = 0;
for (uint32_t index = 0; index < model_num; ++index) {
// Init partition table
auto partition_table = reinterpret_cast<ModelPartitionTable *>(model_data + cur_offset);
size_t partition_table_size = SIZE_OF_MODEL_PARTITION_TABLE(*partition_table);
cur_offset += partition_table_size;
GELOGD("Cur model index %zu: ModelPartitionTable num :%u, "
"ModelFileHeader length :%zu, ModelPartitionTable length :%zu",
index, partition_table->num, sizeof(ModelFileHeader), partition_table_size);
if (model_data_size <= cur_offset) {
GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "invalid model data, partition_table->num:%u, model data size %u",
partition_table->num, model_data_size);
return GE_EXEC_MODEL_DATA_SIZE_INVALID;
}
for (uint32_t i = 0; i < partition_table->num; i++) {
ModelPartition partition;
partition.size = partition_table->partition[i].mem_size;
partition.data = model_data + cur_offset;
partition.type = partition_table->partition[i].type;
if (index >= model_contexts_.size()) {
if (index != model_contexts_.size()) {
GELOGE(FAILED, "cur index is %zu make model_contexts_ overflow", index);
return FAILED;
}
OmFileContext tmp_ctx;
tmp_ctx.partition_datas_.push_back(partition);
model_contexts_.push_back(tmp_ctx);
} else {
model_contexts_[index].partition_datas_.push_back(partition);
}
if (partition.size > model_data_size || cur_offset > model_data_size - partition.size) {
GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.",
partition.size + cur_offset, model_data_size);
return GE_EXEC_MODEL_DATA_SIZE_INVALID;
}
cur_offset += partition.size;
GELOGD("Partition, type:%d, size:%u, model_index:%zu", static_cast<int>(partition.type), partition.size, index);
}
}
if (cur_offset != model_data_size) {
GELOGE(FAILED, "do not get the complete model, read end offset:%zu, all size:%zu", cur_offset, model_data_size);
return FAILED;
}
return SUCCESS;
}
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::vector<ModelPartition> FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::vector<ModelPartition>
&OmFileSaveHelper::GetModelPartitions() const { &OmFileSaveHelper::GetModelPartitions() const {
return context_.partition_datas_; return context_.partition_datas_;
@ -172,6 +269,28 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSave
return partition_table; return partition_table;
} }
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelPartitionTable *OmFileSaveHelper::GetPartitionTable(
size_t cur_ctx_index) {
auto &cur_ctx = model_contexts_[cur_ctx_index];
auto partition_size = static_cast<uint32_t>(cur_ctx.partition_datas_.size());
// Build ModelPartitionTable, flex array
cur_ctx.partition_table_.clear();
cur_ctx.partition_table_.resize(sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * partition_size, 0);
auto partition_table = reinterpret_cast<ModelPartitionTable *>(cur_ctx.partition_table_.data());
partition_table->num = partition_size;
uint32_t mem_offset = 0;
for (uint32_t i = 0; i < partition_size; i++) {
ModelPartition partition = cur_ctx.partition_datas_[i];
partition_table->partition[i] = {partition.type, mem_offset, partition.size};
mem_offset += partition.size;
GELOGD("Partition, type:%d, size:%u", static_cast<int>(partition.type), partition.size);
}
return partition_table;
}
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPartition(ModelPartition &partition) { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPartition(ModelPartition &partition) {
if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) { if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) {
GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size); GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size);
@ -182,6 +301,27 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileSaveHelper::AddPar
return SUCCESS; return SUCCESS;
} }
Status OmFileSaveHelper::AddPartition(ModelPartition &partition, size_t cur_index) {
if (ge::CheckUint32AddOverflow(context_.model_data_len_, partition.size) != SUCCESS) {
GELOGE(FAILED, "UINT32 %u and %u addition can result in overflow!", context_.model_data_len_, partition.size);
return FAILED;
}
if (cur_index >= model_contexts_.size()) {
if (cur_index != model_contexts_.size()) {
GELOGE(FAILED, "cur index is %zu make model_contexts_ overflow", cur_index);
return FAILED;
}
OmFileContext tmp_ctx;
tmp_ctx.model_data_len_ += partition.size;
tmp_ctx.partition_datas_.push_back(partition);
model_contexts_.push_back(tmp_ctx);
} else {
model_contexts_[cur_index].model_data_len_ += partition.size;
model_contexts_[cur_index].partition_datas_.push_back(partition);
}
return SUCCESS;
}
Status OmFileSaveHelper::SaveModel(const SaveParam &save_param, const char *output_file, ModelBufferData &model, Status OmFileSaveHelper::SaveModel(const SaveParam &save_param, const char *output_file, ModelBufferData &model,
bool is_offline) { bool is_offline) {
(void)save_param.cert_file; (void)save_param.cert_file;
@ -198,6 +338,10 @@ Status OmFileSaveHelper::SaveModel(const SaveParam &save_param, const char *outp
Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferData &model, bool is_offline) { Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferData &model, bool is_offline) {
#if !defined(NONSUPPORT_SAVE_TO_FILE) #if !defined(NONSUPPORT_SAVE_TO_FILE)
if (context_.partition_datas_.empty()) {
GE_CHK_BOOL_EXEC(!model_contexts_.empty(), return FAILED, "mode contexts empty");
context_ = model_contexts_.front();
}
uint32_t model_data_len = context_.model_data_len_; uint32_t model_data_len = context_.model_data_len_;
if (model_data_len == 0) { if (model_data_len == 0) {
GELOGE(domi::PARAM_INVALID, "Model data len error! should not be 0"); GELOGE(domi::PARAM_INVALID, "Model data len error! should not be 0");
@ -231,4 +375,53 @@ Status OmFileSaveHelper::SaveModelToFile(const char *output_file, ModelBufferDat
return SUCCESS; return SUCCESS;
#endif #endif
} }
Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char *output_file,
ModelBufferData &model, bool is_offline) {
(void)save_param.cert_file;
(void)save_param.ek_file;
(void)save_param.encode_mode;
(void)save_param.hw_key_file;
(void)save_param.pri_key_file;
#if !defined(NONSUPPORT_SAVE_TO_FILE)
vector<ModelPartitionTable *> model_partition_tabels;
vector<vector<ModelPartition>> all_model_partitions;
for (size_t ctx_index = 0; ctx_index < model_contexts_.size(); ++ctx_index) {
auto &cur_ctx = model_contexts_[ctx_index];
uint32_t cur_model_data_len = cur_ctx.model_data_len_;
if (cur_model_data_len == 0) {
GELOGE(domi::PARAM_INVALID, "Model data len error! should not be 0");
return domi::PARAM_INVALID;
}
auto tmp_table = GetPartitionTable(ctx_index);
if (tmp_table == nullptr) {
GELOGE(ge::GE_GRAPH_SAVE_FAILED, "SaveModelToFile execute failed: partition_table is NULL.");
return ge::GE_GRAPH_SAVE_FAILED;
}
uint32_t size_of_table = SIZE_OF_MODEL_PARTITION_TABLE(*tmp_table);
FMK_UINT32_ADDCHECK(size_of_table, cur_model_data_len)
FMK_UINT32_ADDCHECK(size_of_table + cur_model_data_len, model_header_.length)
model_header_.length += size_of_table + cur_model_data_len;
model_partition_tabels.push_back(tmp_table);
all_model_partitions.push_back(cur_ctx.partition_datas_);
GELOGD("sizeof(ModelPartitionTable):%u, cur_model_data_len:%u, cur_context_index:%zu",
size_of_table, cur_model_data_len, ctx_index);
}
Status ret;
if (is_offline) {
ret = FileSaver::SaveToFile(output_file, model_header_, model_partition_tabels, all_model_partitions);
} else {
GELOGW("do not support save ge root model to buff now");
return FAILED;
}
if (ret == SUCCESS) {
GELOGD("Save model success without encrypt.");
}
return ret;
#else
return SUCCESS;
#endif
}
} // namespace ge } // namespace ge

@ -801,7 +801,7 @@ const uint32_t XRGB_CHN_NUM = 4;
/// ///
const bool DEFAULT_GLOBAL_POOLING = false; const bool DEFAULT_GLOBAL_POOLING = false;
const uint32_t MODEL_VERSION = 0x10000000; ///< Model version 1.0/// const uint32_t MODEL_VERSION = 0x20000000; ///< Model version 2.0///
// Eltwise's input size // Eltwise's input size
const int ELTWISE_MIN_INPUT_SIZE = 2; const int ELTWISE_MIN_INPUT_SIZE = 2;

@ -240,6 +240,8 @@ class GeGenerator::Impl {
Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model); Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model);
Status SaveRootModel(const string &file_name_prefix, GeRootModelPtr &model, ModelBufferData &model_buff);
Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs, Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs,
const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); const vector<GeTensor> &inputs, const vector<GeTensor> &outputs);
@ -505,19 +507,7 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr
GE_CHECK_NOTNULL(ge_root_model); GE_CHECK_NOTNULL(ge_root_model);
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); GE_CHECK_NOTNULL(ge_root_model->GetRootGraph());
ModelHelper model_helper; ret = impl_->SaveRootModel(file_name_prefix, ge_root_model, model);
string model_name = "";
Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), model_name);
if (name_ret != SUCCESS) {
ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"});
GELOGE(FAILED, "Get model_name failed. Param --output is invalid");
return PARAM_INVALID;
}
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel();
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()];
GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model can not be null");
ge_model->SetName(model_name);
ret = impl_->SaveModel(file_name_prefix, ge_model, model);
if (ret != SUCCESS) { if (ret != SUCCESS) {
GELOGE(ret, "Save model failed"); GELOGE(ret, "Save model failed");
if (impl_->graph_manager_.Finalize() != SUCCESS) { if (impl_->graph_manager_.Finalize() != SUCCESS) {
@ -712,6 +702,44 @@ Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &
return SUCCESS; return SUCCESS;
} }
Status GeGenerator::Impl::SaveRootModel(const string &file_name_prefix, GeRootModelPtr &ge_root_model,
ModelBufferData &model_buff) {
bool is_unknown_shape = false;
auto ret = ge_root_model->CheckIsUnknownShape(is_unknown_shape);
if (ret != SUCCESS) {
GELOGE(FAILED, "Check root model is unkonwn shape failed");
return FAILED;
}
GELOGD("begin save root model, cur model is unkonwn shape model ? : %d", is_unknown_shape);
GE_CHK_BOOL_EXEC(!ge_root_model->GetSubgraphInstanceNameToModel().empty(), return FAILED,
"ge root model has no sub model")
GeModelPtr model_root = nullptr;
if (is_unknown_shape) {
model_root = make_shared<GeModel>();
model_root->SetGraph(GraphUtils::CreateGraphFromComputeGraph(ge_root_model->GetRootGraph()));
ge_root_model->SetSubgraphInstanceNameToModel(ge_root_model->GetRootGraph()->GetName(), model_root);
model_root->SetName(ge_root_model->GetRootGraph()->GetName());
} else {
model_root = ge_root_model->GetSubgraphInstanceNameToModel().begin()->second;
}
// set atc version
if (!SetAtcVersionInfo(*(model_root.get()))) {
GELOGW("SetPackageVersionInfo of atc failed!");
}
// set opp version
if (!SetOppVersionInfo(*(model_root.get()))) {
GELOGW("SetPackageVersionInfo of ops failed!");
}
ModelHelper model_helper;
model_helper.SetSaveMode(is_offline_);
ret = model_helper.SaveToOmRootModel(ge_root_model, save_param_, file_name_prefix, model_buff, is_unknown_shape);
if (ret != SUCCESS) {
GELOGE(ret, "Save to om model failed");
return ret;
}
return SUCCESS;
}
Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs, Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs,
GeRootModelPtr &ge_root_model) { GeRootModelPtr &ge_root_model) {
static std::atomic<GraphId> atomic_graph_id(0); static std::atomic<GraphId> atomic_graph_id(0);

@ -23,6 +23,7 @@
namespace ge { namespace ge {
class GeRootModel { class GeRootModel {
public: public:
GeRootModel() = default;
explicit GeRootModel(ComputeGraphPtr &root_graph) : root_graph_(root_graph), model_id_(INVALID_MODEL_ID) {}; explicit GeRootModel(ComputeGraphPtr &root_graph) : root_graph_(root_graph), model_id_(INVALID_MODEL_ID) {};
~GeRootModel() = default; ~GeRootModel() = default;
@ -35,11 +36,11 @@ class GeRootModel {
void SetModelId(uint32_t model_id) { model_id_ = model_id; } void SetModelId(uint32_t model_id) { model_id_ = model_id; }
uint32_t GetModelId() const { return model_id_; } uint32_t GetModelId() const { return model_id_; }
Status CheckIsUnknownShape(bool &is_dynamic_shape); Status CheckIsUnknownShape(bool &is_dynamic_shape);
void SetRootGraph(ComputeGraphPtr graph) { root_graph_ = graph; }
private: private:
ComputeGraphPtr root_graph_; ComputeGraphPtr root_graph_ = nullptr;
std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_; std::map<std::string, GeModelPtr> subgraph_instance_name_to_model_;
uint32_t model_id_; uint32_t model_id_ = 0;
}; };
} // namespace ge } // namespace ge
using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>; using GeRootModelPtr = std::shared_ptr<ge::GeRootModel>;

@ -25,6 +25,7 @@
#include "common/types.h" #include "common/types.h"
#include "graph/model.h" #include "graph/model.h"
#include "model/ge_model.h" #include "model/ge_model.h"
#include "model/ge_root_model.h"
namespace ge { namespace ge {
class ModelHelper { class ModelHelper {
@ -32,17 +33,22 @@ class ModelHelper {
ModelHelper() = default; ModelHelper() = default;
~ModelHelper(); ~ModelHelper();
Status SaveToOmModel(const GeModelPtr &ge_model, const SaveParam &save_param, Status SaveToOmModel(const GeModelPtr &ge_model, const SaveParam &save_param, const std::string &output_file,
const std::string &output_file, ge::ModelBufferData &model); ge::ModelBufferData &model);
Status SaveToOmRootModel(const GeRootModelPtr &ge_root_model, const SaveParam &save_param, const string &output_file,
ModelBufferData &model, bool is_unknown_shape);
Status SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file); Status SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file);
Status LoadModel(const ge::ModelData &model_data); Status LoadModel(const ge::ModelData &model_data);
Status LoadRootModel(const ge::ModelData &model_data);
Status GetModelBufferData(ge::ModelBufferData &model); Status GetModelBufferData(ge::ModelBufferData &model);
const ModelFileHeader *GetFileHeader() const { return file_header_; } const ModelFileHeader *GetFileHeader() const { return file_header_; }
GeModelPtr GetGeModel(); GeModelPtr GetGeModel();
GeRootModelPtr GetGeRootModel();
void SetSaveMode(bool val) { is_offline_ = val; } void SetSaveMode(bool val) { is_offline_ = val; }
bool GetSaveMode(void) const { return is_offline_; } bool GetSaveMode(void) const { return is_offline_; }
bool GetModelType() const { return is_unknown_shape_model_; };
Status GetBaseNameFromFileName(const std::string &file_name, std::string &base_name); Status GetBaseNameFromFileName(const std::string &file_name, std::string &base_name);
Status GetModelNameFromMergedGraphName(const std::string &graph_name, std::string &model_name); Status GetModelNameFromMergedGraphName(const std::string &graph_name, std::string &model_name);
@ -50,24 +56,46 @@ class ModelHelper {
private: private:
bool is_assign_model_ = false; bool is_assign_model_ = false;
bool is_offline_ = true; bool is_offline_ = true;
bool is_unknown_shape_model_ = false;
ModelFileHeader *file_header_ = nullptr; ModelFileHeader *file_header_ = nullptr;
// Encrypted model need delete temp model and unencrypted model need not delete model // Encrypted model need delete temp model and unencrypted model need not delete model
uint8_t *model_addr_tmp_ = nullptr; uint8_t *model_addr_tmp_ = nullptr;
uint32_t model_len_tmp_ = 0; uint32_t model_len_tmp_ = 0;
GeModelPtr model_; GeModelPtr model_;
GeRootModelPtr root_model_;
ModelHelper(const ModelHelper &); ModelHelper(const ModelHelper &);
ModelHelper &operator=(const ModelHelper &); ModelHelper &operator=(const ModelHelper &);
Status GenerateGeModel(OmFileLoadHelper &om_load_helper); Status GenerateGeModel(OmFileLoadHelper &om_load_helper);
Status GenerateGeRootModel(OmFileLoadHelper &om_load_helper);
Status LoadModelData(OmFileLoadHelper &om_load_helper); Status LoadModelData(OmFileLoadHelper &om_load_helper);
void SetModelToGeModel(ge::Model &model); void SetModelToGeModel(ge::Model &model);
Status LoadModelData(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index);
Status LoadWeights(OmFileLoadHelper &om_load_helper); Status LoadWeights(OmFileLoadHelper &om_load_helper);
Status LoadWeights(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index);
Status LoadTask(OmFileLoadHelper &om_load_helper); Status LoadTask(OmFileLoadHelper &om_load_helper);
Status LoadTask(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index);
Status LoadTBEKernelStore(OmFileLoadHelper &om_load_helper); Status LoadTBEKernelStore(OmFileLoadHelper &om_load_helper);
Status LoadTBEKernelStore(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index);
Status LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper); Status LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper);
Status LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper, GeModelPtr &cur_model, size_t mode_index);
Status ReleaseLocalModelData() noexcept; Status ReleaseLocalModelData() noexcept;
Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type,
ModelPartitionType type, const uint8_t* data, size_t size); const uint8_t *data, size_t size, size_t model_index);
Status SaveModelDef(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model,
Buffer &model_buffer, size_t model_index = 0);
Status SaveModelWeights(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model,
size_t model_index = 0);
Status SaveModelTbeKernel(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model,
size_t model_index = 0);
Status SaveModelCustAICPU(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model,
size_t model_index = 0);
Status SaveModelTaskDef(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model,
Buffer &task_buffer, size_t model_index = 0);
Status SaveModelHeader(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model,
size_t model_num = 1);
Status SaveAllModelPartiton(shared_ptr<OmFileSaveHelper> &om_file_save_helper, const GeModelPtr &ge_model,
Buffer &model_buffer, Buffer &task_buffer, size_t model_index = 0);
}; };
} // namespace ge } // namespace ge
#endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ #endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_

@ -39,7 +39,7 @@ struct ModelPartition {
struct OmFileContext { struct OmFileContext {
std::vector<ModelPartition> partition_datas_; std::vector<ModelPartition> partition_datas_;
std::vector<char> partition_table_; std::vector<char> partition_table_;
uint32_t model_data_len_; uint32_t model_data_len_ = 0;
}; };
struct SaveParam { struct SaveParam {
@ -57,15 +57,23 @@ class OmFileLoadHelper {
Status Init(uint8_t *model_data, const uint32_t model_data_size); Status Init(uint8_t *model_data, const uint32_t model_data_size);
Status Init(uint8_t *model_data, const uint32_t model_data_size, uint32_t model_num);
Status GetModelPartition(ModelPartitionType type, ModelPartition &partition); Status GetModelPartition(ModelPartitionType type, ModelPartition &partition);
Status GetModelPartition(ModelPartitionType type, ModelPartition &partition, size_t model_index);
OmFileContext context_; OmFileContext context_;
vector<OmFileContext> model_contexts_;
private: private:
Status CheckModelValid(const ge::ModelData &model) const; Status CheckModelValid(const ge::ModelData &model) const;
Status LoadModelPartitionTable(uint8_t *model_data, const uint32_t model_data_size); Status LoadModelPartitionTable(uint8_t *model_data, const uint32_t model_data_size);
Status LoadModelPartitionTable(uint8_t *model_data, const uint32_t model_data_size, uint32_t model_num);
bool is_inited_{false}; bool is_inited_{false};
}; };
@ -79,15 +87,23 @@ class OmFileSaveHelper {
Status AddPartition(ModelPartition &partition); Status AddPartition(ModelPartition &partition);
Status AddPartition(ModelPartition &partition, size_t cur_index);
const std::vector<ModelPartition> &GetModelPartitions() const; const std::vector<ModelPartition> &GetModelPartitions() const;
Status SaveModel(const SaveParam &save_param, const char *target_file, Status SaveModel(const SaveParam &save_param, const char *target_file, ge::ModelBufferData &model,
ge::ModelBufferData& model, bool is_offline = true); bool is_offline = true);
Status SaveModelToFile(const char *output_file, ge::ModelBufferData &model, bool is_offline = true); Status SaveModelToFile(const char *output_file, ge::ModelBufferData &model, bool is_offline = true);
vector<OmFileContext> model_contexts_;
ModelFileHeader model_header_; ModelFileHeader model_header_;
OmFileContext context_; OmFileContext context_;
ModelPartitionTable *GetPartitionTable(size_t cur_ctx_index);
Status SaveRootModel(const SaveParam &save_param, const char *output_file, ModelBufferData &model, bool is_offline);
}; };
} // namespace ge } // namespace ge
#endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ #endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_

@ -605,7 +605,7 @@ static constexpr uint32_t MODEL_FILE_CHECKSUM_LENGTH = 64;
/// ///
/// @brief length of the reserved field in the model file header /// @brief length of the reserved field in the model file header
/// ///
static constexpr uint32_t MODEL_FILE_RESERVED_LENGTH = 79; static constexpr uint32_t MODEL_FILE_RESERVED_LENGTH = 75;
/// ///
/// @ingroup domi_omg /// @ingroup domi_omg
@ -843,9 +843,10 @@ struct ModelFileHeader {
uint32_t ops = 0; // Computing power (Kops) uint32_t ops = 0; // Computing power (Kops)
uint8_t userdefineinfo[USER_DEFINE_INFO_LENGTH] = {0}; // User-defined information. The value contains 32 characters uint8_t userdefineinfo[USER_DEFINE_INFO_LENGTH] = {0}; // User-defined information. The value contains 32 characters
uint32_t om_ir_version = 0; uint32_t om_ir_version = 0;
uint32_t model_num = 0;
uint8_t platform_version[PLATFORM_VERSION_LEN] = {0}; uint8_t platform_version[PLATFORM_VERSION_LEN] = {0};
uint8_t platform_type = {0}; uint8_t platform_type = {0};
uint8_t reserved[MODEL_FILE_RESERVED_LENGTH] = {0}; // Reserved field 79 uint8_t reserved[MODEL_FILE_RESERVED_LENGTH] = {0}; // Reserved field 75
}; };
static constexpr uint8_t TARGET_TYPE_LTTE_8BIT = 0; static constexpr uint8_t TARGET_TYPE_LTTE_8BIT = 0;

Loading…
Cancel
Save