|
|
|
@ -442,10 +442,10 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f"
|
|
|
|
|
tr0_cmd = tr_cmd % \
|
|
|
|
|
(self._python_interp, model, self._ps_endpoints,
|
|
|
|
|
0, w0_ep, self._lr / 2)
|
|
|
|
|
0, w0_ep, self._lr)
|
|
|
|
|
tr1_cmd = tr_cmd % \
|
|
|
|
|
(self._python_interp, model, self._ps_endpoints,
|
|
|
|
|
1, w1_ep, self._lr / 2)
|
|
|
|
|
1, w1_ep, self._lr)
|
|
|
|
|
|
|
|
|
|
if self._mem_opt:
|
|
|
|
|
tr0_cmd += " --mem_opt"
|
|
|
|
|