|
|
|
@ -103,8 +103,12 @@ class TestDataBalance(unittest.TestCase):
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
exe.run(startup_prog)
|
|
|
|
|
|
|
|
|
|
build_strategy = fluid.BuildStrategy()
|
|
|
|
|
build_strategy.enable_data_balance = True
|
|
|
|
|
parallel_exe = fluid.ParallelExecutor(
|
|
|
|
|
use_cuda=self.use_cuda, main_program=main_prog)
|
|
|
|
|
use_cuda=self.use_cuda,
|
|
|
|
|
main_program=main_prog,
|
|
|
|
|
build_strategy=build_strategy)
|
|
|
|
|
|
|
|
|
|
if (parallel_exe.device_count > self.batch_size):
|
|
|
|
|
print("WARNING: Unittest TestDataBalance skipped. \
|
|
|
|
@ -145,9 +149,12 @@ class TestDataBalance(unittest.TestCase):
|
|
|
|
|
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
exe.run(startup_prog)
|
|
|
|
|
|
|
|
|
|
build_strategy = fluid.BuildStrategy()
|
|
|
|
|
build_strategy.enable_data_balance = True
|
|
|
|
|
parallel_exe = fluid.ParallelExecutor(
|
|
|
|
|
use_cuda=self.use_cuda, main_program=main_prog)
|
|
|
|
|
use_cuda=self.use_cuda,
|
|
|
|
|
main_program=main_prog,
|
|
|
|
|
build_strategy=build_strategy)
|
|
|
|
|
|
|
|
|
|
if (parallel_exe.device_count > self.batch_size):
|
|
|
|
|
print("WARNING: Unittest TestDataBalance skipped. \
|
|
|
|
|