|
|
|
@ -19,6 +19,7 @@ import math
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
import os
|
|
|
|
|
import sys
|
|
|
|
|
import signal
|
|
|
|
|
import subprocess
|
|
|
|
|
|
|
|
|
@ -56,7 +57,7 @@ class TestDistSeResneXt2x2(unittest.TestCase):
|
|
|
|
|
except os.error:
|
|
|
|
|
retry_times -= 1
|
|
|
|
|
|
|
|
|
|
def no_test_with_place(self):
|
|
|
|
|
def test_with_place(self):
|
|
|
|
|
# *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
|
|
|
|
|
required_envs = {
|
|
|
|
|
"PATH": os.getenv("PATH"),
|
|
|
|
@ -70,9 +71,15 @@ class TestDistSeResneXt2x2(unittest.TestCase):
|
|
|
|
|
local_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d FLASE" % \
|
|
|
|
|
(self._python_interp, "127.0.0.1:1234", "127.0.0.1:1234", 1)
|
|
|
|
|
local_proc = subprocess.Popen(
|
|
|
|
|
local_cmd.split(" "), stdout=subprocess.PIPE, env=env_local)
|
|
|
|
|
local_cmd.split(" "),
|
|
|
|
|
stdout=subprocess.PIPE,
|
|
|
|
|
stderr=subprocess.PIPE,
|
|
|
|
|
env=env_local)
|
|
|
|
|
local_proc.wait()
|
|
|
|
|
local_ret = local_proc.stdout.read()
|
|
|
|
|
out, err = local_proc.communicate()
|
|
|
|
|
local_ret = out
|
|
|
|
|
sys.stderr.write('local_loss: %s\n' % local_ret)
|
|
|
|
|
sys.stderr.write('local_stderr: %s\n' % err)
|
|
|
|
|
|
|
|
|
|
# Run dist train to compare with local results
|
|
|
|
|
ps0, ps1 = self.start_pserver()
|
|
|
|
@ -92,13 +99,22 @@ class TestDistSeResneXt2x2(unittest.TestCase):
|
|
|
|
|
FNULL = open(os.devnull, 'w')
|
|
|
|
|
|
|
|
|
|
tr0_proc = subprocess.Popen(
|
|
|
|
|
tr0_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env0)
|
|
|
|
|
tr0_cmd.split(" "),
|
|
|
|
|
stdout=subprocess.PIPE,
|
|
|
|
|
stderr=subprocess.PIPE,
|
|
|
|
|
env=env0)
|
|
|
|
|
tr1_proc = subprocess.Popen(
|
|
|
|
|
tr1_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env1)
|
|
|
|
|
tr1_cmd.split(" "),
|
|
|
|
|
stdout=subprocess.PIPE,
|
|
|
|
|
stderr=subprocess.PIPE,
|
|
|
|
|
env=env1)
|
|
|
|
|
|
|
|
|
|
tr0_proc.wait()
|
|
|
|
|
tr1_proc.wait()
|
|
|
|
|
loss_data0 = tr0_proc.stdout.read()
|
|
|
|
|
out, err = tr0_proc.communicate()
|
|
|
|
|
sys.stderr.write('dist_stderr: %s\n' % err)
|
|
|
|
|
loss_data0 = out
|
|
|
|
|
sys.stderr.write('dist_loss: %s\n' % loss_data0)
|
|
|
|
|
lines = loss_data0.split("\n")
|
|
|
|
|
dist_first_loss = eval(lines[0].replace(" ", ","))[0]
|
|
|
|
|
dist_last_loss = eval(lines[1].replace(" ", ","))[0]
|
|
|
|
|