|
|
|
@ -227,6 +227,7 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self._trainers = 2
|
|
|
|
|
self._pservers = 2
|
|
|
|
|
self._port_set = set()
|
|
|
|
|
self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % (
|
|
|
|
|
self._find_free_port(), self._find_free_port())
|
|
|
|
|
self._python_interp = sys.executable
|
|
|
|
@ -242,9 +243,17 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
self._after_setup_config()
|
|
|
|
|
|
|
|
|
|
def _find_free_port(self):
|
|
|
|
|
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
|
|
|
|
s.bind(('', 0))
|
|
|
|
|
return s.getsockname()[1]
|
|
|
|
|
def __free_port():
|
|
|
|
|
with closing(socket.socket(socket.AF_INET,
|
|
|
|
|
socket.SOCK_STREAM)) as s:
|
|
|
|
|
s.bind(('', 0))
|
|
|
|
|
return s.getsockname()[1]
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
port = __free_port()
|
|
|
|
|
if port not in self._port_set:
|
|
|
|
|
self._port_set.add(port)
|
|
|
|
|
return port
|
|
|
|
|
|
|
|
|
|
def start_pserver(self, model_file, check_error_log, required_envs):
|
|
|
|
|
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
|
|
|
|
|