|
|
|
|
@ -638,9 +638,12 @@ def _save_mindir(net, file_name, *inputs):
|
|
|
|
|
else:
|
|
|
|
|
logger.warning("Parameters in the net capacity exceeds 1G, save MindIR model and parameters separately.")
|
|
|
|
|
# save parameter
|
|
|
|
|
file_prefix = file_name.split("/")[-1]
|
|
|
|
|
if file_prefix.endswith(".mindir"):
|
|
|
|
|
file_prefix = file_prefix[:-7]
|
|
|
|
|
current_path = os.path.abspath(file_name)
|
|
|
|
|
dirname = os.path.dirname(current_path)
|
|
|
|
|
data_path = dirname + "/variables"
|
|
|
|
|
data_path = dirname + "/" + file_prefix + "_variables"
|
|
|
|
|
if os.path.exists(data_path):
|
|
|
|
|
shutil.rmtree(data_path)
|
|
|
|
|
os.makedirs(data_path, exist_ok=True)
|
|
|
|
|
@ -675,7 +678,7 @@ def _save_mindir(net, file_name, *inputs):
|
|
|
|
|
|
|
|
|
|
# save graph
|
|
|
|
|
del model.graph.parameter[:]
|
|
|
|
|
graph_file_name = file_name + "_graph.mindir"
|
|
|
|
|
graph_file_name = dirname + "/" + file_prefix + "_graph.mindir"
|
|
|
|
|
with open(graph_file_name, 'wb') as f:
|
|
|
|
|
os.chmod(graph_file_name, stat.S_IWUSR | stat.S_IRUSR)
|
|
|
|
|
f.write(model.SerializeToString())
|
|
|
|
|
@ -1147,7 +1150,6 @@ def _merge_and_split(sliced_params, train_strategy, predict_strategy):
|
|
|
|
|
|
|
|
|
|
def _load_single_param(ckpt_file_name, param_name):
|
|
|
|
|
"""Load a parameter from checkpoint."""
|
|
|
|
|
logger.info("Execute the process of loading checkpoint files.")
|
|
|
|
|
checkpoint_list = Checkpoint()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
@ -1155,7 +1157,8 @@ def _load_single_param(ckpt_file_name, param_name):
|
|
|
|
|
pb_content = f.read()
|
|
|
|
|
checkpoint_list.ParseFromString(pb_content)
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
|
|
|
|
|
logger.error("Failed to read the checkpoint file `%s` during load single parameter,"
|
|
|
|
|
" please check the correct of the file.", ckpt_file_name)
|
|
|
|
|
raise ValueError(e.__str__())
|
|
|
|
|
|
|
|
|
|
parameter = None
|
|
|
|
|
@ -1189,8 +1192,7 @@ def _load_single_param(ckpt_file_name, param_name):
|
|
|
|
|
param_dim.append(dim)
|
|
|
|
|
param_value = param_data.reshape(param_dim)
|
|
|
|
|
parameter = Parameter(Tensor(param_value, ms_type), name=element.tag)
|
|
|
|
|
break
|
|
|
|
|
logger.info("Loading checkpoint files process is finished.")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
|
|
|
|
|
|