add fake init test in test_dist_transpiler

fix_recordio_link
Qiao Longfei 6 years ago
parent a13c788a04
commit 68aeb4e7e9

@ -68,9 +68,10 @@ class FakeInitOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
AddComment(R"DOC(
FakeInitBatchSizeLike Operator.
FakeInit Operator.
Init an op but not alloc tensor for it, it is used for distributed lookup table.
Init an variable but not alloc memory for it, it is used for init the
table parameter at trainer side in distributed lookup table.
)DOC");
}

@ -497,7 +497,7 @@ class TestDistLookupTable(TestDistLookupTableBase):
# 5 save table
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
trainer, _ = self.get_trainer()
trainer, trainer_startup = self.get_trainer()
self.assertEqual(len(trainer.blocks), 1)
ops = [
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool', 'split_ids',
@ -511,6 +511,16 @@ class TestDistLookupTable(TestDistLookupTableBase):
]
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
startup_ops = [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'uniform_random', 'recv', 'recv',
'fetch_barrier', 'fake_init'
]
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
startup_ops)
class TestAsyncLocalLookupTable(TestDistLookupTableBase):
def net_conf(self):

Loading…
Cancel
Save