|
|
|
@ -37,30 +37,6 @@ class TestCollectiveAPIRunnerBase(object):
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"get model should be implemented by child class.")
|
|
|
|
|
|
|
|
|
|
def wait_server_ready(self, endpoints):
|
|
|
|
|
assert not isinstance(endpoints, string_types)
|
|
|
|
|
while True:
|
|
|
|
|
all_ok = True
|
|
|
|
|
not_ready_endpoints = []
|
|
|
|
|
for ep in endpoints:
|
|
|
|
|
ip_port = ep.split(":")
|
|
|
|
|
with closing(
|
|
|
|
|
socket.socket(socket.AF_INET,
|
|
|
|
|
socket.SOCK_STREAM)) as sock:
|
|
|
|
|
sock.settimeout(2)
|
|
|
|
|
result = sock.connect_ex((ip_port[0], int(ip_port[1])))
|
|
|
|
|
if result != 0:
|
|
|
|
|
all_ok = False
|
|
|
|
|
not_ready_endpoints.append(ep)
|
|
|
|
|
if not all_ok:
|
|
|
|
|
sys.stderr.write("server not ready, wait 3 sec to retry...\n")
|
|
|
|
|
sys.stderr.write("not ready endpoints:" + str(
|
|
|
|
|
not_ready_endpoints) + "\n")
|
|
|
|
|
sys.stderr.flush()
|
|
|
|
|
time.sleep(3)
|
|
|
|
|
else:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
def run_trainer(self, args):
|
|
|
|
|
train_prog = fluid.Program()
|
|
|
|
|
startup_prog = fluid.Program()
|
|
|
|
@ -157,8 +133,8 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
tr_cmd = "%s %s"
|
|
|
|
|
tr0_cmd = tr_cmd % (self._python_interp, model_file)
|
|
|
|
|
tr1_cmd = tr_cmd % (self._python_interp, model_file)
|
|
|
|
|
tr0_pipe = open("/tmp/tr0_err.log", "w")
|
|
|
|
|
tr1_pipe = open("/tmp/tr1_err.log", "w")
|
|
|
|
|
tr0_pipe = open("/tmp/tr0_err_%d.log" % os.getpid(), "w")
|
|
|
|
|
tr1_pipe = open("/tmp/tr1_err_%d.log" % os.getpid(), "w")
|
|
|
|
|
#print(tr0_cmd)
|
|
|
|
|
tr0_proc = subprocess.Popen(
|
|
|
|
|
tr0_cmd.strip().split(),
|
|
|
|
@ -179,9 +155,9 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
# close trainer file
|
|
|
|
|
tr0_pipe.close()
|
|
|
|
|
tr1_pipe.close()
|
|
|
|
|
with open("/tmp/tr0_err.log", "r") as f:
|
|
|
|
|
with open("/tmp/tr0_err_%d.log" % os.getpid(), "r") as f:
|
|
|
|
|
sys.stderr.write('trainer 0 stderr file: %s\n' % f.read())
|
|
|
|
|
with open("/tmp/tr1_err.log", "r") as f:
|
|
|
|
|
with open("/tmp/tr1_err_%d.log" % os.getpid(), "r") as f:
|
|
|
|
|
sys.stderr.write('trainer 1 stderr file: %s\n' % f.read())
|
|
|
|
|
return pickle.loads(tr0_out), pickle.loads(
|
|
|
|
|
tr1_out), tr0_proc.pid, tr1_proc.pid
|
|
|
|
|