Python implementation for a proposed Go Op. (#8434)
* Adding Python boilerplate code for Go op * Add very basic test case * Adding the python logic for go routine * Fix syntax * Changing test to notest * Rename Routine to Go * Combining GoGuard and Go in one class * Modify test * Adding fluid close channel * Fixing __init__.py for calling fluid.go() * Adding stubs for channel methods and updating test case * Removing import * * Adding imports from concurrencyemailweixu-patch-1
parent
89ead8d151
commit
74404fadcd
@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# TODO: Variables: make_channel
|
||||
# TODO: Operators: send, close_channel, recv, go, select
|
||||
from layers.control_flow import BlockGuard
|
||||
from layer_helper import LayerHelper
|
||||
|
||||
__all__ = [
|
||||
'Go',
|
||||
'make_channel',
|
||||
'channel_send',
|
||||
'channel_recv',
|
||||
'channel_close',
|
||||
]
|
||||
|
||||
|
||||
class Go(BlockGuard):
|
||||
def __init__(self, name=None):
|
||||
self.helper = LayerHelper("go", name=name)
|
||||
super(Go, self).__init__(self.helper.main_program)
|
||||
|
||||
def __enter__(self):
|
||||
super(Go, self).__enter__()
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is not None:
|
||||
return False
|
||||
self.construct_go_op()
|
||||
return super(Go, self).__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
def construct_go_op(self):
|
||||
main_program = self.helper.main_program
|
||||
go_block = main_program.current_block()
|
||||
parent_block = main_program.block(main_program.current_block()
|
||||
.parent_idx)
|
||||
|
||||
x_name_list = set()
|
||||
out_vars = set()
|
||||
for op in go_block.ops:
|
||||
# Iterate over all operators, get all the inputs
|
||||
# and add as input to the Go operator.
|
||||
for iname in op.input_names:
|
||||
for in_var_name in op.input(iname):
|
||||
x_name_list.add(in_var_name)
|
||||
|
||||
# Iterate over all operators , get all the outputs
|
||||
# add to the output list of Go operator only if
|
||||
# they exist in the parent block.
|
||||
for oname in op.output_names:
|
||||
for out_var_name in op.output(oname):
|
||||
if out_var_name in parent_block.vars:
|
||||
out_vars.add(parent_block.var(out_var_name))
|
||||
|
||||
parent_block.append_op(
|
||||
type='go',
|
||||
inputs={'X': [parent_block.var(x_name) for x_name in x_name_list]},
|
||||
outputs={'Out': out_vars},
|
||||
attrs={'sub_block': go_block})
|
||||
|
||||
|
||||
def make_channel(dtype, size=0):
|
||||
return True
|
||||
|
||||
|
||||
def channel_send(channel, value):
|
||||
return True
|
||||
|
||||
|
||||
def channel_recv(channel):
|
||||
return True
|
||||
|
||||
|
||||
def channel_close(channel):
|
||||
return True
|
||||
@ -0,0 +1,38 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import paddle.v2.fluid as fluid
|
||||
import paddle.v2.fluid.core as core
|
||||
from paddle.v2.fluid.executor import Executor
|
||||
|
||||
|
||||
class TestRoutineOp(unittest.TestCase):
|
||||
def test_simple_routine(self):
|
||||
ch = fluid.make_channel(dtype=bool)
|
||||
with fluid.Go():
|
||||
fluid.channel_send(ch, True)
|
||||
|
||||
result = fluid.channel_recv(ch)
|
||||
fluid.channel_close(ch)
|
||||
|
||||
cpu = core.CPUPlace()
|
||||
exe = Executor(cpu)
|
||||
|
||||
outs = exe.run(fetch_list=[result])
|
||||
self.assertEqual(outs[0], True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Loading…
Reference in new issue