|
|
@ -51,17 +51,17 @@ class TranspilerTest(unittest.TestCase):
|
|
|
|
self.origin_prog = main.clone()
|
|
|
|
self.origin_prog = main.clone()
|
|
|
|
return main
|
|
|
|
return main
|
|
|
|
|
|
|
|
|
|
|
|
def get_trainer(self, config=None):
|
|
|
|
def get_trainer(self, config=None, sync_mode=True):
|
|
|
|
t = self._transpiler_instance(config)
|
|
|
|
t = self._transpiler_instance(config, sync_mode)
|
|
|
|
return t.get_trainer_program()
|
|
|
|
return t.get_trainer_program()
|
|
|
|
|
|
|
|
|
|
|
|
def get_pserver(self, ep, config=None):
|
|
|
|
def get_pserver(self, ep, config=None, sync_mode=True):
|
|
|
|
t = self._transpiler_instance(config)
|
|
|
|
t = self._transpiler_instance(config, sync_mode)
|
|
|
|
pserver = t.get_pserver_program(ep)
|
|
|
|
pserver = t.get_pserver_program(ep)
|
|
|
|
startup = t.get_startup_program(ep, pserver)
|
|
|
|
startup = t.get_startup_program(ep, pserver)
|
|
|
|
return pserver, startup
|
|
|
|
return pserver, startup
|
|
|
|
|
|
|
|
|
|
|
|
def _transpiler_instance(self, config=None):
|
|
|
|
def _transpiler_instance(self, config=None, sync_mode=True):
|
|
|
|
if not self.transpiler:
|
|
|
|
if not self.transpiler:
|
|
|
|
main = self.get_main_program()
|
|
|
|
main = self.get_main_program()
|
|
|
|
self.transpiler = fluid.DistributeTranspiler(config=config)
|
|
|
|
self.transpiler = fluid.DistributeTranspiler(config=config)
|
|
|
@ -69,7 +69,8 @@ class TranspilerTest(unittest.TestCase):
|
|
|
|
self.trainer_id,
|
|
|
|
self.trainer_id,
|
|
|
|
program=main,
|
|
|
|
program=main,
|
|
|
|
pservers=self.pserver_eps,
|
|
|
|
pservers=self.pserver_eps,
|
|
|
|
trainers=self.trainers)
|
|
|
|
trainers=self.trainers,
|
|
|
|
|
|
|
|
sync_mode=sync_mode)
|
|
|
|
|
|
|
|
|
|
|
|
return self.transpiler
|
|
|
|
return self.transpiler
|
|
|
|
|
|
|
|
|
|
|
@ -470,8 +471,7 @@ class TestAsyncLocalLookupTable(TestDistLookupTableBase):
|
|
|
|
|
|
|
|
|
|
|
|
def transpiler_test_impl(self):
|
|
|
|
def transpiler_test_impl(self):
|
|
|
|
config = fluid.DistributeTranspilerConfig()
|
|
|
|
config = fluid.DistributeTranspilerConfig()
|
|
|
|
config.sync_mode = False
|
|
|
|
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
|
|
|
|
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.assertEqual(len(pserver1.blocks), 3)
|
|
|
|
self.assertEqual(len(pserver1.blocks), 3)
|
|
|
|
# 0 listen_and_serv
|
|
|
|
# 0 listen_and_serv
|
|
|
@ -503,9 +503,8 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
|
|
|
|
|
|
|
|
|
|
|
|
def transpiler_test_impl(self):
|
|
|
|
def transpiler_test_impl(self):
|
|
|
|
config = fluid.DistributeTranspilerConfig()
|
|
|
|
config = fluid.DistributeTranspilerConfig()
|
|
|
|
config.sync_mode = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
|
|
|
|
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config, False)
|
|
|
|
|
|
|
|
|
|
|
|
self.assertEqual(len(pserver1.blocks), 6)
|
|
|
|
self.assertEqual(len(pserver1.blocks), 6)
|
|
|
|
# 0 listen_and_serv
|
|
|
|
# 0 listen_and_serv
|
|
|
@ -525,7 +524,6 @@ class TestAsyncDistLookupTable(TestDistLookupTableBase):
|
|
|
|
|
|
|
|
|
|
|
|
trainer = self.get_trainer(config)
|
|
|
|
trainer = self.get_trainer(config)
|
|
|
|
self.assertEqual(len(trainer.blocks), 1)
|
|
|
|
self.assertEqual(len(trainer.blocks), 1)
|
|
|
|
print([op.type for op in trainer.blocks[0].ops])
|
|
|
|
|
|
|
|
ops = [
|
|
|
|
ops = [
|
|
|
|
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
|
|
|
|
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
|
|
|
|
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',
|
|
|
|
'prefetch', 'merge_ids', 'sequence_pool', 'concat', 'mul',
|
|
|
|