|
|
|
@ -16,7 +16,7 @@ import contextlib
|
|
|
|
|
from layer_function_generator import autodoc
|
|
|
|
|
from tensor import assign, fill_constant
|
|
|
|
|
from .. import core
|
|
|
|
|
from ..framework import Program, Variable, Operator, Block
|
|
|
|
|
from ..framework import Program, Variable, Operator
|
|
|
|
|
from ..layer_helper import LayerHelper, unique_name
|
|
|
|
|
from ops import logical_and, logical_not, logical_or
|
|
|
|
|
|
|
|
|
@ -29,7 +29,6 @@ __all__ = [
|
|
|
|
|
'WhileGuard',
|
|
|
|
|
'While',
|
|
|
|
|
'Switch',
|
|
|
|
|
'Select',
|
|
|
|
|
'lod_rank_table',
|
|
|
|
|
'max_sequence_len',
|
|
|
|
|
'topk',
|
|
|
|
@ -1212,186 +1211,6 @@ class Switch(object):
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SelectCase(object):
|
|
|
|
|
DEFAULT = 0
|
|
|
|
|
SEND = 1
|
|
|
|
|
RECEIVE = 2
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
case_idx,
|
|
|
|
|
case_to_execute,
|
|
|
|
|
channel_action_fn=None,
|
|
|
|
|
channel=None,
|
|
|
|
|
value=None):
|
|
|
|
|
self.helper = LayerHelper('conditional_block')
|
|
|
|
|
self.main_program = self.helper.main_program
|
|
|
|
|
self.is_scalar_condition = True
|
|
|
|
|
|
|
|
|
|
self.case_to_execute = case_to_execute
|
|
|
|
|
self.idx = case_idx
|
|
|
|
|
|
|
|
|
|
# Since we aren't going to use the `channel_send` or `channel_recv`
|
|
|
|
|
# functions directly, we just need to capture the name.
|
|
|
|
|
self.action = (self.SEND
|
|
|
|
|
if channel_action_fn.__name__ == ('channel_send') else
|
|
|
|
|
self.RECEIVE) if channel_action_fn else (self.DEFAULT)
|
|
|
|
|
self.value = value
|
|
|
|
|
self.channel = channel
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
self.block = self.main_program.create_block()
|
|
|
|
|
|
|
|
|
|
def construct_op(self):
|
|
|
|
|
main_program = self.helper.main_program
|
|
|
|
|
cases_block = main_program.current_block()
|
|
|
|
|
|
|
|
|
|
inner_outputs = set()
|
|
|
|
|
input_set = set()
|
|
|
|
|
params = set()
|
|
|
|
|
|
|
|
|
|
for op in self.block.ops:
|
|
|
|
|
# Iterate over all operators, get all the inputs
|
|
|
|
|
# and add as input to the SelectCase operator.
|
|
|
|
|
for iname in op.input_names:
|
|
|
|
|
for in_var_name in op.input(iname):
|
|
|
|
|
if in_var_name not in inner_outputs:
|
|
|
|
|
input_set.add(in_var_name)
|
|
|
|
|
|
|
|
|
|
for oname in op.output_names:
|
|
|
|
|
for out_var_name in op.output(oname):
|
|
|
|
|
inner_outputs.add(out_var_name)
|
|
|
|
|
|
|
|
|
|
param_list = [
|
|
|
|
|
cases_block.var(each_name) for each_name in params
|
|
|
|
|
if each_name not in input_set
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Iterate over all operators, get all the outputs
|
|
|
|
|
# add to the output list of SelectCase operator only if
|
|
|
|
|
# they exist in the parent block.
|
|
|
|
|
out_vars = []
|
|
|
|
|
for inner_out_name in inner_outputs:
|
|
|
|
|
if inner_out_name in cases_block.vars:
|
|
|
|
|
out_vars.append(cases_block.var(inner_out_name))
|
|
|
|
|
|
|
|
|
|
# First, create an op that will determine whether or not this is the
|
|
|
|
|
# conditional variable to execute.
|
|
|
|
|
should_execute_block = equal(
|
|
|
|
|
fill_constant(
|
|
|
|
|
shape=[1], dtype=core.VarDesc.VarType.INT32, value=self.idx),
|
|
|
|
|
self.case_to_execute)
|
|
|
|
|
|
|
|
|
|
step_scope = cases_block.create_var(
|
|
|
|
|
type=core.VarDesc.VarType.STEP_SCOPES)
|
|
|
|
|
|
|
|
|
|
cases_block.append_op(
|
|
|
|
|
type='conditional_block',
|
|
|
|
|
inputs={'X': [should_execute_block],
|
|
|
|
|
'Params': param_list},
|
|
|
|
|
outputs={'Out': out_vars,
|
|
|
|
|
'Scope': [step_scope]},
|
|
|
|
|
attrs={
|
|
|
|
|
'sub_block': self.block,
|
|
|
|
|
'is_scalar_condition': self.is_scalar_condition
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return '%s,%s,%s,%s' % (self.idx, self.action, self.channel.name
|
|
|
|
|
if self.channel else '', self.value.name
|
|
|
|
|
if self.value else '')
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
|
self.main_program.rollback()
|
|
|
|
|
if exc_type is not None:
|
|
|
|
|
return False # re-raise exception
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Select(BlockGuard):
|
|
|
|
|
def __init__(self, name=None):
|
|
|
|
|
self.helper = LayerHelper('select', name=name)
|
|
|
|
|
self.cases = []
|
|
|
|
|
|
|
|
|
|
super(Select, self).__init__(self.helper.main_program)
|
|
|
|
|
self.case_to_execute = fill_constant(
|
|
|
|
|
shape=[1], dtype=core.VarDesc.VarType.INT32, value=-1)
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
super(Select, self).__enter__()
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def case(self, channel_action_fn, channel, value):
|
|
|
|
|
"""Create a new block for this condition.
|
|
|
|
|
"""
|
|
|
|
|
select_case = SelectCase(
|
|
|
|
|
len(self.cases), self.case_to_execute, channel_action_fn, channel,
|
|
|
|
|
value)
|
|
|
|
|
|
|
|
|
|
self.cases.append(select_case)
|
|
|
|
|
|
|
|
|
|
return select_case
|
|
|
|
|
|
|
|
|
|
def default(self):
|
|
|
|
|
"""Create a default case block for this condition.
|
|
|
|
|
"""
|
|
|
|
|
default_case = SelectCase(len(self.cases), self.case_to_execute)
|
|
|
|
|
|
|
|
|
|
self.cases.append(default_case)
|
|
|
|
|
|
|
|
|
|
return default_case
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
|
if exc_type is not None:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
# Create a select op and another block to wrap its
|
|
|
|
|
# case blocks.
|
|
|
|
|
select_block = self.helper.main_program.current_block()
|
|
|
|
|
parent_block = self.helper.main_program.block(select_block.parent_idx)
|
|
|
|
|
|
|
|
|
|
# Construct each case op, inside the newly created select block.
|
|
|
|
|
serialized_cases = []
|
|
|
|
|
for case in self.cases:
|
|
|
|
|
serialized_cases.append(case.construct_op())
|
|
|
|
|
|
|
|
|
|
intermediate = set()
|
|
|
|
|
params = set()
|
|
|
|
|
|
|
|
|
|
for case_block in select_block.ops:
|
|
|
|
|
if case_block.attrs and 'sub_block' in case_block.attrs:
|
|
|
|
|
for each_op in case_block.attrs['sub_block'].ops:
|
|
|
|
|
assert isinstance(each_op, Operator)
|
|
|
|
|
for iname in each_op.input_names:
|
|
|
|
|
for in_var_name in each_op.input(iname):
|
|
|
|
|
if in_var_name not in intermediate:
|
|
|
|
|
params.add(in_var_name)
|
|
|
|
|
|
|
|
|
|
for oname in each_op.output_names:
|
|
|
|
|
for out_var_name in each_op.output(oname):
|
|
|
|
|
intermediate.add(out_var_name)
|
|
|
|
|
|
|
|
|
|
# TODO(varunarora): Figure out if defining output is needed.
|
|
|
|
|
out_list = [
|
|
|
|
|
parent_block.var(var_name) for var_name in parent_block.vars
|
|
|
|
|
if var_name in intermediate
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
X = [select_block.var_recursive(x_name) for x_name in params]
|
|
|
|
|
|
|
|
|
|
# Needs to be used by `equal` inside the cases block.
|
|
|
|
|
X.append(self.case_to_execute)
|
|
|
|
|
|
|
|
|
|
# Construct the select op.
|
|
|
|
|
parent_block.append_op(
|
|
|
|
|
type='select',
|
|
|
|
|
inputs={'X': X,
|
|
|
|
|
'case_to_execute': self.case_to_execute},
|
|
|
|
|
attrs={'sub_block': select_block,
|
|
|
|
|
'cases': serialized_cases},
|
|
|
|
|
outputs={})
|
|
|
|
|
|
|
|
|
|
return super(Select, self).__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IfElseBlockGuard(object):
|
|
|
|
|
def __init__(self, is_true, ifelse):
|
|
|
|
|
if not isinstance(ifelse, IfElse):
|
|
|
|
|