|
|
|
@ -477,13 +477,16 @@ class DistributeTranspiler(object):
|
|
|
|
|
trainer_id,
|
|
|
|
|
trainers,
|
|
|
|
|
current_endpoint,
|
|
|
|
|
startup_program=None):
|
|
|
|
|
startup_program=None,
|
|
|
|
|
wait_port=True):
|
|
|
|
|
if not startup_program:
|
|
|
|
|
startup_program = default_startup_program()
|
|
|
|
|
if trainer_id >= 0:
|
|
|
|
|
worker_endpoints = trainers.split(",")
|
|
|
|
|
# send NCCL_ID to others or recv from trainer 0
|
|
|
|
|
worker_endpoints.remove(current_endpoint)
|
|
|
|
|
if trainer_id == 0 and wait_port:
|
|
|
|
|
wait_server_ready(worker_endpoints)
|
|
|
|
|
|
|
|
|
|
nccl_id_var = startup_program.global_block().create_var(
|
|
|
|
|
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
|
|
|
|
@ -564,11 +567,13 @@ class DistributeTranspiler(object):
|
|
|
|
|
|
|
|
|
|
if self.config.mode == "nccl2":
|
|
|
|
|
assert (isinstance(trainers, str))
|
|
|
|
|
self.origin_program._trainers_endpoints = trainers.split(",")
|
|
|
|
|
self._transpile_nccl2(
|
|
|
|
|
trainer_id,
|
|
|
|
|
trainers,
|
|
|
|
|
current_endpoint,
|
|
|
|
|
startup_program=startup_program)
|
|
|
|
|
startup_program=startup_program,
|
|
|
|
|
wait_port=self.config.wait_port)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
self.trainer_num = trainers
|
|
|
|
|