|
|
|
@ -55,6 +55,7 @@ class TestDistRunnerBase(object):
|
|
|
|
|
pserver_prog = t.get_pserver_program(args.current_endpoint)
|
|
|
|
|
startup_prog = t.get_startup_program(args.current_endpoint,
|
|
|
|
|
pserver_prog)
|
|
|
|
|
|
|
|
|
|
place = fluid.CPUPlace()
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
exe.run(startup_prog)
|
|
|
|
@ -147,6 +148,8 @@ def runtime_main(test_class):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import paddle.compat as cpt
|
|
|
|
|
import socket
|
|
|
|
|
from contextlib import closing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDistBase(unittest.TestCase):
|
|
|
|
@ -156,13 +159,19 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self._trainers = 2
|
|
|
|
|
self._pservers = 2
|
|
|
|
|
self._ps_endpoints = "127.0.0.1:9123,127.0.0.1:9124"
|
|
|
|
|
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
|
|
|
|
|
self._find_free_port(), self._find_free_port())
|
|
|
|
|
self._python_interp = "python"
|
|
|
|
|
self._sync_mode = True
|
|
|
|
|
self._mem_opt = False
|
|
|
|
|
self._use_reduce = False
|
|
|
|
|
self._setup_config()
|
|
|
|
|
|
|
|
|
|
def _find_free_port(self):
|
|
|
|
|
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
|
|
|
|
s.bind(('', 0))
|
|
|
|
|
return s.getsockname()[1]
|
|
|
|
|
|
|
|
|
|
def start_pserver(self, model_file, check_error_log):
|
|
|
|
|
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
|
|
|
|
|
ps_cmd = "%s %s --role pserver --endpoints %s --trainer_id 0 --current_endpoint %s --trainers %d --is_dist"
|
|
|
|
|