fix test_gen_nccl_id_op failed (#30686)

revert-31068-fix_conv3d_windows
WangXi 4 years ago committed by GitHub
parent 164275704d
commit a28a202603
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,6 +16,7 @@ import unittest
import os
import copy
from launch_function_helper import wait, _find_free_port
from multiprocessing import Pool, Process
from threading import Thread
os.environ['GLOG_vmodule'] = str("gen_nccl_id_op*=10,gen_comm_id*=10")
@ -30,8 +31,8 @@ def run_gen_ncc_id(attr):
nccl_comm_num = attr['nccl_comm_num']
use_hallreduce = attr['use_hierarchical_allreduce']
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
startup_program = paddle.static.default_startup_program()
main_program = paddle.static.default_main_program()
with paddle.static.program_guard(main_program, startup_program):
nccl_id_var = startup_program.global_block().create_var(
@ -62,8 +63,6 @@ def run_gen_ncc_id(attr):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
@ -99,13 +98,12 @@ class TestGenNcclIdOp(unittest.TestCase):
procs = []
for i in range(nranks):
attr['trainer_id'] = i
# NOTE. multiprocessing cannot be covered by coverage
p = Thread(target=run_gen_ncc_id, args=(copy.copy(attr), ))
# NOTE: multiprocessing cannot be covered by coverage
p = Process(target=run_gen_ncc_id, args=(attr, ))
p.start()
procs.append(p)
for p in procs:
p.join()
wait(procs, timeout=120)
def test_flat(self):
print(">>> test gen flat nccl id")

Loading…
Cancel
Save