|
|
|
|
@ -261,6 +261,7 @@ class ModelCheckpoint(Callback):
|
|
|
|
|
self._manager = CheckpointManager()
|
|
|
|
|
self._prefix = _chg_ckpt_file_name_if_same_exist(self._directory, self._prefix)
|
|
|
|
|
self._graph_saved = False
|
|
|
|
|
self._need_flush_from_cache = True
|
|
|
|
|
|
|
|
|
|
def step_end(self, run_context):
|
|
|
|
|
"""
|
|
|
|
|
@ -326,7 +327,8 @@ class ModelCheckpoint(Callback):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# if param is cache enable, flush data from cache to host before save_ckpt
|
|
|
|
|
self._flush_from_cache(cb_params)
|
|
|
|
|
if self._need_flush_from_cache:
|
|
|
|
|
self._flush_from_cache(cb_params)
|
|
|
|
|
|
|
|
|
|
save_ckpt = self._check_save_ckpt(cb_params, force_to_save)
|
|
|
|
|
step_num_in_epoch = int((cb_params.cur_step_num - 1) % cb_params.batch_num + 1)
|
|
|
|
|
@ -365,10 +367,14 @@ class ModelCheckpoint(Callback):
|
|
|
|
|
|
|
|
|
|
def _flush_from_cache(self, cb_params):
|
|
|
|
|
"""Flush cache data to host if tensor is cache enable."""
|
|
|
|
|
has_cache_params = False
|
|
|
|
|
params = cb_params.train_network.get_parameters()
|
|
|
|
|
for param in params:
|
|
|
|
|
if param.cache_enable:
|
|
|
|
|
has_cache_params = True
|
|
|
|
|
Tensor(param).flush_from_cache()
|
|
|
|
|
if not has_cache_params:
|
|
|
|
|
self._need_flush_from_cache = False
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def latest_ckpt_file_name(self):
|
|
|
|
|
|