|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
"""Model and parameters serialization."""
|
|
|
|
|
import os
|
|
|
|
|
import stat
|
|
|
|
|
from threading import Thread, Lock
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
import mindspore.nn as nn
|
|
|
|
@ -40,6 +41,7 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
|
|
|
|
|
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
|
|
|
|
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
|
|
|
|
|
|
|
|
|
|
_ckpt_mutex = Lock()
|
|
|
|
|
|
|
|
|
|
def _special_process_par(par, new_par):
|
|
|
|
|
"""
|
|
|
|
@ -101,7 +103,29 @@ def _update_param(param, new_param):
|
|
|
|
|
param.set_parameter_data(type(param.data)(new_param.data))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(parameter_list, ckpt_file_name):
|
|
|
|
|
def _exec_save(ckpt_file_name, data_list):
|
|
|
|
|
"""Execute save checkpoint into file process."""
|
|
|
|
|
checkpoint_list = Checkpoint()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
with _ckpt_mutex:
|
|
|
|
|
for name, value in data_list.items():
|
|
|
|
|
param_value = checkpoint_list.value.add()
|
|
|
|
|
param_value.tag = name
|
|
|
|
|
param_tensor = param_value.tensor
|
|
|
|
|
param_tensor.dims.extend(value[0])
|
|
|
|
|
param_tensor.tensor_type = value[1]
|
|
|
|
|
param_tensor.tensor_content = value[2].tostring()
|
|
|
|
|
|
|
|
|
|
with open(ckpt_file_name, "wb") as f:
|
|
|
|
|
f.write(checkpoint_list.SerializeToString())
|
|
|
|
|
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
|
|
|
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
|
|
|
|
|
raise RuntimeError(e.__str__())
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
|
|
|
|
|
"""
|
|
|
|
|
Saves checkpoint info to a specified file.
|
|
|
|
|
|
|
|
|
@ -109,37 +133,37 @@ def save_checkpoint(parameter_list, ckpt_file_name):
|
|
|
|
|
parameter_list (list): Parameters list, each element is a dict
|
|
|
|
|
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
|
|
|
|
ckpt_file_name (str): Checkpoint file name.
|
|
|
|
|
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
RuntimeError: Failed to save the Checkpoint file.
|
|
|
|
|
"""
|
|
|
|
|
logger.info("Execute save checkpoint process.")
|
|
|
|
|
checkpoint_list = Checkpoint()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
data_list = {}
|
|
|
|
|
with _ckpt_mutex:
|
|
|
|
|
for param in parameter_list:
|
|
|
|
|
param_value = checkpoint_list.value.add()
|
|
|
|
|
param_value.tag = param["name"]
|
|
|
|
|
param_tensor = param_value.tensor
|
|
|
|
|
key = param["name"]
|
|
|
|
|
data_list[key] = []
|
|
|
|
|
if isinstance(param["data"], Parameter):
|
|
|
|
|
param["data"].init_data()
|
|
|
|
|
param_data = param["data"].asnumpy().reshape(-1)
|
|
|
|
|
param_tensor.tensor_content = param_data.tostring()
|
|
|
|
|
param_tensor.tensor_type = str(param["data"].dtype)
|
|
|
|
|
|
|
|
|
|
dims = []
|
|
|
|
|
if param['data'].shape == ():
|
|
|
|
|
param_tensor.dims.append(0)
|
|
|
|
|
dims.append(0)
|
|
|
|
|
else:
|
|
|
|
|
for dim in param['data'].shape:
|
|
|
|
|
param_tensor.dims.append(dim)
|
|
|
|
|
|
|
|
|
|
with open(ckpt_file_name, "wb") as f:
|
|
|
|
|
f.write(checkpoint_list.SerializeToString())
|
|
|
|
|
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
|
|
|
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
|
|
|
|
|
raise RuntimeError(e.__str__())
|
|
|
|
|
dims.append(dim)
|
|
|
|
|
data_list[key].append(dims)
|
|
|
|
|
tensor_type = str(param["data"].dtype)
|
|
|
|
|
data_list[key].append(tensor_type)
|
|
|
|
|
data = param["data"].asnumpy().reshape(-1)
|
|
|
|
|
data_list[key].append(data)
|
|
|
|
|
|
|
|
|
|
if async_save:
|
|
|
|
|
thr = Thread(target=_exec_save, args=(ckpt_file_name, data_list))
|
|
|
|
|
thr.start()
|
|
|
|
|
else:
|
|
|
|
|
_exec_save(ckpt_file_name, data_list)
|
|
|
|
|
logger.info("Save checkpoint process finish.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -305,7 +329,7 @@ def _save_graph(network, file_name):
|
|
|
|
|
os.chmod(file_name, stat.S_IRUSR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
|
|
|
|
|
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False):
|
|
|
|
|
"""
|
|
|
|
|
Saves checkpoint for 'ms' backend.
|
|
|
|
|
|
|
|
|
@ -313,6 +337,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
|
|
|
|
|
train_network (Network): The train network for training.
|
|
|
|
|
ckpt_file_name (str): The name of checkpoint file.
|
|
|
|
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
|
|
|
|
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
param_dict = {}
|
|
|
|
@ -336,7 +361,7 @@ def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True):
|
|
|
|
|
each_param["data"] = param_data
|
|
|
|
|
param_list.append(each_param)
|
|
|
|
|
|
|
|
|
|
save_checkpoint(param_list, ckpt_file_name)
|
|
|
|
|
save_checkpoint(param_list, ckpt_file_name, async_save)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_merged_param_data(net, param_name, param_data):
|
|
|
|
|