|
|
|
|
@ -62,27 +62,20 @@ class CheckpointConfig(object):
|
|
|
|
|
max_num_checkpoints=3,
|
|
|
|
|
epoch_interval=1,
|
|
|
|
|
step_interval=10):
|
|
|
|
|
if checkpoint_dir is None:
|
|
|
|
|
self.checkpoint_dir = os.getcwd()
|
|
|
|
|
else:
|
|
|
|
|
self.checkpoint_dir = checkpoint_dir
|
|
|
|
|
|
|
|
|
|
self.max_num_checkpoints = max_num_checkpoints
|
|
|
|
|
|
|
|
|
|
if epoch_interval < 1:
|
|
|
|
|
self.epoch_interval = 1
|
|
|
|
|
else:
|
|
|
|
|
self.epoch_interval = epoch_interval
|
|
|
|
|
|
|
|
|
|
if step_interval < 1:
|
|
|
|
|
self.step_interval = 10
|
|
|
|
|
else:
|
|
|
|
|
self.step_interval = step_interval
|
|
|
|
|
assert epoch_interval >= 1
|
|
|
|
|
assert step_interval >= 1
|
|
|
|
|
|
|
|
|
|
self.checkpoint_dir = checkpoint_dir if checkpoint_dir is not None else os.getcwd(
|
|
|
|
|
)
|
|
|
|
|
self.max_num_checkpoints = max_num_checkpoints
|
|
|
|
|
self.epoch_interval = epoch_interval
|
|
|
|
|
self.step_interval = step_interval
|
|
|
|
|
self.epoch_id = 0
|
|
|
|
|
self.step_id = 0
|
|
|
|
|
self.load_serial = None
|
|
|
|
|
self.is_pserver = False
|
|
|
|
|
self.has_lookup_table = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_and_get_place(place):
|
|
|
|
|
@ -181,13 +174,18 @@ class Trainer(object):
|
|
|
|
|
self.checkpoint_cfg.load_serial,
|
|
|
|
|
self.startup_program)
|
|
|
|
|
|
|
|
|
|
if not self.checkpoint_cfg.is_pserver:
|
|
|
|
|
epoch_id, step_id = io.load_trainer_args(
|
|
|
|
|
self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
self.checkpoint_cfg.load_serial, self.trainer_id,
|
|
|
|
|
self._get_checkpoint_load_args())
|
|
|
|
|
self.checkpoint_cfg.epoch_id = int(epoch_id)
|
|
|
|
|
self.checkpoint_cfg.step_id = int(step_id)
|
|
|
|
|
if not self.checkpoint_cfg.is_pserver:
|
|
|
|
|
epoch_id, step_id = io.load_trainer_args(
|
|
|
|
|
self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
self.checkpoint_cfg.load_serial, self.trainer_id,
|
|
|
|
|
self._get_checkpoint_load_args())
|
|
|
|
|
self.checkpoint_cfg.epoch_id = int(epoch_id)
|
|
|
|
|
self.checkpoint_cfg.step_id = int(step_id)
|
|
|
|
|
else:
|
|
|
|
|
if self.checkpoint_cfg.has_lookup_table:
|
|
|
|
|
io.load_lookup_table_vars(
|
|
|
|
|
exe, self.checkpoint_cfg.checkpoint_dir, 0,
|
|
|
|
|
"table_name")
|
|
|
|
|
|
|
|
|
|
if param_path and os.path.isdir(param_path):
|
|
|
|
|
# load params from param_path into scope
|
|
|
|
|
|