You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
2.7 KiB
87 lines
2.7 KiB
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
# TODO: Variables: make_channel
|
|
# TODO: Operators: send, close_channel, recv, go, select
|
|
from layers.control_flow import BlockGuard
|
|
from layer_helper import LayerHelper
|
|
|
|
__all__ = [
|
|
'Go',
|
|
'make_channel',
|
|
'channel_send',
|
|
'channel_recv',
|
|
'channel_close',
|
|
]
|
|
|
|
|
|
class Go(BlockGuard):
|
|
def __init__(self, name=None):
|
|
self.helper = LayerHelper("go", name=name)
|
|
super(Go, self).__init__(self.helper.main_program)
|
|
|
|
def __enter__(self):
|
|
super(Go, self).__enter__()
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if exc_type is not None:
|
|
return False
|
|
self.construct_go_op()
|
|
return super(Go, self).__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
def construct_go_op(self):
|
|
main_program = self.helper.main_program
|
|
go_block = main_program.current_block()
|
|
parent_block = main_program.block(main_program.current_block()
|
|
.parent_idx)
|
|
|
|
x_name_list = set()
|
|
out_vars = set()
|
|
for op in go_block.ops:
|
|
# Iterate over all operators, get all the inputs
|
|
# and add as input to the Go operator.
|
|
for iname in op.input_names:
|
|
for in_var_name in op.input(iname):
|
|
x_name_list.add(in_var_name)
|
|
|
|
# Iterate over all operators , get all the outputs
|
|
# add to the output list of Go operator only if
|
|
# they exist in the parent block.
|
|
for oname in op.output_names:
|
|
for out_var_name in op.output(oname):
|
|
if out_var_name in parent_block.vars:
|
|
out_vars.add(parent_block.var(out_var_name))
|
|
|
|
parent_block.append_op(
|
|
type='go',
|
|
inputs={'X': [parent_block.var(x_name) for x_name in x_name_list]},
|
|
outputs={'Out': out_vars},
|
|
attrs={'sub_block': go_block})
|
|
|
|
|
|
def make_channel(dtype, size=0):
|
|
return True
|
|
|
|
|
|
def channel_send(channel, value):
|
|
return True
|
|
|
|
|
|
def channel_recv(channel):
|
|
return True
|
|
|
|
|
|
def channel_close(channel):
|
|
return True
|