|
|
|
@ -82,8 +82,18 @@ class TestDistRunnerBase(object):
|
|
|
|
|
strategy = fluid.ExecutionStrategy()
|
|
|
|
|
strategy.num_threads = 1
|
|
|
|
|
strategy.allow_op_delay = False
|
|
|
|
|
build_stra = fluid.BuildStrategy()
|
|
|
|
|
|
|
|
|
|
if args.use_reduce:
|
|
|
|
|
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.Reduce
|
|
|
|
|
else:
|
|
|
|
|
build_stra.reduce_strategy = fluid.BuildStrategy.ReduceStrategy.AllReduce
|
|
|
|
|
|
|
|
|
|
exe = fluid.ParallelExecutor(
|
|
|
|
|
True, loss_name=avg_cost.name, exec_strategy=strategy)
|
|
|
|
|
True,
|
|
|
|
|
loss_name=avg_cost.name,
|
|
|
|
|
exec_strategy=strategy,
|
|
|
|
|
build_strategy=build_stra)
|
|
|
|
|
|
|
|
|
|
feed_var_list = [
|
|
|
|
|
var for var in trainer_prog.global_block().vars.values()
|
|
|
|
@ -123,6 +133,7 @@ def runtime_main(test_class):
|
|
|
|
|
'--current_endpoint', type=str, required=False, default="")
|
|
|
|
|
parser.add_argument('--sync_mode', action='store_true')
|
|
|
|
|
parser.add_argument('--mem_opt', action='store_true')
|
|
|
|
|
parser.add_argument('--use_reduce', action='store_true')
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
@ -149,20 +160,25 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
self._python_interp = "python"
|
|
|
|
|
self._sync_mode = True
|
|
|
|
|
self._mem_opt = False
|
|
|
|
|
self._use_reduce = False
|
|
|
|
|
self._setup_config()
|
|
|
|
|
|
|
|
|
|
def start_pserver(self, model_file, check_error_log):
|
|
|
|
|
|
|
|
|
|
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
|
|
|
|
|
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist %s %s"
|
|
|
|
|
sync_mode_str = "--sync_mode" if self._sync_mode else ""
|
|
|
|
|
mem_opt_str = "--mem_opt" if self._mem_opt else ""
|
|
|
|
|
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
|
|
|
|
|
ps0_cmd = ps_cmd % \
|
|
|
|
|
(self._python_interp, model_file, self._ps_endpoints, ps0_ep,
|
|
|
|
|
self._trainers, sync_mode_str, mem_opt_str)
|
|
|
|
|
self._trainers)
|
|
|
|
|
ps1_cmd = ps_cmd % \
|
|
|
|
|
(self._python_interp, model_file, self._ps_endpoints, ps1_ep,
|
|
|
|
|
self._trainers, sync_mode_str, mem_opt_str)
|
|
|
|
|
self._trainers)
|
|
|
|
|
|
|
|
|
|
if self._sync_mode:
|
|
|
|
|
ps0_cmd += " --sync_mode"
|
|
|
|
|
ps1_cmd += " --sync_mode"
|
|
|
|
|
if self._mem_opt:
|
|
|
|
|
ps0_cmd += " --mem_opt"
|
|
|
|
|
ps1_cmd += " --mem_opt"
|
|
|
|
|
|
|
|
|
|
ps0_pipe = subprocess.PIPE
|
|
|
|
|
ps1_pipe = subprocess.PIPE
|
|
|
|
@ -242,17 +258,23 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
self._wait_ps_ready(ps1.pid)
|
|
|
|
|
|
|
|
|
|
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
|
|
|
|
|
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist %s %s"
|
|
|
|
|
sync_mode_str = "--sync_mode" if self._sync_mode else ""
|
|
|
|
|
mem_opt_str = "--mem_opt" if self._mem_opt else ""
|
|
|
|
|
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
|
|
|
|
|
tr0_cmd = tr_cmd % \
|
|
|
|
|
(self._python_interp, model_file, self._ps_endpoints,
|
|
|
|
|
0, ps0_ep,
|
|
|
|
|
self._trainers, sync_mode_str, mem_opt_str)
|
|
|
|
|
0, ps0_ep, self._trainers)
|
|
|
|
|
tr1_cmd = tr_cmd % \
|
|
|
|
|
(self._python_interp, model_file, self._ps_endpoints,
|
|
|
|
|
1, ps1_ep,
|
|
|
|
|
self._trainers, sync_mode_str, mem_opt_str)
|
|
|
|
|
1, ps1_ep, self._trainers)
|
|
|
|
|
|
|
|
|
|
if self._sync_mode:
|
|
|
|
|
tr0_cmd += " --sync_mode"
|
|
|
|
|
tr1_cmd += " --sync_mode"
|
|
|
|
|
if self._mem_opt:
|
|
|
|
|
tr0_cmd += " --mem_opt"
|
|
|
|
|
tr1_cmd += " --mem_opt"
|
|
|
|
|
if self._use_reduce:
|
|
|
|
|
tr0_cmd += " --use_reduce"
|
|
|
|
|
tr1_cmd += " --use_reduce"
|
|
|
|
|
|
|
|
|
|
env0 = {"CUDA_VISIBLE_DEVICES": "0"}
|
|
|
|
|
env1 = {"CUDA_VISIBLE_DEVICES": "1"}
|
|
|
|
@ -303,6 +325,8 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
# FIXME: use terminate() instead of sigkill.
|
|
|
|
|
os.kill(ps0.pid, signal.SIGKILL)
|
|
|
|
|
os.kill(ps1.pid, signal.SIGKILL)
|
|
|
|
|
ps0.terminate()
|
|
|
|
|
ps1.terminate()
|
|
|
|
|
ps0.wait()
|
|
|
|
|
ps1.wait()
|
|
|
|
|
FNULL.close()
|
|
|
|
|