restore param_path

wangkuiyi-patch-1
tangwei12 7 years ago
parent b044724db7
commit dca0b6d9cc

@ -105,6 +105,7 @@ class Trainer(object):
def __init__(self,
train_func,
optimizer,
param_path=None,
place=None,
parallel=False,
checkpoint_config=None):
@ -120,8 +121,8 @@ class Trainer(object):
# only chief worker will save variables
self.chief = True
self.checkpoint = checkpoint_config
if self.checkpoint and not isinstance(self.checkpoint,
CheckpointConfig):
if self.checkpoint and \
not isinstance(self.checkpoint, CheckpointConfig):
raise TypeError(
"The checkpoint_config shoule be an instance of CheckpointConfig"
)
@ -159,6 +160,10 @@ class Trainer(object):
io.load_checkpoint(exe, self.checkpoint.checkpoint_dir,
self.startup_program)
if param_path:
# load params from param_path into scope
io.load_persistables(exe, dirname=param_path)
def _transpile_nccl2_dist(self):
# PADDLE_TRAINER_IPS
if "PADDLE_TRAINER_IPS" not in os.environ:

Loading…
Cancel
Save