@ -3609,18 +3609,18 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
This operator is like a C + + switch / case statement .
Args :
branch_index ( Variable ) : A Tensor with shape [ 1 ] to specify which branch to execute . The data type is ` ` int32 ` ` , ` ` int64 ` ` or ` ` uint8 ` ` .
branch_index ( Tensor ) : A Tensor with shape [ 1 ] to specify which branch to execute . The data type is ` ` int32 ` ` , ` ` int64 ` ` or ` ` uint8 ` ` .
branch_fns ( dict | list | tuple ) : If it ' s a list or tuple, the elements in it could be pairs of (int, callable) or simple callables whose actual index will be used as the index of callable. If it ' s a dict , its key is a python integer and the value is a callable . All callables return the same structure of Tensors .
default ( callable , optional ) : Callable that returns a structure of Tensors .
name ( str , optional ) : The default value is None . Normally there is no need for user to set this property . For more information , please refer to : ref : ` api_guide_Name ` .
Returns :
Variable| list ( Variable ) : Tensors returned by the callable specified by ` ` branch_index ` ` in ` ` branch_fns ` ` ,
Tensor| list ( Tensor ) : Tensors returned by the callable specified by ` ` branch_index ` ` in ` ` branch_fns ` ` ,
or Tensors returned by ` ` default ` ` if ` ` default ` ` is not None and no index matches in ` ` branch_fns ` ` ,
or Tensors returned by the callable with the max index in ` ` branch_fns ` ` if ` ` default ` ` is None and no index matches in ` ` branch_fns ` ` .
Raises :
TypeError : If the type of ` ` branch_index ` ` is not Variable .
TypeError : If the type of ` ` branch_index ` ` is not Tensor .
TypeError : If the data type of ` ` branch_index ` ` is not ` ` int32 ` ` , ` ` int64 ` ` or ` ` uint8 ` ` .
TypeError : If the type of ` ` branch_fns ` ` is not dict , list or tuple .
TypeError : If the elements of ` ` branch_fns ` ` is not 2 - tuple .
@ -3632,40 +3632,41 @@ def switch_case(branch_index, branch_fns, default=None, name=None):
Examples :
. . code - block : : python
import paddle . fluid as fluid
import paddle . fluid . layers as layers
import paddle
paddle . enable_static ( )
def fn_1 ( ) :
return layers . fill_constant ( shape = [ 1 , 2 ] , dtype = ' float32 ' , value = 1 )
return paddle . fill_constant ( shape = [ 1 , 2 ] , dtype = ' float32 ' , value = 1 )
def fn_2 ( ) :
return layers . fill_constant ( shape = [ 2 , 2 ] , dtype = ' int32 ' , value = 2 )
return paddle . fill_constant ( shape = [ 2 , 2 ] , dtype = ' int32 ' , value = 2 )
def fn_3 ( ) :
return layers . fill_constant ( shape = [ 3 ] , dtype = ' int32 ' , value = 3 )
return paddle . fill_constant ( shape = [ 3 ] , dtype = ' int32 ' , value = 3 )
main_program = fluid . default_startup_program ( )
startup_program = fluid . default_main_program ( )
with fluid . program_guard ( main_program , startup_program ) :
index_1 = layers . fill_constant ( shape = [ 1 ] , dtype = ' int32 ' , value = 1 )
index_2 = layers . fill_constant ( shape = [ 1 ] , dtype = ' int32 ' , value = 2 )
main_program = paddle. static . default_startup_program ( )
startup_program = paddle. static . default_main_program ( )
with paddle. static . program_guard ( main_program , startup_program ) :
index_1 = paddle . fill_constant ( shape = [ 1 ] , dtype = ' int32 ' , value = 1 )
index_2 = paddle . fill_constant ( shape = [ 1 ] , dtype = ' int32 ' , value = 2 )
out_1 = layers . switch_case (
out_1 = paddle. static . nn . switch_case (
branch_index = index_1 ,
branch_fns = { 1 : fn_1 , 2 : fn_2 } ,
default = fn_3 )
out_2 = layers . switch_case (
out_2 = paddle. static . nn . switch_case (
branch_index = index_2 ,
branch_fns = [ ( 1 , fn_1 ) , ( 2 , fn_2 ) ] ,
default = fn_3 )
# Argument default is None and no index matches. fn_3 will be called because of the max index 7.
out_3 = layers . switch_case (
out_3 = paddle. static . nn . switch_case (
branch_index = index_2 ,
branch_fns = [ ( 0 , fn_1 ) , ( 4 , fn_2 ) , ( 7 , fn_3 ) ] )
exe = fluid. Executor ( fluid . CPUPlace ( ) )
exe = paddle. static . Executor ( paddle . CPUPlace ( ) )
res_1 , res_2 , res_3 = exe . run ( main_program , fetch_list = [ out_1 , out_2 , out_3 ] )
print ( res_1 ) # [[1. 1.]]
print ( res_2 ) # [[2 2] [2 2]]