support ir build online model

pull/1392/MERGE^2
zhengyuanhua 4 years ago
parent da104ed39d
commit c6eeb0a745

@ -190,6 +190,71 @@ Status FileSaver::SaveToBuffWithFileHeader(const ModelFileHeader &file_header,
return SUCCESS;
}
Status FileSaver::SaveToBuffWithFileHeader(const ModelFileHeader &file_header,
vector<ModelPartitionTable *> &model_partition_tables,
const std::vector<std::vector<ModelPartition>> &all_partition_datas,
ge::ModelBufferData &model) {
GE_CHK_BOOL_RET_STATUS(model_partition_tables.size() == all_partition_datas.size(), PARAM_INVALID,
"Model table size %zu does not match partition size %zu.",
model_partition_tables.size(), all_partition_datas.size());
for (size_t index = 0; index < model_partition_tables.size(); ++index) {
auto &cur_partiton_data = all_partition_datas[index];
auto &cur_model_partition_table = *model_partition_tables[index];
GE_CHK_BOOL_RET_STATUS(!cur_partiton_data.empty() && cur_model_partition_table.num != 0
&& cur_model_partition_table.num == cur_partiton_data.size(), FAILED,
"Invalid param: partition data size is (%zu), model_partition_table.num is (%u).",
cur_partiton_data.size(), cur_model_partition_table.num);
}
uint32_t model_header_size = sizeof(ModelFileHeader);
uint32_t total_size = model_header_size;
for (size_t index = 0; index < model_partition_tables.size(); ++index) {
auto &cur_model_partition_table = *model_partition_tables[index];
total_size += static_cast<uint32_t>(SIZE_OF_MODEL_PARTITION_TABLE(cur_model_partition_table));
auto &cur_partition_data = all_partition_datas[index];
for (const auto &partition_data : cur_partition_data) {
auto ret = ge::CheckUint32AddOverflow(total_size, partition_data.size);
GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, FAILED, "Add uint32 overflow!");
total_size += partition_data.size;
}
}
// save to buff
auto buff = reinterpret_cast<uint8_t *>(malloc(total_size));
GE_CHK_BOOL_RET_STATUS(buff != nullptr, FAILED, "Malloc failed!");
GE_PRINT_DYNAMIC_MEMORY(malloc, "File buffer.", total_size)
model.data.reset(buff, [](uint8_t *buff) {
GELOGD("Free online model memory.");
free(buff);
buff = nullptr;
});
model.length = total_size;
uint32_t left_space = total_size;
auto ret_mem1 = memcpy_s(buff, left_space, reinterpret_cast<void *>(const_cast<ModelFileHeader *>(&file_header)),
model_header_size);
GE_CHK_BOOL_RET_STATUS(ret_mem1 == 0, FAILED, "Memcpy_s failed!");
buff += model_header_size;
left_space -= model_header_size;
for (size_t index = 0; index < model_partition_tables.size(); ++index) {
auto &cur_tabel = *model_partition_tables[index];
uint32_t table_size = static_cast<uint32_t>(SIZE_OF_MODEL_PARTITION_TABLE(cur_tabel));
auto ret_mem2 = memcpy_s(buff, left_space, reinterpret_cast<void *>(&cur_tabel), table_size);
GE_CHK_BOOL_RET_STATUS(ret_mem2 == 0, FAILED, "Memcpy_s failed!");
buff += table_size;
left_space -= table_size;
auto &cur_partition_data = all_partition_datas[index];
for (const auto &partition_data : cur_partition_data) {
auto ret_mem3 = memcpy_s(buff, left_space, reinterpret_cast<void *>(const_cast<uint8_t *>(partition_data.data)),
partition_data.size);
GE_CHK_BOOL_RET_STATUS(ret_mem3 == 0, FAILED, "Memcpy_s failed!");
buff += partition_data.size;
left_space -= partition_data.size;
}
}
return SUCCESS;
}
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::CheckPath(const std::string &file_path) {
// Determine file path length
if (file_path.size() >= MMPA_MAX_PATH) {

@ -83,6 +83,11 @@ class FileSaver {
const std::vector<ModelPartition> &partitionDatas,
ge::ModelBufferData& model);
static Status SaveToBuffWithFileHeader(const ModelFileHeader &file_header,
std::vector<ModelPartitionTable *> &model_partition_tables,
const std::vector<std::vector<ModelPartition>> &all_partition_datas,
ge::ModelBufferData &model);
static Status SaveToFile(const string &file_path, const void *data, int len);
protected:
@ -113,8 +118,8 @@ class FileSaver {
ModelPartitionTable &model_partition_table,
const std::vector<ModelPartition> &partition_datas);
static Status SaveWithFileHeader(const std::string &file_path, const ModelFileHeader &file_header,
vector<ModelPartitionTable *> &model_partition_tables,
const vector<vector<ModelPartition>> &all_partition_datas);
std::vector<ModelPartitionTable *> &model_partition_tables,
const std::vector<std::vector<ModelPartition>> &all_partition_datas);
};
} // namespace ge
#endif // GE_COMMON_AUTH_FILE_SAVER_H_

@ -416,8 +416,7 @@ Status OmFileSaveHelper::SaveRootModel(const SaveParam &save_param, const char *
if (is_offline) {
ret = FileSaver::SaveToFile(output_file, model_header_, model_partition_tabels, all_model_partitions);
} else {
GELOGW("do not support save ge root model to buff now");
return FAILED;
ret = FileSaver::SaveToBuffWithFileHeader(model_header_, model_partition_tabels, all_model_partitions, model);
}
if (ret == SUCCESS) {
GELOGD("Save model success without encrypt.");

@ -17,6 +17,7 @@
#include "hybrid_model_executor.h"
#include "graph/ge_context.h"
#include "graph/runtime_inference_context.h"
#include "graph/utils/tensor_utils.h"
#include "common/dump/dump_manager.h"
namespace ge {
@ -50,6 +51,11 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) {
auto root_graph_item = model_->GetRootGraphItem();
GE_CHECK_NOTNULL(root_graph_item);
if (root_graph_item->IsDynamic()) {
GE_CHK_STATUS_RET(CheckInputShapeByShapeRange(root_graph_item, args),
"[%s] check input node shape by shape range failed.",
root_graph_item->GetName().c_str());
}
GE_CHK_RT_RET(rtMemcpyAsync(context_.global_step, sizeof(uint64_t), &context_.iteration,
sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE_EX, context_.stream));
SubgraphExecutor executor(model_->GetRootGraphItem(), &context_);
@ -138,5 +144,55 @@ Status HybridModelExecutor::ResetExecutionContext(GraphExecutionContext &context
GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext");
return SUCCESS;
}
Status HybridModelExecutor::CheckInputShapeByShapeRange(const GraphItem *graph_item,
HybridModelExecutor::ExecuteArgs &args) {
GE_CHECK_NOTNULL(graph_item);
auto input_nodes = graph_item->GetInputNodes();
if (args.input_desc.size() < input_nodes.size()) {
REPORT_INNER_ERROR("E19999", "[%s] Number of inputs [%zu] is not sufficient for graph which needs [%zu] inputs.",
graph_item->GetName().c_str(), args.input_desc.size(), input_nodes.size());
GELOGE(INTERNAL_ERROR, "[%s] Number of inputs [%zu] is not sufficient for graph which needs [%zu] inputs.",
graph_item->GetName().c_str(), args.input_desc.size(), input_nodes.size());
return INTERNAL_ERROR;
}
for (size_t i = 0; i < input_nodes.size(); ++i) {
auto &input_node = input_nodes[i];
if (input_node == nullptr) {
GELOGD("[%s] Input[%zu] is not needed by graph, skip it.", graph_item->GetName().c_str(), i);
continue;
}
GeTensorDescPtr model_input_desc = input_node->MutableInputDesc(i);
GE_CHECK_NOTNULL(model_input_desc);
std::vector<std::pair<int64_t, int64_t>> shape_range;
if (model_input_desc->GetShapeRange(shape_range) != SUCCESS) {
REPORT_INNER_ERROR("E19999", "[%s] Input[%zu] get shape range failed", graph_item->GetName().c_str(), i);
GELOGE(INTERNAL_ERROR, "[%s] Input[%zu] get shape range failed", graph_item->GetName().c_str(), i);
return INTERNAL_ERROR;
}
if (shape_range.empty()) {
GELOGD("[%s] Input[%zu] shape is not needed to check by shape range, skip it.", graph_item->GetName().c_str(), i);
continue;
}
ConstGeTensorDescPtr args_tensor_desc = args.input_desc[i];
GE_CHECK_NOTNULL(args_tensor_desc);
GeShape shape = args_tensor_desc->GetShape();
if (shape.IsUnknownShape()) {
REPORT_INNER_ERROR("E19999", "[%s] Input desc shape [%zu] designed by user must be static.",
graph_item->GetName().c_str(), i);
GELOGE(INTERNAL_ERROR, "[%s] Input desc shape [%zu] designed by user must be static.",
graph_item->GetName().c_str(), i);
return INTERNAL_ERROR;
}
if (TensorUtils::CheckShapeByShapeRange(shape, shape_range) != SUCCESS) {
GELOGE(PARAM_INVALID, "[Check][InputShape] [%s] check input [%zu] shape failed by shape range.",
graph_item->GetName().c_str(), i);
return PARAM_INVALID;
}
}
return SUCCESS;
}
} // namespace hybrid
} // namespace ge

@ -52,6 +52,7 @@ class HybridModelExecutor {
Status Cleanup();
Status InitExecutionContext();
static Status ResetExecutionContext(GraphExecutionContext &context);
static Status CheckInputShapeByShapeRange(const GraphItem *graph_item, HybridModelExecutor::ExecuteArgs &args);
HybridModel *model_;
uint32_t device_id_;

@ -44,27 +44,6 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(
}
}
Status ShapeInferenceState::CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc,
const GeTensorDesc &target_tensor_desc) const {
std::vector<std::pair<int64_t, int64_t>> shape_range;
if (tensor_desc.GetShapeRange(shape_range) != SUCCESS) {
GELOGE(PARAM_INVALID, "Get shape range failed.");
return PARAM_INVALID;
}
if (shape_range.empty()) {
GELOGD("Shape range is empty, no need to check input shape.");
return SUCCESS;
}
GeShape target_shape = target_tensor_desc.GetShape();
if (TensorUtils::CheckShapeByShapeRange(target_shape, shape_range) != SUCCESS) {
GELOGE(PARAM_INVALID, "Check shape by shape range failed.");
return PARAM_INVALID;
}
return SUCCESS;
}
Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) {
if (node_item.IsInputShapeStatic(idx)) {
GELOGD("[%s] Trying to update static shape, idx = %d. old shape = [%s], new shape = [%s]",

@ -58,8 +58,6 @@ struct ShapeInferenceState {
const vector<GeTensorDesc> &GetOutputTensorDesc() const;
Status CheckInputShapeByShapeRange(const GeTensorDesc &tensor_desc, const GeTensorDesc &target_tensor_desc) const;
const NodeItem &node_item;
private:

@ -59,7 +59,7 @@ const char *const kKeepDtypeError = "file not found";
const char *const kInputShapeRangeInvalid = "format of shape range is invalid";
const char *const kShapeRangeValueConvertError = "transfer from string to int64 error";
const char *const kInputShapeRangeSample1 = "\"input_name1:[n1~n2,c1,h1,w1]\"";
const char *const kInputShapeRangeSample2 = "\"[]\"";
const char *const kInputShapeRangeSample2 = "\"[1~20]\"";
const char *const kInputShapeRangeSample3 = "\"[1~20,3,3~6,-1]\"";
vector<string> SplitInputShape(const std::string &input_shape) {
@ -302,8 +302,8 @@ bool ParseSingleShapeRange(std::string &shape_range, vector<pair<int64_t, int64_
}
}
bool is_square_brackets = (square_brackets[0] == '[') && (square_brackets[1] == ']') &&
(square_brackets.size() == kSquareBracketsSize);
bool is_square_brackets = (square_brackets.size() == kSquareBracketsSize) &&
(square_brackets[0] == '[') && (square_brackets[1] == ']');
if (!is_square_brackets) {
ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"shape_range", "reason", "sample"},
{shape_range, kInputShapeRangeInvalid, kInputShapeRangeSample2});

@ -503,8 +503,17 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTe
string data_type_str = ge::TypeUtils::DataTypeToSerialString(data_type);
GELOGD("Data op get data type:%s from InputDesc in ge ir graph.", data_type_str.c_str());
std::vector<std::pair<int64_t, int64_t>> shape_range;
if (tensor.GetShapeRange(shape_range) != GRAPH_SUCCESS) {
GELOGE(FAILED, "[Creat][Input] Data op [%s] get shape range failed.", data_op_name.c_str());
return FAILED;
}
ge::GeTensor inputTensor;
ge::GeTensorDesc desc(data_shape, ge::Format(data_format), data_type);
if (desc.SetShapeRange(shape_range) != GRAPH_SUCCESS) {
GELOGE(FAILED, "[Creat][Input] Data op [%s] set shape range failed.", data_op_name.c_str());
return FAILED;
}
inputTensor.SetTensorDesc(desc);
inputs.push_back(inputTensor);
}

Loading…
Cancel
Save