Fix multi nccl comm & wait server ready (#28663)

musl/disable_test_yolov3_temporarily
WangXi 5 years ago committed by GitHub
parent e7caf3b8d9
commit e931c7baf9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
import copy
import paddle import paddle
from paddle.fluid.framework import core from paddle.fluid.framework import core
from paddle.fluid import compiler from paddle.fluid import compiler
@ -51,13 +52,21 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
# should fix the variable # should fix the variable
def _setup_nccl_op(self, startup_program, main_program, build_strategy): def _setup_nccl_op(self, startup_program, main_program, build_strategy):
trainer_endpoints = self.role_maker._get_trainer_endpoints() trainer_endpoints = self.role_maker._get_trainer_endpoints()
trainers = trainer_endpoints other_trainers = copy.copy(trainer_endpoints)
trainer_id = self.role_maker._worker_index() trainer_id = self.role_maker._worker_index()
current_endpoint = self.role_maker._get_trainer_endpoints()[trainer_id] current_endpoint = self.role_maker._get_trainer_endpoints()[trainer_id]
other_trainers.remove(current_endpoint)
trainer_endpoints_env = ",".join(trainer_endpoints) trainer_endpoints_env = ",".join(trainer_endpoints)
trainers_num = self.role_maker._worker_num() trainers_num = self.role_maker._worker_num()
if trainer_id == 0:
wait_server_ready(other_trainers)
nccl_id_var = startup_program.global_block().create_var( nccl_id_var = startup_program.global_block().create_var(
name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW)
for i in range(1, build_strategy.nccl_comm_num): for i in range(1, build_strategy.nccl_comm_num):
startup_program.global_block().create_var( startup_program.global_block().create_var(
name="NCCLID_{}".format(i), name="NCCLID_{}".format(i),
@ -90,7 +99,6 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
}) })
def _try_to_compile(self, startup_program, main_program, loss): def _try_to_compile(self, startup_program, main_program, loss):
import copy
dist_strategy = self.user_defined_strategy dist_strategy = self.user_defined_strategy
local_build_strategy = paddle.fluid.BuildStrategy() local_build_strategy = paddle.fluid.BuildStrategy()
local_build_strategy.enable_sequential_execution = \ local_build_strategy.enable_sequential_execution = \
@ -148,13 +156,12 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
sync_allreduce = dist_strategy.sync_nccl_allreduce sync_allreduce = dist_strategy.sync_nccl_allreduce
if sync_allreduce: if sync_allreduce:
exe_strategy.num_threads = local_build_strategy.nccl_comm_num + 1 exe_strategy.num_threads = max(
if local_build_strategy.use_hierarchical_allreduce: local_build_strategy.nccl_comm_num + 1,
exe_strategy.num_threads = 2 * local_build_strategy.nccl_comm_num + 1 exe_strategy.num_threads)
if exe_strategy.num_threads > 4: if local_build_strategy.nccl_comm_num > 1:
logging.warn( logging.warn(
"if you use hierachical_allreduce or " "nccl_comm_num > 1, you may need to set sync_nccl_allreduce=False to ensure that different nccl comms can overlap"
"with multi nccl comm, please set distributed_strategy.sync_nccl_allreduce=False"
) )
sync_batch_norm = local_build_strategy.sync_batch_norm sync_batch_norm = local_build_strategy.sync_batch_norm
@ -167,6 +174,11 @@ class GraphExecutionOptimizer(MetaOptimizerBase):
"set num_threads=1, nccl_comm_num=1, hierachical_allreduce=False." "set num_threads=1, nccl_comm_num=1, hierachical_allreduce=False."
) )
# NOTE. compatible with compiler, otherwise these values will be overwritten by compiler
main_program._nccl_comm_num = local_build_strategy.nccl_comm_num
main_program._use_hierarchical_allreduce = local_build_strategy.use_hierarchical_allreduce
main_program._hierarchical_allreduce_inter_nranks = local_build_strategy.hierarchical_allreduce_inter_nranks
# TODO(guru4elephant): should be an independent optimizer # TODO(guru4elephant): should be an independent optimizer
self._setup_nccl_op(startup_program, main_program, local_build_strategy) self._setup_nccl_op(startup_program, main_program, local_build_strategy)

@ -75,6 +75,9 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
optimizer, strategy=strategy) optimizer, strategy=strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace())
exe.run(paddle.fluid.default_startup_program())
proc_a = launch_func(node_func, node_a) proc_a = launch_func(node_func, node_a)
proc_a.start() proc_a.start()
proc_b = launch_func(node_func, node_b) proc_b = launch_func(node_func, node_b)
@ -197,6 +200,9 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase):
optimizer, strategy=strategy) optimizer, strategy=strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace())
exe.run(paddle.fluid.default_startup_program())
proc_a = launch_func(node_func, node_a) proc_a = launch_func(node_func, node_a)
proc_a.start() proc_a.start()
proc_b = launch_func(node_func, node_b) proc_b = launch_func(node_func, node_b)

Loading…
Cancel
Save