|
|
@ -438,7 +438,7 @@ class TestLocalLookupTable(TestDistLookupTableBase):
|
|
|
|
# 2 optimize for table adam
|
|
|
|
# 2 optimize for table adam
|
|
|
|
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
|
|
|
|
# NOTE: if param is not selected rows, the grad will scaled to grad / trainer_num
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
|
|
|
|
self.assertEqual([op.type for op in pserver1.blocks[2].ops],
|
|
|
|
["sum", "adam", "scale", "scale"])
|
|
|
|
["sum", "scale", "adam", "scale", "scale"])
|
|
|
|
|
|
|
|
|
|
|
|
trainer, _ = self.get_trainer()
|
|
|
|
trainer, _ = self.get_trainer()
|
|
|
|
self.assertEqual(len(trainer.blocks), 1)
|
|
|
|
self.assertEqual(len(trainer.blocks), 1)
|
|
|
|