@ -29,6 +29,7 @@ from mindspore.common.api import _executor
from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export"]
tensor_to_ms_type = {"Int8": mstype.int8, "Uint8": mstype.uint8, "Int16": mstype.int16, "Uint16": mstype.uint16,
@ -40,6 +41,8 @@ tensor_to_np_type = {"Int8": np.int8, "Uint8": np.uint8, "Int16": np.int16, "Uin
"Int32": np.int32, "Uint32": np.uint32, "Int64": np.int64, "Uint64": np.uint64,
"Float16": np.float16, "Float32": np.float32, "Float64": np.float64, "Bool": np.bool_}
ModelType = ["normal", "fusion", "quant"]
def _special_process_par(par, new_par):
@ -101,20 +104,22 @@ def _update_param(param, new_param):
def save_checkpoint(parameter_list, ckpoint_file_name):
def save_checkpoint(parameter_list, ckpt_file_name, model_type="normal"):
Saves checkpoint info to a specified file.
parameter_list (list): Parameters list, each element is a dict
like {"name":xx, "type":xx, "shape":xx, "data":xx}.
ckpoint_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type. Default: "normal".
RuntimeError: Failed to save the Checkpoint file.
logger.info("Execute save checkpoint process.")
checkpoint_list = Checkpoint()
checkpoint_list.model_type = model_type
for param in parameter_list:
@ -133,22 +138,23 @@ def save_checkpoint(parameter_list, ckpoint_file_name):
for dim in param['data'].shape:
with open(ckpoint_file_name, "wb") as f:
with open(ckpt_file_name, "wb") as f:
os.chmod(ckpoint_file_name, stat.S_IRUSR)
os.chmod(ckpt_file_name, stat.S_IRUSR)
except BaseException as e:
logger.error("Failed to save the checkpoint file %s.", ckpoint_file_name)
logger.error("Failed to save the checkpoint file %s.", ckpt_file_name)
raise RuntimeError(e.__str__())
logger.info("Save checkpoint process finish.")
def load_checkpoint(ckpoint_file_name, net=None):
def load_checkpoint(ckpt_file_name, model_type="normal", net=None):
Loads checkpoint info from a specified file.
ckpoint_file_name (str): Checkpoint file name.
ckpt_file_name (str): Checkpoint file name.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
net (Cell): Cell network. Default: None
@ -157,28 +163,33 @@ def load_checkpoint(ckpoint_file_name, net=None):
ValueError: Checkpoint file is incorrect.
if not isinstance(ckpoint_file_name, str):
raise ValueError("The ckpoint_file_name must be String.")
if not isinstance(ckpt_file_name, str):
raise ValueError("The ckpt_file_name must be string.")
if not os.path.exists(ckpoint_file_name) or ckpoint_file_name[-5:] != ".ckpt":
if model_type not in ModelType:
raise ValueError(f"The model_type is not in {ModelType}.")
if not os.path.exists(ckpt_file_name) or ckpt_file_name[-5:] != ".ckpt":
raise ValueError("Please input the correct checkpoint file name.")
if os.path.getsize(ckpoint_file_name) == 0:
if os.path.getsize(ckpt_file_name) == 0:
raise ValueError("The checkpoint file may be empty, please make sure enter the correct file name.")
logger.info("Execute load checkpoint process.")
checkpoint_list = Checkpoint()
with open(ckpoint_file_name, "rb") as f:
with open(ckpt_file_name, "rb") as f:
pb_content = f.read()
except BaseException as e:
logger.error("Failed to read the checkpoint file %s, please check the correct of the file.", ckpoint_file_name)
logger.error("Failed to read the checkpoint file `%s`, please check the correct of the file.", ckpt_file_name)
raise ValueError(e.__str__())
parameter_dict = {}
if model_type != checkpoint_list.model_type:
raise KeyError("Checkpoint file model type({}) is not equal to input model type({}).".format(
checkpoint_list.model_type, model_type))
for element in checkpoint_list.value:
data = element.tensor.tensor_content
@ -206,7 +217,7 @@ def load_checkpoint(ckpoint_file_name, net=None):
logger.info("Load checkpoint process finish.")
except BaseException as e:
logger.error("Failed to load the checkpoint file %s.", ckpoint_file_name)
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__())
if net:
@ -303,14 +314,15 @@ def _save_graph(network, file_name):
os.chmod(file_name, stat.S_IWUSR | stat.S_IRUSR)
def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True):
def _exec_save_checkpoint(train_network, ckpt_file_name, model_type="normal", integrated_save=True):
Saves checkpoint for 'ms' backend.
train_network (Network): The train network for training.
ckpoint_file_name (str): The name of checkpoint file.
integrated_save (bool): Whether to intergrated save in automatic model parallel scene.
ckpt_file_name (str): The name of checkpoint file.
model_type (str): The name of model type in `normal`, `fusion` or `quant`. Default: "normal".
integrated_save (bool): Whether to integrated save in automatic model parallel scene.
param_dict = {}
@ -334,7 +346,7 @@ def _exec_save_checkpoint(train_network, ckpoint_file_name, integrated_save=True
each_param["data"] = param_data
save_checkpoint(param_list, ckpoint_file_name)
save_checkpoint(param_list, ckpt_file_name, model_type)
def _get_merged_param_data(net, param_name, param_data):