diff --git a/mindspore/train/callback/_checkpoint.py b/mindspore/train/callback/_checkpoint.py index 152e77704e..3912bcd620 100644 --- a/mindspore/train/callback/_checkpoint.py +++ b/mindspore/train/callback/_checkpoint.py @@ -108,13 +108,13 @@ class CheckpointConfig: not keep_checkpoint_max and not keep_checkpoint_per_n_minutes: raise ValueError("The input_param can't be all None or 0") - if save_checkpoint_steps: + if save_checkpoint_steps is not None: save_checkpoint_steps = check_int_non_negative(save_checkpoint_steps) - if save_checkpoint_seconds: + if save_checkpoint_seconds is not None: save_checkpoint_seconds = check_int_non_negative(save_checkpoint_seconds) - if keep_checkpoint_max: + if keep_checkpoint_max is not None: keep_checkpoint_max = check_int_non_negative(keep_checkpoint_max) - if keep_checkpoint_per_n_minutes: + if keep_checkpoint_per_n_minutes is not None: keep_checkpoint_per_n_minutes = check_int_non_negative(keep_checkpoint_per_n_minutes) self._save_checkpoint_steps = save_checkpoint_steps diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 8c11981268..e07cfa94c5 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -258,7 +258,7 @@ def load_checkpoint(ckpt_file_name, net=None): logger.error("Failed to load the checkpoint file `%s`.", ckpt_file_name) raise RuntimeError(e.__str__()) - if net: + if net is not None: load_param_into_net(net, parameter_dict) return parameter_dict