!3907 modify ckpt func check parameter

Merge pull request !3907 from changzherui/mod_ckpt_func_param
pull/3907/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 3dcea81721

@ -108,13 +108,13 @@ class CheckpointConfig:
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
raise ValueError("The input_param can't be all None or 0") raise ValueError("The input_param can't be all None or 0")
if save_checkpoint_steps: if save_checkpoint_steps is not None:
save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps) save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps)
if save_checkpoint_seconds: if save_checkpoint_seconds is not None:
save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds) save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds)
if keep_checkpoint_max: if keep_checkpoint_max is not None:
keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max)
if keep_checkpoint_per_n_minutes: if keep_checkpoint_per_n_minutes is not None:
keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes)
self._save_checkpoint_steps = save_checkpoint_steps self._save_checkpoint_steps = save_checkpoint_steps

@ -258,7 +258,7 @@ def load_checkpoint(ckpt_file_name, net=None):
logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name)
raise RuntimeError(e.__str__()) raise RuntimeError(e.__str__())
if net: if net is not None:
load_param_into_net(net, parameter_dict) load_param_into_net(net, parameter_dict)
return parameter_dict return parameter_dict

Loading…
Cancel
Save