|
|
|
@ -16,6 +16,7 @@ import unittest
|
|
|
|
|
import os
|
|
|
|
|
import copy
|
|
|
|
|
from launch_function_helper import wait, _find_free_port
|
|
|
|
|
from multiprocessing import Pool, Process
|
|
|
|
|
from threading import Thread
|
|
|
|
|
|
|
|
|
|
os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10,gen_comm_id*=10")
|
|
|
|
@ -30,8 +31,8 @@ def run_gen_ncc_id(attr):
|
|
|
|
|
nccl_comm_num = attr['nccl_comm_num']
|
|
|
|
|
use_hallreduce = attr['use_hierarchical_allreduce']
|
|
|
|
|
|
|
|
|
|
startup_program = paddle.static.Program()
|
|
|
|
|
main_program = paddle.static.Program()
|
|
|
|
|
startup_program = paddle.static.default_startup_program()
|
|
|
|
|
main_program = paddle.static.default_main_program()
|
|
|
|
|
|
|
|
|
|
with paddle.static.program_guard(main_program, startup_program):
|
|
|
|
|
nccl_id_var = startup_program.global_block().create_var(
|
|
|
|
@ -62,8 +63,6 @@ def run_gen_ncc_id(attr):
|
|
|
|
|
|
|
|
|
|
place = paddle.CPUPlace()
|
|
|
|
|
exe = paddle.static.Executor(place)
|
|
|
|
|
scope = paddle.static.Scope()
|
|
|
|
|
with paddle.static.scope_guard(scope):
|
|
|
|
|
exe.run(startup_program)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -99,13 +98,12 @@ class TestGenNcclIdOp(unittest.TestCase):
|
|
|
|
|
procs = []
|
|
|
|
|
for i in range(nranks):
|
|
|
|
|
attr['trainer_id'] = i
|
|
|
|
|
# NOTE. multiprocessing cannot be covered by coverage
|
|
|
|
|
p = Thread(target=run_gen_ncc_id, args=(copy.copy(attr), ))
|
|
|
|
|
# NOTE: multiprocessing cannot be covered by coverage
|
|
|
|
|
p = Process(target=run_gen_ncc_id, args=(attr, ))
|
|
|
|
|
p.start()
|
|
|
|
|
procs.append(p)
|
|
|
|
|
|
|
|
|
|
for p in procs:
|
|
|
|
|
p.join()
|
|
|
|
|
wait(procs, timeout=120)
|
|
|
|
|
|
|
|
|
|
def test_flat(self):
|
|
|
|
|
print(">>> test gen flat nccl id")
|
|
|
|
|