|
|
|
@ -132,19 +132,18 @@ class Trainer(object):
|
|
|
|
|
# 1. we need to generate a framework.Program by calling
|
|
|
|
|
# program_func. Reference: fluid.program_guard in
|
|
|
|
|
# test_word2vec.py
|
|
|
|
|
if not isinstance(optimizer, opt_module.Optimizer):
|
|
|
|
|
raise TypeError("The optimizer should be an instance of Optimizer")
|
|
|
|
|
assert isinstance(optimizer, opt_module.Optimizer)
|
|
|
|
|
|
|
|
|
|
# config for checkpoint
|
|
|
|
|
# only chief worker will save variables
|
|
|
|
|
self.trainer_id = 0
|
|
|
|
|
self.chief = True
|
|
|
|
|
self.checkpoint = checkpoint_config
|
|
|
|
|
if self.checkpoint:
|
|
|
|
|
assert isinstance(self.checkpoint, CheckpointConfig)
|
|
|
|
|
self.checkpoint_cfg = checkpoint_config
|
|
|
|
|
if self.checkpoint_cfg:
|
|
|
|
|
assert isinstance(self.checkpoint_cfg, CheckpointConfig)
|
|
|
|
|
serial = io.get_latest_checkpoint_serial(
|
|
|
|
|
self.checkpoint.checkpoint_dir)
|
|
|
|
|
self.checkpoint.load_serial = serial if serial >= 0 else None
|
|
|
|
|
self.checkpoint_cfg.checkpoint_dir)
|
|
|
|
|
self.checkpoint_cfg.load_serial = serial if serial >= 0 else None
|
|
|
|
|
|
|
|
|
|
self.scope = core.Scope()
|
|
|
|
|
|
|
|
|
@ -174,19 +173,20 @@ class Trainer(object):
|
|
|
|
|
exe = executor.Executor(place)
|
|
|
|
|
exe.run(self.startup_program)
|
|
|
|
|
|
|
|
|
|
if self.checkpoint and self.checkpoint.load_serial:
|
|
|
|
|
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial:
|
|
|
|
|
with self._prog_and_scope_guard():
|
|
|
|
|
exe = executor.Executor(place)
|
|
|
|
|
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
|
|
|
|
|
self.checkpoint.load_serial,
|
|
|
|
|
io.load_checkpoint(exe, self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
self.checkpoint_cfg.load_serial,
|
|
|
|
|
self.startup_program)
|
|
|
|
|
|
|
|
|
|
if not self.checkpoint.is_pserver:
|
|
|
|
|
if not self.checkpoint_cfg.is_pserver:
|
|
|
|
|
epoch_id, step_id = io.load_trainer_args(
|
|
|
|
|
self.checkpoint.checkpoint_dir, self.checkpoint.load_serial,
|
|
|
|
|
self.trainer_id, self._get_checkpoint_load_args())
|
|
|
|
|
self.checkpoint.epoch_id = int(epoch_id)
|
|
|
|
|
self.checkpoint.step_id = int(step_id)
|
|
|
|
|
self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
self.checkpoint_cfg.load_serial, self.trainer_id,
|
|
|
|
|
self._get_checkpoint_load_args())
|
|
|
|
|
self.checkpoint_cfg.epoch_id = int(epoch_id)
|
|
|
|
|
self.checkpoint_cfg.step_id = int(step_id)
|
|
|
|
|
|
|
|
|
|
if param_path and os.path.isdir(param_path):
|
|
|
|
|
# load params from param_path into scope
|
|
|
|
@ -256,7 +256,7 @@ class Trainer(object):
|
|
|
|
|
t.transpile(
|
|
|
|
|
self.trainer_id, pservers=pserver_endpoints, trainers=trainers)
|
|
|
|
|
if training_role == "PSERVER":
|
|
|
|
|
if self.checkpoint:
|
|
|
|
|
if self.checkpoint_cfg:
|
|
|
|
|
self.is_pserver = True
|
|
|
|
|
|
|
|
|
|
self.train_program = t.get_pserver_program(current_endpoint)
|
|
|
|
@ -351,10 +351,10 @@ class Trainer(object):
|
|
|
|
|
self._train_by_any_executor(event_handler, exe, num_epochs, reader)
|
|
|
|
|
|
|
|
|
|
def _train_by_any_executor(self, event_handler, exe, num_epochs, reader):
|
|
|
|
|
if self.checkpoint:
|
|
|
|
|
if self.checkpoint_cfg:
|
|
|
|
|
epochs = [
|
|
|
|
|
epoch_id for epoch_id in range(num_epochs)
|
|
|
|
|
if epoch_id >= self.checkpoint.epoch_id
|
|
|
|
|
if epoch_id >= self.checkpoint_cfg.epoch_id
|
|
|
|
|
]
|
|
|
|
|
else:
|
|
|
|
|
epochs = [epoch_id for epoch_id in range(num_epochs)]
|
|
|
|
@ -366,8 +366,8 @@ class Trainer(object):
|
|
|
|
|
self._clean_checkpoint()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if self.checkpoint and self.checkpoint.load_serial \
|
|
|
|
|
and self.checkpoint.step_id >= step_id and self.checkpoint.epoch_id == epoch_id:
|
|
|
|
|
if self.checkpoint_cfg and self.checkpoint_cfg.load_serial \
|
|
|
|
|
and self.checkpoint_cfg.step_id >= step_id and self.checkpoint_cfg.epoch_id == epoch_id:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
begin_event = BeginStepEvent(epoch_id, step_id)
|
|
|
|
@ -381,9 +381,11 @@ class Trainer(object):
|
|
|
|
|
else:
|
|
|
|
|
metrics = exe.run(feed=data, fetch_list=[])
|
|
|
|
|
|
|
|
|
|
if self.checkpoint_cfg:
|
|
|
|
|
self._save_checkpoint(epoch_id, step_id)
|
|
|
|
|
event_handler(EndStepEvent(epoch_id, step_id, metrics))
|
|
|
|
|
event_handler(EndEpochEvent(epoch_id))
|
|
|
|
|
if self.checkpoint_cfg:
|
|
|
|
|
self._clean_checkpoint()
|
|
|
|
|
|
|
|
|
|
def _test_by_executor(self, reader, feed_order, fetch_list):
|
|
|
|
@ -424,9 +426,8 @@ class Trainer(object):
|
|
|
|
|
return self._get_parallel_executor()
|
|
|
|
|
|
|
|
|
|
def _clean_checkpoint(self):
|
|
|
|
|
if not self.checkpoint:
|
|
|
|
|
return
|
|
|
|
|
io.clean_checkpoint(checkpoint_dir=self.checkpoint.checkpoint_dir)
|
|
|
|
|
assert self.checkpoint_cfg
|
|
|
|
|
io.clean_checkpoint(checkpoint_dir=self.checkpoint_cfg.checkpoint_dir)
|
|
|
|
|
|
|
|
|
|
def _get_checkpoint_load_args(self):
|
|
|
|
|
"""
|
|
|
|
@ -444,19 +445,18 @@ class Trainer(object):
|
|
|
|
|
return trainer_args
|
|
|
|
|
|
|
|
|
|
def _save_checkpoint(self, epoch_id, step_id):
|
|
|
|
|
if not self.checkpoint:
|
|
|
|
|
return
|
|
|
|
|
assert self.checkpoint_cfg
|
|
|
|
|
|
|
|
|
|
if epoch_id % self.checkpoint.epoch_interval == 0 and step_id % self.checkpoint.step_interval == 0:
|
|
|
|
|
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 and step_id % self.checkpoint_cfg.step_interval == 0:
|
|
|
|
|
exe = executor.Executor(self.place)
|
|
|
|
|
io.save_checkpoint(
|
|
|
|
|
executor=exe,
|
|
|
|
|
checkpoint_dir=self.checkpoint.checkpoint_dir,
|
|
|
|
|
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
trainer_id=self.trainer_id,
|
|
|
|
|
is_chief=self.chief,
|
|
|
|
|
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
|
|
|
|
|
main_program=self.train_program,
|
|
|
|
|
max_num_checkpoints=self.checkpoint.max_num_checkpoints)
|
|
|
|
|
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_feed_var_list(program, feed_order):
|
|
|
|
|