|
|
|
@ -228,6 +228,8 @@ class ModelCheckpoint(Callback):
|
|
|
|
|
Args:
|
|
|
|
|
run_context (RunContext): Context of the train running.
|
|
|
|
|
"""
|
|
|
|
|
if _is_role_pserver():
|
|
|
|
|
self._prefix = "PServer_" + str(_get_ps_mode_rank()) + "_" + self._prefix
|
|
|
|
|
cb_params = run_context.original_args()
|
|
|
|
|
# save graph (only once)
|
|
|
|
|
if not self._graph_saved:
|
|
|
|
@ -281,8 +283,6 @@ class ModelCheckpoint(Callback):
|
|
|
|
|
if save_ckpt:
|
|
|
|
|
cur_ckpoint_file = self._prefix + "-" + str(cb_params.cur_epoch_num) + "_" \
|
|
|
|
|
+ str(step_num_in_epoch) + ".ckpt"
|
|
|
|
|
if _is_role_pserver():
|
|
|
|
|
cur_ckpoint_file = "PServer_" + str(_get_ps_mode_rank()) + "_" + cur_ckpoint_file
|
|
|
|
|
# update checkpoint file list.
|
|
|
|
|
self._manager.update_ckpoint_filelist(self._directory, self._prefix)
|
|
|
|
|
# keep checkpoint files number equal max number.
|
|
|
|
|