fix tangwei merge issue test=develop (#15506)

inference-pre-release-gpu
Wu Yi 6 years ago committed by GitHub
parent dec89bd7ed
commit 22db82c053
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -159,7 +159,7 @@ class ParallelExecutor(object):
trainers_endpoints = main._trainers_endpoints
if num_trainers > 1 and trainers_endpoints:
assert num_trainers == len(
trainers_endpoints), "num_trainers == len(end_points)"
trainers_endpoints), "num_trainers == len(endpoints)"
build_strategy.trainers_endpoints = trainers_endpoints
# step6: get persistable_vars, places. persistable_vars

@ -477,13 +477,16 @@ class DistributeTranspiler(object):
trainer_id,
trainers,
current_endpoint,
startup_program=None):
startup_program=None,
wait_port=True):
if not startup_program:
startup_program = default_startup_program()
if trainer_id >= 0:
worker_endpoints = trainers.split(",")
# send NCCL_ID to others or recv from trainer 0
worker_endpoints.remove(current_endpoint)
if trainer_id == 0 and wait_port:
wait_server_ready(worker_endpoints)
nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
@ -564,11 +567,13 @@ class DistributeTranspiler(object):
if self.config.mode == "nccl2":
assert (isinstance(trainers, str))
self.origin_program._trainers_endpoints = trainers.split(",")
self._transpile_nccl2(
trainer_id,
trainers,
current_endpoint,
startup_program=startup_program)
startup_program=startup_program,
wait_port=self.config.wait_port)
return
self.trainer_num = trainers

Loading…
Cancel
Save