|
|
|
@ -81,7 +81,7 @@ class FleetDistHeterRunnerBase(object):
|
|
|
|
|
def build_strategy(self, args):
|
|
|
|
|
self.strategy = paddle.distributed.fleet.DistributedStrategy()
|
|
|
|
|
self.strategy.a_sync = True
|
|
|
|
|
|
|
|
|
|
self.strategy.a_sync_configs = {"launch_barrier": True}
|
|
|
|
|
return self.strategy
|
|
|
|
|
|
|
|
|
|
def build_optimizer(self, avg_cost, strategy):
|
|
|
|
@ -237,7 +237,10 @@ class TestFleetHeterBase(unittest.TestCase):
|
|
|
|
|
return heter0_proc, heter1_proc, heter0_pipe, heter1_pipe
|
|
|
|
|
|
|
|
|
|
def _run_cluster(self, model, envs):
|
|
|
|
|
env = {'GRAD_CLIP': str(self._grad_clip_mode)}
|
|
|
|
|
env = {
|
|
|
|
|
'GRAD_CLIP': str(self._grad_clip_mode),
|
|
|
|
|
'FLAGS_eager_delete_tensor_gb': str(-1)
|
|
|
|
|
}
|
|
|
|
|
python_path = self._python_interp
|
|
|
|
|
gloo_path = tempfile.mkdtemp()
|
|
|
|
|
|
|
|
|
@ -286,27 +289,6 @@ class TestFleetHeterBase(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
tr0_ret = tr0.returncode
|
|
|
|
|
tr1_ret = tr0.returncode
|
|
|
|
|
print("tr get returncode: {}".format(tr0_ret))
|
|
|
|
|
if tr0_ret != 0:
|
|
|
|
|
print(
|
|
|
|
|
"========================Error tr0_err begin==========================="
|
|
|
|
|
)
|
|
|
|
|
os.system("cat {}".format(tempfile.gettempdir() + "/tr0_err.log"))
|
|
|
|
|
print(
|
|
|
|
|
"========================Error tr0_err end==========================="
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if tr1_ret != 0:
|
|
|
|
|
print(
|
|
|
|
|
"========================Error tr1_err begin==========================="
|
|
|
|
|
)
|
|
|
|
|
os.system("cat {}".format(tempfile.gettempdir() + "/tr1_err.log"))
|
|
|
|
|
print(
|
|
|
|
|
"========================Error tr1_err end==========================="
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check")
|
|
|
|
|
self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check")
|
|
|
|
|
|
|
|
|
|
# close trainer file
|
|
|
|
|
tr0_pipe.close()
|
|
|
|
@ -320,7 +302,8 @@ class TestFleetHeterBase(unittest.TestCase):
|
|
|
|
|
ps1.terminate()
|
|
|
|
|
heter0.terminate()
|
|
|
|
|
heter1.terminate()
|
|
|
|
|
|
|
|
|
|
self.assertEqual(tr0_ret, 0, "something wrong in tr0, please check")
|
|
|
|
|
self.assertEqual(tr1_ret, 0, "something wrong in tr1, please check")
|
|
|
|
|
shutil.rmtree(gloo_path)
|
|
|
|
|
return 0, 0
|
|
|
|
|
|
|
|
|
|