|
|
|
@ -49,7 +49,6 @@ class TestDistTranspiler(unittest.TestCase):
|
|
|
|
|
def test_transpiler(self):
|
|
|
|
|
trainer = self.get_trainer()
|
|
|
|
|
pserver, startup = self.get_pserver(self.current_pserver_ep)
|
|
|
|
|
|
|
|
|
|
self.assertEqual([op.type for op in trainer.global_block().ops],
|
|
|
|
|
self.get_expect_trainer_ops())
|
|
|
|
|
|
|
|
|
@ -67,7 +66,7 @@ class TestDistTranspiler(unittest.TestCase):
|
|
|
|
|
"fill_constant", "fill_constant", "uniform_random", "uniform_random"
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# the variable #fc_w will be split into two blocks
|
|
|
|
|
# the variable #fc_w will be split into two blocks
|
|
|
|
|
fc_w_var = startup.global_block().var("fc_w.block1")
|
|
|
|
|
self.assertEqual(fc_w_var.shape, (500, 1000))
|
|
|
|
|
|
|
|
|
@ -86,8 +85,12 @@ class TestDistTranspiler(unittest.TestCase):
|
|
|
|
|
optimize_ops, params_grads = self.net_conf()
|
|
|
|
|
|
|
|
|
|
delete_ops(trainer.global_block(), optimize_ops)
|
|
|
|
|
return [op.type for op in trainer.global_block().ops
|
|
|
|
|
] + ["split_byref", "send", "concat"]
|
|
|
|
|
ops = [op.type for op in trainer.global_block().ops] + [
|
|
|
|
|
"split_byref", "send_vars", "send_barrier", "recv", "recv",
|
|
|
|
|
"fetch_barrier", "concat"
|
|
|
|
|
]
|
|
|
|
|
ops.insert(ops.index("elementwise_add_grad") + 1, "send_vars")
|
|
|
|
|
return ops
|
|
|
|
|
|
|
|
|
|
def get_trainer(self):
|
|
|
|
|
return self._transpiler_instance().get_trainer_program()
|
|
|
|
|