|
|
|
@ -29,6 +29,7 @@ from mindspore.common.api import _executor
|
|
|
|
|
from mindspore.common import dtype as mstype
|
|
|
|
|
from mindspore._checkparam import check_input_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
|
|
|
|
|
|
|
|
|
|
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
|
|
|
|
@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
|
|
|
|
|
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
|
|
|
|
|
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
|
|
|
|
|
|
|
|
|
|
ModelType = ["normal", "fusion", "quant"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _special_process_par(par, new_par):
|
|
|
|
|
"""
|
|
|
|
@ -101,20 +104,22 @@ def _update_param(param, new_param):
|
|
|
|
|
param.set_parameter_data(type(param.data)(new_param.data))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_checkpoint(parameter_list, ckpoint_file_name):
|
|
|
|
|
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
|
|
|
|
|
"""
|
|
|
|
|
Saves checkpoint info to a specified file.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
parameter_list (list): Parameters list, each element is a dict
|
|
|
|
|
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
|
|
|
|
|
ckpoint_file_name (str): Checkpoint file name.
|
|
|
|
|
ckpt_file_name (str): Checkpoint file name.
|
|
|
|
|
model_type (str): The name of model type. Default: "normal".
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
RuntimeError: Failed to save the Checkpoint file.
|
|
|
|
|
"""
|
|
|
|
|
logger.info("Execute save checkpoint process.")
|
|
|
|
|
checkpoint_list = Checkpoint()
|
|
|
|
|
checkpoint_list.model_type = model_type
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
for param in parameter_list:
|
|
|
|
@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
|
|
|
|
|
for dim in param['data'].shape:
|
|
|
|
|
param_tensor.dims.append(dim)
|
|
|
|
|
|
|
|
|
|
with open(ckpoint_file_name, "wb") as f:
|
|
|
|
|
with open(ckpt_file_name, "wb") as f:
|
|
|
|
|
f.write(checkpoint_list.SerializeToString())
|
|
|
|
|
os.chmod(ckpoint_file_name, stat.S_IRUSR)
|
|
|
|
|
os.chmod(ckpt_file_name, stat.S_IRUSR)
|
|
|
|
|
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name)
|
|
|
|
|
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
|
|
|
|
|
raise RuntimeError(e.__str__())
|
|
|
|
|
logger.info("Save checkpoint process finish.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(ckpoint_file_name, net=None):
|
|
|
|
|
def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
|
|
|
|
|
"""
|
|
|
|
|
Loads checkpoint info from a specified file.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
ckpoint_file_name (str): Checkpoint file name.
|
|
|
|
|
ckpt_file_name (str): Checkpoint file name.
|
|
|
|
|
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
|
|
|
|
|
net (Cell): Cell network. Default: None
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None):
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: Checkpoint file is incorrect.
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(ckpoint_file_name, str):
|
|
|
|
|
raise ValueError("The ckpoint_file_name must be String.")
|
|
|
|
|
if not isinstance(ckpt_file_name, str):
|
|
|
|
|
raise ValueError("The ckpt_file_name must be string.")
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt":
|
|
|
|
|
if model_type not in ModelType:
|
|
|
|
|
raise ValueError(f"The model_type is not in {ModelType}.")
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
|
|
|
|
|
raise ValueError("Please input the correct checkpoint file name.")
|
|
|
|
|
|
|
|
|
|
if os.path.getsize(ckpoint_file_name) == 0:
|
|
|
|
|
if os.path.getsize(ckpt_file_name) == 0:
|
|
|
|
|
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
|
|
|
|
|
|
|
|
|
|
logger.info("Execute load checkpoint process.")
|
|
|
|
|
checkpoint_list = Checkpoint()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
with open(ckpoint_file_name, "rb") as f:
|
|
|
|
|
with open(ckpt_file_name, "rb") as f:
|
|
|
|
|
pb_content = f.read()
|
|
|
|
|
checkpoint_list.ParseFromString(pb_content)
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.error("Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name)
|
|
|
|
|
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
|
|
|
|
|
raise ValueError(e.__str__())
|
|
|
|
|
|
|
|
|
|
parameter_dict = {}
|
|
|
|
|
|
|
|
|
|
if model_type != checkpoint_list.model_type:
|
|
|
|
|
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
|
|
|
|
|
checkpoint_list.model_type, model_type))
|
|
|
|
|
try:
|
|
|
|
|
for element in checkpoint_list.value:
|
|
|
|
|
data = element.tensor.tensor_content
|
|
|
|
@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None):
|
|
|
|
|
logger.info("Load checkpoint process finish.")
|
|
|
|
|
|
|
|
|
|
except BaseException as e:
|
|
|
|
|
logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name)
|
|
|
|
|
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
|
|
|
|
|
raise RuntimeError(e.__str__())
|
|
|
|
|
|
|
|
|
|
if net:
|
|
|
|
@ -303,14 +314,15 @@ def _save_graph(network, file_name):
|
|
|
|
|
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
|
|
|
|
|
def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True):
|
|
|
|
|
"""
|
|
|
|
|
Saves checkpoint for 'ms' backend.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
train_network (Network): The train network for training.
|
|
|
|
|
ckpoint_file_name (str): The name of checkpoint file.
|
|
|
|
|
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
|
|
|
|
|
ckpt_file_name (str): The name of checkpoint file.
|
|
|
|
|
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
|
|
|
|
|
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
param_dict = {}
|
|
|
|
@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
|
|
|
|
|
each_param["data"] = param_data
|
|
|
|
|
param_list.append(each_param)
|
|
|
|
|
|
|
|
|
|
save_checkpoint(param_list, ckpoint_file_name)
|
|
|
|
|
save_checkpoint(param_list, ckpt_file_name, model_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_merged_param_data(net, param_name, param_data):
|
|
|
|
|