|
|
|
@ -870,9 +870,21 @@ class TestRemoteNce(TestDistLookupTableBase):
|
|
|
|
|
|
|
|
|
|
def transpiler_test_impl(self):
|
|
|
|
|
trainer, _ = self.get_trainer()
|
|
|
|
|
|
|
|
|
|
out_vars = ["nce_w.block0", "nce_w.block1"]
|
|
|
|
|
in_vars = ["nce_b.block0", "nce_b.block1"]
|
|
|
|
|
|
|
|
|
|
recv_var_names = []
|
|
|
|
|
|
|
|
|
|
for op in trainer.blocks[0].ops:
|
|
|
|
|
if op.type == "recv":
|
|
|
|
|
pass
|
|
|
|
|
for var in op.output("Out"):
|
|
|
|
|
recv_var_names.append(var)
|
|
|
|
|
|
|
|
|
|
for out_var in out_vars:
|
|
|
|
|
self.assertFalse(out_var in recv_var_names)
|
|
|
|
|
for in_var in in_vars:
|
|
|
|
|
self.assertTrue(in_var in recv_var_names)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|