|
|
|
@ -217,6 +217,57 @@ class TestRoutineOp(unittest.TestCase):
|
|
|
|
|
exe_result = exe.run(fetch_list=[result])
|
|
|
|
|
self.assertEqual(exe_result[0][0], 34)
|
|
|
|
|
|
|
|
|
|
def test_ping_pong(self):
|
|
|
|
|
"""
|
|
|
|
|
Mimics Ping Pong example: https://gobyexample.com/channel-directions
|
|
|
|
|
"""
|
|
|
|
|
with framework.program_guard(framework.Program()):
|
|
|
|
|
result = self._create_tensor('return_value',
|
|
|
|
|
core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
core.VarDesc.VarType.FP64)
|
|
|
|
|
|
|
|
|
|
ping_result = self._create_tensor('ping_return_value',
|
|
|
|
|
core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
core.VarDesc.VarType.FP64)
|
|
|
|
|
|
|
|
|
|
pong_result = self._create_tensor('pong_return_value',
|
|
|
|
|
core.VarDesc.VarType.LOD_TENSOR,
|
|
|
|
|
core.VarDesc.VarType.FP64)
|
|
|
|
|
|
|
|
|
|
def ping(ch, message):
|
|
|
|
|
message_to_send_tmp = fill_constant(
|
|
|
|
|
shape=[1], dtype=core.VarDesc.VarType.FP64, value=0)
|
|
|
|
|
|
|
|
|
|
assign(input=message, output=message_to_send_tmp)
|
|
|
|
|
fluid.channel_send(ch, message_to_send_tmp)
|
|
|
|
|
|
|
|
|
|
def pong(ch1, ch2):
|
|
|
|
|
fluid.channel_recv(ch1, ping_result)
|
|
|
|
|
assign(input=ping_result, output=pong_result)
|
|
|
|
|
fluid.channel_send(ch2, pong_result)
|
|
|
|
|
|
|
|
|
|
pings = fluid.make_channel(
|
|
|
|
|
dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1)
|
|
|
|
|
pongs = fluid.make_channel(
|
|
|
|
|
dtype=core.VarDesc.VarType.LOD_TENSOR, capacity=1)
|
|
|
|
|
|
|
|
|
|
msg = fill_constant(
|
|
|
|
|
shape=[1], dtype=core.VarDesc.VarType.FP64, value=9)
|
|
|
|
|
|
|
|
|
|
ping(pings, msg)
|
|
|
|
|
pong(pings, pongs)
|
|
|
|
|
|
|
|
|
|
fluid.channel_recv(pongs, result)
|
|
|
|
|
|
|
|
|
|
fluid.channel_close(pings)
|
|
|
|
|
fluid.channel_close(pongs)
|
|
|
|
|
|
|
|
|
|
cpu = core.CPUPlace()
|
|
|
|
|
exe = Executor(cpu)
|
|
|
|
|
|
|
|
|
|
exe_result = exe.run(fetch_list=[result])
|
|
|
|
|
self.assertEqual(exe_result[0][0], 9)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|