@ -82,11 +82,14 @@ class SelectCase(object):
RECEIVE = 2
def __init__ ( self ,
select ,
case_idx ,
case_to_execute ,
channel_action_fn = None ,
channel = None ,
value = None ) :
value = None ,
is_copy = False ) :
self . select = select
self . helper = LayerHelper ( ' conditional_block ' )
self . main_program = self . helper . main_program
self . is_scalar_condition = True
@ -99,7 +102,24 @@ class SelectCase(object):
self . action = ( self . SEND
if channel_action_fn . __name__ == ( ' channel_send ' ) else
self . RECEIVE ) if channel_action_fn else self . DEFAULT
self . value = value
X = value
if self . action == self . SEND and is_copy :
# We create of copy of the data we want to send
copied_X = self . select . parent_block . create_var (
name = unique_name . generate ( value . name + ' _copy ' ) ,
type = value . type ,
dtype = value . dtype ,
shape = value . shape ,
lod_level = value . lod_level ,
capacity = value . capacity
if hasattr ( value , ' capacity ' ) else None , )
self . select . parent_block . append_op (
type = " assign " , inputs = { " X " : value } , outputs = { " Out " : copied_X } )
X = copied_X
self . value = X
self . channel = channel
def __enter__ ( self ) :
@ -173,6 +193,7 @@ class SelectCase(object):
class Select ( BlockGuard ) :
def __init__ ( self , name = None ) :
self . helper = LayerHelper ( ' select ' , name = name )
self . parent_block = self . helper . main_program . current_block ( )
self . cases = [ ]
super ( Select , self ) . __init__ ( self . helper . main_program )
@ -183,12 +204,12 @@ class Select(BlockGuard):
super ( Select , self ) . __enter__ ( )
return self
def case ( self , channel_action_fn , channel , value ):
def case ( self , channel_action_fn , channel , value , is_copy = False ):
""" Create a new block for this condition.
"""
select_case = SelectCase (
len ( self . cases ) , self . case_to_execute , channel_action_fn , channel ,
value )
select_case = SelectCase ( self ,
len ( self . cases ) , self . case_to_execute ,
channel_action_fn , channel , value , is_copy )
self . cases . append ( select_case )
@ -197,7 +218,7 @@ class Select(BlockGuard):
def default ( self ) :
""" Create a default case block for this condition.
"""
default_case = SelectCase ( len ( self . cases ) , self . case_to_execute )
default_case = SelectCase ( self , len ( self . cases ) , self . case_to_execute )
self . cases . append ( default_case )
@ -341,17 +362,17 @@ def channel_send(channel, value, is_copy=False):
X = value
if is_copy is True :
if is_copy :
copied_X = helper . create_variable (
name = unique_name . generate ( value . name + ' _copy ' ) ,
type = value . type ,
dtype = value . dtype ,
shape = value . shape ,
lod_level = value . lod_level ,
capacity = value . capacity )
capacity = value . capacity if hasattr ( value , ' capacity ' ) else None )
assign_op = channel_send_block . append_op (
type = " assign _op " , inputs = { " X " : value } , outputs = { " Out " : copied_X } )
type = " assign " , inputs = { " X " : value } , outputs = { " Out " : copied_X } )
X = copied_X
channel_send_block . append_op (