|
|
|
@ -61,9 +61,10 @@ class TestDistRunnerBase(object):
|
|
|
|
|
exe.run(startup_prog)
|
|
|
|
|
exe.run(pserver_prog)
|
|
|
|
|
|
|
|
|
|
def run_trainer(self, place, args):
|
|
|
|
|
def run_trainer(self, use_cuda, args):
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
|
|
|
|
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
|
|
|
|
|
self.get_model(batch_size=2)
|
|
|
|
|
if args.mem_opt:
|
|
|
|
@ -91,7 +92,7 @@ class TestDistRunnerBase(object):
|
|
|
|
|
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
|
|
|
|
|
|
|
|
|
|
exe = fluid.ParallelExecutor(
|
|
|
|
|
True,
|
|
|
|
|
use_cuda,
|
|
|
|
|
loss_name=avg_cost.name,
|
|
|
|
|
exec_strategy=strategy,
|
|
|
|
|
build_strategy=build_stra)
|
|
|
|
@ -142,9 +143,8 @@ def runtime_main(test_class):
|
|
|
|
|
if args.role == "pserver" and args.is_dist:
|
|
|
|
|
model.run_pserver(args)
|
|
|
|
|
else:
|
|
|
|
|
p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
|
|
|
|
|
) else fluid.CPUPlace()
|
|
|
|
|
model.run_trainer(p, args)
|
|
|
|
|
use_cuda = True if core.is_compiled_with_cuda() else False
|
|
|
|
|
model.run_trainer(use_cuda, args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import paddle.compat as cpt
|
|
|
|
@ -225,11 +225,12 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
def check_with_place(self, model_file, delta=1e-3, check_error_log=False):
|
|
|
|
|
# TODO(typhoonzero): should auto adapt GPU count on the machine.
|
|
|
|
|
required_envs = {
|
|
|
|
|
"PATH": os.getenv("PATH"),
|
|
|
|
|
"PYTHONPATH": os.getenv("PYTHONPATH"),
|
|
|
|
|
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH"),
|
|
|
|
|
"PATH": os.getenv("PATH", ""),
|
|
|
|
|
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
|
|
|
|
|
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
|
|
|
|
|
"FLAGS_fraction_of_gpu_memory_to_use": "0.15",
|
|
|
|
|
"FLAGS_cudnn_deterministic": "1"
|
|
|
|
|
"FLAGS_cudnn_deterministic": "1",
|
|
|
|
|
"CPU_NUM": "1"
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if check_error_log:
|
|
|
|
|