|
|
@ -105,6 +105,7 @@ class Trainer(object):
|
|
|
|
def __init__(self,
|
|
|
|
def __init__(self,
|
|
|
|
train_func,
|
|
|
|
train_func,
|
|
|
|
optimizer,
|
|
|
|
optimizer,
|
|
|
|
|
|
|
|
param_path=None,
|
|
|
|
place=None,
|
|
|
|
place=None,
|
|
|
|
parallel=False,
|
|
|
|
parallel=False,
|
|
|
|
checkpoint_config=None):
|
|
|
|
checkpoint_config=None):
|
|
|
@ -120,8 +121,8 @@ class Trainer(object):
|
|
|
|
# only chief worker will save variables
|
|
|
|
# only chief worker will save variables
|
|
|
|
self.chief = True
|
|
|
|
self.chief = True
|
|
|
|
self.checkpoint = checkpoint_config
|
|
|
|
self.checkpoint = checkpoint_config
|
|
|
|
if self.checkpoint and not isinstance(self.checkpoint,
|
|
|
|
if self.checkpoint and \
|
|
|
|
CheckpointConfig):
|
|
|
|
not isinstance(self.checkpoint, CheckpointConfig):
|
|
|
|
raise TypeError(
|
|
|
|
raise TypeError(
|
|
|
|
"The checkpoint_config shoule be an instance of CheckpointConfig"
|
|
|
|
"The checkpoint_config shoule be an instance of CheckpointConfig"
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -159,6 +160,10 @@ class Trainer(object):
|
|
|
|
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
|
|
|
|
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
|
|
|
|
self.startup_program)
|
|
|
|
self.startup_program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if param_path:
|
|
|
|
|
|
|
|
# load params from param_path into scope
|
|
|
|
|
|
|
|
io.load_persistables(exe, dirname=param_path)
|
|
|
|
|
|
|
|
|
|
|
|
def _transpile_nccl2_dist(self):
|
|
|
|
def _transpile_nccl2_dist(self):
|
|
|
|
# PADDLE_TRAINER_IPS
|
|
|
|
# PADDLE_TRAINER_IPS
|
|
|
|
if "PADDLE_TRAINER_IPS" not in os.environ:
|
|
|
|
if "PADDLE_TRAINER_IPS" not in os.environ:
|
|
|
|