|
|
|
@ -131,7 +131,40 @@ class Trainer(object):
|
|
|
|
|
# load params from param_path into scope
|
|
|
|
|
io.load_persistables(exe, dirname=param_path)
|
|
|
|
|
|
|
|
|
|
def _transpile_nccl2_dist(self):
|
|
|
|
|
# PADDLE_TRAINER_IPS
|
|
|
|
|
if "PADDLE_TRAINER_IPS" not in os.environ:
|
|
|
|
|
self.nccl_id_var = None
|
|
|
|
|
else:
|
|
|
|
|
self.trainer_id = int(os.getenv("PADDLE_TRAINER_ID"))
|
|
|
|
|
port = os.getenv("PADDLE_PSERVER_PORT")
|
|
|
|
|
worker_ips = os.getenv("PADDLE_TRAINER_IPS")
|
|
|
|
|
worker_endpoints = []
|
|
|
|
|
for ip in worker_ips.split(","):
|
|
|
|
|
worker_endpoints.append(':'.join([ip, port]))
|
|
|
|
|
self.num_trainers = len(worker_endpoints)
|
|
|
|
|
current_endpoint = os.getenv("POD_IP") + ":" + port
|
|
|
|
|
worker_endpoints.remove(current_endpoint)
|
|
|
|
|
# TODO(wuyi): use self.nccl_id_var, self.num_trainers and self.trainer_id
|
|
|
|
|
# in ParallelExecutor to start
|
|
|
|
|
# distributed training using NCCL2
|
|
|
|
|
self.nccl_id_var = self.startup_program.global_block().create_var(
|
|
|
|
|
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
|
|
|
|
|
self.startup_program.global_block().append_op(
|
|
|
|
|
type="gen_nccl_id",
|
|
|
|
|
inputs={},
|
|
|
|
|
outputs={"NCCLID": self.nccl_id_var},
|
|
|
|
|
attrs={
|
|
|
|
|
"endpoint": current_endpoint,
|
|
|
|
|
"endpoint_list": worker_endpoints,
|
|
|
|
|
"trainer_id": self.trainer_id
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
def _dist_transpile_if_necessary(self, optimize_ops, params_grads):
|
|
|
|
|
self._transpile_nccl2_dist()
|
|
|
|
|
if self.nccl_id_var != None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if "PADDLE_TRAINING_ROLE" not in os.environ:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|