!12884 modify export mindir model

From: @changzherui
Reviewed-by: @kingxian,@xu-yfei
Signed-off-by: @kingxian
pull/12884/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit fe7652e87f

@ -133,19 +133,18 @@ std::shared_ptr<FuncGraph> LoadMindIR(const std::string &file_name, bool is_lite
// Load parameter into graph
if (endsWith(abs_path_buff, "_graph.mindir")) {
char *mindir_name, delimiter = '/';
mindir_name = strrchr(abs_path_buff, delimiter);
int path_len = strlen(abs_path_buff) - strlen(mindir_name) + 1;
int path_len = strlen(abs_path_buff) - strlen("graph.mindir");
memcpy(abs_path, abs_path_buff, path_len);
abs_path[path_len] = '\0';
snprintf(abs_path, sizeof(abs_path), "variables");
snprintf(abs_path + path_len, sizeof(abs_path), "variables");
std::ifstream ifs(abs_path);
if (ifs.good()) {
MS_LOG(DEBUG) << "MindIR file has variables path, load parameter into graph.";
string path = abs_path;
get_all_files(path, &files);
} else {
MS_LOG(ERROR) << "MindIR graph has not variable path. ";
MS_LOG(ERROR) << "MindIR graph has not variable path, load failed";
return nullptr;
}
int file_size = files.size();

@ -638,9 +638,12 @@ def _save_mindir(net, file_name, *inputs):
else:
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
# save parameter
file_prefix = file_name.split("/")[-1]
if file_prefix.endswith(".mindir"):
file_prefix = file_prefix[:-7]
current_path = os.path.abspath(file_name)
dirname = os.path.dirname(current_path)
data_path = dirname + "/variables"
data_path = dirname + "/" + file_prefix + "_variables"
if os.path.exists(data_path):
shutil.rmtree(data_path)
os.makedirs(data_path, exist_ok=True)
@ -675,7 +678,7 @@ def _save_mindir(net, file_name, *inputs):
# save graph
del model.graph.parameter[:]
graph_file_name = file_name + "_graph.mindir"
graph_file_name = dirname + "/" + file_prefix + "_graph.mindir"
with open(graph_file_name, 'wb') as f:
os.chmod(graph_file_name, stat.S_IWUSR | stat.S_IRUSR)
f.write(model.SerializeToString())
@ -1147,7 +1150,6 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
def _load_single_param(ckpt_file_name, param_name):
"""Load a parameter from checkpoint."""
logger.info("Execute the process of loading checkpoint files.")
checkpoint_list = Checkpoint()
try:
@ -1155,7 +1157,8 @@ def _load_single_param(ckpt_file_name, param_name):
pb_content = f.read()
checkpoint_list.ParseFromString(pb_content)
except BaseException as e:
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
logger.error("Failed to read the checkpoint file `%s` during load single parameter,"
" please check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__())
parameter = None
@ -1189,8 +1192,7 @@ def _load_single_param(ckpt_file_name, param_name):
param_dim.append(dim)
param_value = param_data.reshape(param_dim)
parameter = Parameter(Tensor(param_value, ms_type), name=element.tag)
break
logger.info("Loading checkpoint files process is finished.")
break
except BaseException as e:
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)

Loading…
Cancel
Save