|
|
|
@ -24,7 +24,7 @@ import mindspore.context as context
|
|
|
|
|
from mindspore.train.serialization import _exec_save_checkpoint, _fill_param_into_net, _save_graph
|
|
|
|
|
from mindspore.train._utils import _make_directory
|
|
|
|
|
from mindspore import log as logger
|
|
|
|
|
from mindspore._checkparam import check_int_non_negative
|
|
|
|
|
from mindspore._checkparam import check_int_non_negative, check_bool
|
|
|
|
|
from mindspore.common.tensor import Tensor
|
|
|
|
|
from .summary.summary_record import _cache_summary_tensor_data
|
|
|
|
|
|
|
|
|
@ -150,6 +150,8 @@ class CheckpointConfig:
|
|
|
|
|
keep_checkpoint_max (int): Maximum step to save checkpoint. Default: 5.
|
|
|
|
|
keep_checkpoint_per_n_minutes (int): Keep one checkpoint every n minutes. Default: 0.
|
|
|
|
|
Can't be used with keep_checkpoint_max at the same time.
|
|
|
|
|
integrated_save (bool): Whether to intergrated save in automatic model parall scene. Default: True.
|
|
|
|
|
Integrated save function is only supported in automatic parall scene, not supported in manual parallel.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If the input_param is None or 0.
|
|
|
|
@ -163,7 +165,8 @@ class CheckpointConfig:
|
|
|
|
|
save_checkpoint_steps=1,
|
|
|
|
|
save_checkpoint_seconds=0,
|
|
|
|
|
keep_checkpoint_max=5,
|
|
|
|
|
keep_checkpoint_per_n_minutes=0):
|
|
|
|
|
keep_checkpoint_per_n_minutes=0,
|
|
|
|
|
integrated_save=True):
|
|
|
|
|
|
|
|
|
|
if not save_checkpoint_steps and not save_checkpoint_seconds and \
|
|
|
|
|
not keep_checkpoint_max and not keep_checkpoint_per_n_minutes:
|
|
|
|
@ -191,6 +194,8 @@ class CheckpointConfig:
|
|
|
|
|
if not self._keep_checkpoint_per_n_minutes or self._keep_checkpoint_per_n_minutes == 0:
|
|
|
|
|
self._keep_checkpoint_max = 1
|
|
|
|
|
|
|
|
|
|
self._integrated_save = check_bool(integrated_save)
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def save_checkpoint_steps(self):
|
|
|
|
|
"""Get the value of _save_checkpoint_steps."""
|
|
|
|
@ -211,6 +216,11 @@ class CheckpointConfig:
|
|
|
|
|
"""Get the value of _keep_checkpoint_per_n_minutes."""
|
|
|
|
|
return self._keep_checkpoint_per_n_minutes
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def integrated_save(self):
|
|
|
|
|
"""Get the value of _integrated_save."""
|
|
|
|
|
return self._integrated_save
|
|
|
|
|
|
|
|
|
|
def get_checkpoint_policy(self):
|
|
|
|
|
"""Get the policy of checkpoint."""
|
|
|
|
|
checkpoint_policy = {'save_checkpoint_steps': self._save_checkpoint_steps,
|
|
|
|
@ -619,7 +629,7 @@ class ModelCheckpoint(Callback):
|
|
|
|
|
_set_cur_net(cb_params.train_network)
|
|
|
|
|
cb_params.train_network.exec_checkpoint_graph()
|
|
|
|
|
|
|
|
|
|
_exec_save_checkpoint(cb_params.train_network, gen_file)
|
|
|
|
|
_exec_save_checkpoint(cb_params.train_network, gen_file, self._config.integrated_save)
|
|
|
|
|
|
|
|
|
|
if os.path.exists(gen_file):
|
|
|
|
|
shutil.move(gen_file, cur_file)
|
|
|
|
|