|
|
|
@ -136,7 +136,6 @@ class Trainer(object):
|
|
|
|
|
# config for checkpoint
|
|
|
|
|
# only chief worker will save variables
|
|
|
|
|
self.trainer_id = 0
|
|
|
|
|
self.chief = True
|
|
|
|
|
self.checkpoint_cfg = checkpoint_config
|
|
|
|
|
if self.checkpoint_cfg:
|
|
|
|
|
assert isinstance(self.checkpoint_cfg, CheckpointConfig)
|
|
|
|
@ -201,7 +200,6 @@ class Trainer(object):
|
|
|
|
|
self.nccl_id_var = None
|
|
|
|
|
else:
|
|
|
|
|
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
|
|
|
|
|
self.chief = self.trainer_id == 0
|
|
|
|
|
port = os.getenv("PADDLE_PSERVER_PORT")
|
|
|
|
|
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
|
|
|
|
|
worker_endpoints = []
|
|
|
|
@ -250,7 +248,7 @@ class Trainer(object):
|
|
|
|
|
# the unique trainer id, starting from 0, needed by trainer
|
|
|
|
|
# only
|
|
|
|
|
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
|
|
|
|
|
self.chief = self.trainer_id == 0
|
|
|
|
|
|
|
|
|
|
# the role, should be either PSERVER or TRAINER
|
|
|
|
|
training_role = os.getenv("PADDLE_TRAINING_ROLE")
|
|
|
|
|
with self._prog_and_scope_guard():
|
|
|
|
@ -456,7 +454,6 @@ class Trainer(object):
|
|
|
|
|
executor=exe,
|
|
|
|
|
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
|
|
|
|
|
trainer_id=self.trainer_id,
|
|
|
|
|
is_chief=self.chief,
|
|
|
|
|
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
|
|
|
|
|
main_program=self.train_program,
|
|
|
|
|
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)
|
|
|
|
|