|
|
|
@ -141,24 +141,52 @@ def _exec_save(ckpt_file_name, data_list):
|
|
|
|
|
raise RuntimeError(e.__str__())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
|
|
|
|
|
def save_checkpoint(save_obj, ckpt_file_name, integrated_save=True, async_save=False):
|
|
|
|
|
"""
|
|
|
|
|
Saves checkpoint info to a specified file.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
parameter_list (list): Parameters list, each element is a dictionary
|
|
|
|
|
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
|
|
|
|
save_obj (nn.Cell or list): The train network for training or parameters list(each element is a dictionary,
|
|
|
|
|
like {"name":xx, "type":xx, "shape":xx, "data":xx}.)
|
|
|
|
|
ckpt_file_name (str): Checkpoint file name.
|
|
|
|
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
|
|
|
|
async_save (bool): Whether asynchronous execution saves the checkpoint to a file. Default: False
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: If the parameter save_obj is not nn.Cell or list type.
|
|
|
|
|
RuntimeError: Failed to save the Checkpoint file.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
if not isinstance(save_obj, nn.Cell) and not isinstance(save_obj, list):
|
|
|
|
|
raise TypeError("The parameter save_obj should be nn.Cell or list, but got {}".format(type(save_obj)))
|
|
|
|
|
|
|
|
|
|
logger.info("Execute save checkpoint process.")
|
|
|
|
|
|
|
|
|
|
if isinstance(save_obj, nn.Cell):
|
|
|
|
|
save_obj.init_parameters_data()
|
|
|
|
|
param_dict = {}
|
|
|
|
|
for _, param in save_obj.parameters_and_names():
|
|
|
|
|
param_dict[param.name] = param
|
|
|
|
|
param_list = []
|
|
|
|
|
for (key, value) in param_dict.items():
|
|
|
|
|
each_param = {"name": key}
|
|
|
|
|
if isinstance(value.data, Tensor):
|
|
|
|
|
param_data = value.data
|
|
|
|
|
else:
|
|
|
|
|
param_data = Tensor(value.data)
|
|
|
|
|
|
|
|
|
|
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
|
|
|
|
|
# which should be combined before saving
|
|
|
|
|
if integrated_save and key in save_obj.parameter_layout_dict:
|
|
|
|
|
param_data = _get_merged_param_data(save_obj, key, param_data)
|
|
|
|
|
|
|
|
|
|
each_param["data"] = param_data
|
|
|
|
|
param_list.append(each_param)
|
|
|
|
|
save_obj = param_list
|
|
|
|
|
|
|
|
|
|
data_list = {}
|
|
|
|
|
with _ckpt_mutex:
|
|
|
|
|
for param in parameter_list:
|
|
|
|
|
for param in save_obj:
|
|
|
|
|
key = param["name"]
|
|
|
|
|
data_list[key] = []
|
|
|
|
|
if isinstance(param["data"], Parameter):
|
|
|
|
@ -180,6 +208,7 @@ def save_checkpoint(parameter_list, ckpt_file_name, async_save=False):
|
|
|
|
|
thr.start()
|
|
|
|
|
else:
|
|
|
|
|
_exec_save(ckpt_file_name, data_list)
|
|
|
|
|
|
|
|
|
|
logger.info("Save checkpoint process finish.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -354,39 +383,6 @@ def _save_graph(network, file_name):
|
|
|
|
|
os.chmod(file_name, stat.S_IRUSR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _exec_save_checkpoint(train_network, ckpt_file_name, integrated_save=True, async_save=False):
|
|
|
|
|
"""
|
|
|
|
|
Saves checkpoint for 'ms' backend.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
train_network (Network): The train network for training.
|
|
|
|
|
ckpt_file_name (str): The name of checkpoint file.
|
|
|
|
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
|
|
|
|
async_save (bool): Whether asynchronous execute save checkpoint into file. Default: False.
|
|
|
|
|
"""
|
|
|
|
|
train_network.init_parameters_data()
|
|
|
|
|
param_dict = {}
|
|
|
|
|
for _, param in train_network.parameters_and_names():
|
|
|
|
|
param_dict[param.name] = param
|
|
|
|
|
param_list = []
|
|
|
|
|
for (key, value) in param_dict.items():
|
|
|
|
|
each_param = {"name": key}
|
|
|
|
|
if isinstance(value.data, Tensor):
|
|
|
|
|
param_data = value.data
|
|
|
|
|
else:
|
|
|
|
|
param_data = Tensor(value.data)
|
|
|
|
|
|
|
|
|
|
# in automatic model parallel scenario, some parameters were spliteds to all the devices,
|
|
|
|
|
# which should be combined before saving
|
|
|
|
|
if integrated_save and key in train_network.parameter_layout_dict:
|
|
|
|
|
param_data = _get_merged_param_data(train_network, key, param_data)
|
|
|
|
|
|
|
|
|
|
each_param["data"] = param_data
|
|
|
|
|
param_list.append(each_param)
|
|
|
|
|
|
|
|
|
|
save_checkpoint(param_list, ckpt_file_name, async_save)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_merged_param_data(net, param_name, param_data):
|
|
|
|
|
"""
|
|
|
|
|
Gets the merged data(tensor) from tensor slice, by device arrangement and tensor map.
|
|
|
|
|