|
|
|
@ -32,7 +32,7 @@ DEFAULT_BATCH_SIZE = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestDistRunnerBase(object):
|
|
|
|
|
def get_model(self, batch_size=DEFAULT_BATCH_SIZE):
|
|
|
|
|
def get_model(self, batch_size=DEFAULT_BATCH_SIZE, lr=0.1):
|
|
|
|
|
raise NotImplementedError(
|
|
|
|
|
"get_model should be implemented by child classes.")
|
|
|
|
|
|
|
|
|
@ -56,6 +56,7 @@ class TestDistRunnerBase(object):
|
|
|
|
|
return t
|
|
|
|
|
|
|
|
|
|
def run_pserver(self, args):
|
|
|
|
|
self.lr = args.lr
|
|
|
|
|
self.get_model(batch_size=args.batch_size)
|
|
|
|
|
# NOTE: pserver should not call memory optimize
|
|
|
|
|
t = self.get_transpiler(args.trainer_id,
|
|
|
|
@ -71,6 +72,7 @@ class TestDistRunnerBase(object):
|
|
|
|
|
exe.run(pserver_prog)
|
|
|
|
|
|
|
|
|
|
def run_trainer(self, args):
|
|
|
|
|
self.lr = args.lr
|
|
|
|
|
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
|
|
|
|
|
self.get_model(batch_size=args.batch_size)
|
|
|
|
|
|
|
|
|
@ -189,6 +191,7 @@ def runtime_main(test_class):
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--use_reader_alloc', action='store_true', required=False)
|
|
|
|
|
parser.add_argument('--batch_size', required=False, type=int, default=2)
|
|
|
|
|
parser.add_argument('--lr', required=False, type=float, default=0.001)
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
'--batch_merge_repeat', required=False, type=int, default=1)
|
|
|
|
|
|
|
|
|
@ -234,6 +237,7 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
self._dc_asgd = False # must use with async mode
|
|
|
|
|
self._use_reader_alloc = True
|
|
|
|
|
self._nccl2_mode = False
|
|
|
|
|
self._lr = 0.001
|
|
|
|
|
self._setup_config()
|
|
|
|
|
self._after_setup_config()
|
|
|
|
|
|
|
|
|
@ -284,7 +288,8 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
batch_size=DEFAULT_BATCH_SIZE,
|
|
|
|
|
batch_merge_repeat=1):
|
|
|
|
|
|
|
|
|
|
cmd = "%s %s --role trainer" % (self._python_interp, model)
|
|
|
|
|
cmd = "%s %s --role trainer --lr %f" % (self._python_interp, model,
|
|
|
|
|
self._lr)
|
|
|
|
|
if batch_size != DEFAULT_BATCH_SIZE:
|
|
|
|
|
cmd += " --batch_size %d" % batch_size
|
|
|
|
|
if batch_merge_repeat > 1:
|
|
|
|
@ -330,13 +335,13 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
ps0_ep, ps1_ep = self._ps_endpoints.split(",")
|
|
|
|
|
|
|
|
|
|
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver"
|
|
|
|
|
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --update_method pserver --lr %f"
|
|
|
|
|
tr0_cmd = tr_cmd % \
|
|
|
|
|
(self._python_interp, model, self._ps_endpoints,
|
|
|
|
|
0, ps0_ep, self._trainers)
|
|
|
|
|
0, ps0_ep, self._trainers, self._lr)
|
|
|
|
|
tr1_cmd = tr_cmd % \
|
|
|
|
|
(self._python_interp, model, self._ps_endpoints,
|
|
|
|
|
1, ps1_ep, self._trainers)
|
|
|
|
|
1, ps1_ep, self._trainers, self._lr)
|
|
|
|
|
|
|
|
|
|
if self._sync_mode:
|
|
|
|
|
tr0_cmd += " --sync_mode"
|
|
|
|
@ -425,13 +430,13 @@ class TestDistBase(unittest.TestCase):
|
|
|
|
|
worker_endpoints = self._ps_endpoints.split(",")
|
|
|
|
|
w0_ep, w1_ep = worker_endpoints
|
|
|
|
|
|
|
|
|
|
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2"
|
|
|
|
|
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --update_method nccl2 --lr %f"
|
|
|
|
|
tr0_cmd = tr_cmd % \
|
|
|
|
|
(self._python_interp, model, self._ps_endpoints,
|
|
|
|
|
0, w0_ep)
|
|
|
|
|
0, w0_ep, self._lr / 2)
|
|
|
|
|
tr1_cmd = tr_cmd % \
|
|
|
|
|
(self._python_interp, model, self._ps_endpoints,
|
|
|
|
|
1, w1_ep)
|
|
|
|
|
1, w1_ep, self._lr / 2)
|
|
|
|
|
|
|
|
|
|
if self._mem_opt:
|
|
|
|
|
tr0_cmd += " --mem_opt"
|
|
|
|
|