remove chief

wangkuiyi-patch-1
tangwei12 7 years ago
parent 7fbddaa64a
commit 9e026a93cf

@ -466,7 +466,6 @@ CHECKPOINT_SEPARATOR = "_"
def save_checkpoint(executor,
checkpoint_dir,
trainer_id,
is_chief=False,
trainer_args=None,
main_program=None,
max_num_checkpoints=3):
@ -478,8 +477,7 @@ def save_checkpoint(executor,
:param executor executor for save the value
:param checkpoint_dir the checkpoint directory
:param trainer_id currect trainer id
:param is_chief if the trainer id equals 0, the is_chief will be true
:param trainer_id currect trainer id, if id is equal to 0, the trainer is chief
:param main_program will save all variables in program
:param max_num_checkpoints will keep numbers of checkpoint serials not bigger than max_num_checkpoints
"""
@ -497,7 +495,7 @@ def save_checkpoint(executor,
save_trainer_args(cur_dir, trainer_id, trainer_args)
if is_chief:
if trainer_id == 0:
save_persist_vars_without_grad(executor, cur_dir, main_program)
_scroll_delete(checkpoint_dir, max_num_checkpoints)

@ -136,7 +136,6 @@ class Trainer(object):
# config for checkpoint
# only chief worker will save variables
self.trainer_id = 0
self.chief = True
self.checkpoint_cfg = checkpoint_config
if self.checkpoint_cfg:
assert isinstance(self.checkpoint_cfg, CheckpointConfig)
@ -201,7 +200,6 @@ class Trainer(object):
self.nccl_id_var = None
else:
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
self.chief = self.trainer_id == 0
port = os.getenv("PADDLE_PSERVER_PORT")
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
worker_endpoints = []
@ -250,7 +248,7 @@ class Trainer(object):
# the unique trainer id, starting from 0, needed by trainer
# only
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
self.chief = self.trainer_id == 0
# the role, should be either PSERVER or TRAINER
training_role = os.getenv("PADDLE_TRAINING_ROLE")
with self._prog_and_scope_guard():
@ -456,7 +454,6 @@ class Trainer(object):
executor=exe,
checkpoint_dir=self.checkpoint_cfg.checkpoint_dir,
trainer_id=self.trainer_id,
is_chief=self.chief,
trainer_args=self._get_checkpoint_save_args(epoch_id, step_id),
main_program=self.train_program,
max_num_checkpoints=self.checkpoint_cfg.max_num_checkpoints)

Loading…
Cancel
Save