|
|
|
@ -348,10 +348,7 @@ class Trainer(object):
|
|
|
|
|
training_role = os.getenv("PADDLE_TRAINING_ROLE")
|
|
|
|
|
with self._prog_and_scope_guard():
|
|
|
|
|
t = distribute_transpiler.DistributeTranspiler()
|
|
|
|
|
t.transpile(
|
|
|
|
|
self.trainer_id,
|
|
|
|
|
pservers=pserver_endpoints,
|
|
|
|
|
trainers=trainers)
|
|
|
|
|
t.transpile(self.trainer_id, pservers=pserver_endpoints, trainers=trainers)
|
|
|
|
|
if training_role == "PSERVER":
|
|
|
|
|
if self.checkpoint_cfg:
|
|
|
|
|
pserver_id = eplist.index(current_endpoint)
|
|
|
|
|