|
|
|
@ -2488,9 +2488,6 @@ def _error_message(what, arg_name, op_name, right_value, error_value):
|
|
|
|
|
def case(pred_fn_pairs, default=None, name=None):
|
|
|
|
|
'''
|
|
|
|
|
:api_attr: Static Graph
|
|
|
|
|
:alias_main: paddle.nn.case
|
|
|
|
|
:alias: paddle.nn.case,paddle.nn.control_flow.case
|
|
|
|
|
:old_api: paddle.fluid.layers.case
|
|
|
|
|
|
|
|
|
|
This operator works like an if-elif-elif-else chain.
|
|
|
|
|
|
|
|
|
@ -2500,7 +2497,7 @@ def case(pred_fn_pairs, default=None, name=None):
|
|
|
|
|
name(str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Variable|list(Variable): Tensors returned by the callable from the first pair whose pred is True,
|
|
|
|
|
Tensor|list(Tensor): Tensors returned by the callable from the first pair whose pred is True,
|
|
|
|
|
or Tensors returned by ``default`` if no pred in ``pred_fn_pairs`` is True and ``default`` is not None,
|
|
|
|
|
or Tensors returned by the last callable in ``pred_fn_pairs`` if no pred in ``pred_fn_pairs`` is True and ``default`` is None.
|
|
|
|
|
|
|
|
|
@ -2508,45 +2505,47 @@ def case(pred_fn_pairs, default=None, name=None):
|
|
|
|
|
TypeError: If the type of ``pred_fn_pairs`` is not list or tuple.
|
|
|
|
|
TypeError: If the type of elements in ``pred_fn_pairs`` is not tuple.
|
|
|
|
|
TypeError: If the size of tuples in ``pred_fn_pairs`` is not 2.
|
|
|
|
|
TypeError: If the first element of 2-tuple in ``pred_fn_pairs`` is not Variable.
|
|
|
|
|
TypeError: If the first element of 2-tuple in ``pred_fn_pairs`` is not a Tensor.
|
|
|
|
|
TypeError: If the second element of 2-tuple in ``pred_fn_pairs`` is not callable.
|
|
|
|
|
TypeError: If ``default`` is not None but it is not callable.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddle.fluid.layers as layers
|
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
|
|
paddle.enable_static()
|
|
|
|
|
|
|
|
|
|
def fn_1():
|
|
|
|
|
return layers.fill_constant(shape=[1, 2], dtype='float32', value=1)
|
|
|
|
|
return paddle.fill_constant(shape=[1, 2], dtype='float32', value=1)
|
|
|
|
|
|
|
|
|
|
def fn_2():
|
|
|
|
|
return layers.fill_constant(shape=[2, 2], dtype='int32', value=2)
|
|
|
|
|
return paddle.fill_constant(shape=[2, 2], dtype='int32', value=2)
|
|
|
|
|
|
|
|
|
|
def fn_3():
|
|
|
|
|
return layers.fill_constant(shape=[3], dtype='int32', value=3)
|
|
|
|
|
return paddle.fill_constant(shape=[3], dtype='int32', value=3)
|
|
|
|
|
|
|
|
|
|
main_program = fluid.default_startup_program()
|
|
|
|
|
startup_program = fluid.default_main_program()
|
|
|
|
|
with fluid.program_guard(main_program, startup_program):
|
|
|
|
|
x = layers.fill_constant(shape=[1], dtype='float32', value=0.3)
|
|
|
|
|
y = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
|
|
|
|
|
z = layers.fill_constant(shape=[1], dtype='float32', value=0.2)
|
|
|
|
|
main_program = paddle.static.default_startup_program()
|
|
|
|
|
startup_program = paddle.static.default_main_program()
|
|
|
|
|
|
|
|
|
|
pred_1 = layers.less_than(z, x) # true: 0.2 < 0.3
|
|
|
|
|
pred_2 = layers.less_than(x, y) # false: 0.3 < 0.1
|
|
|
|
|
pred_3 = layers.equal(x, y) # false: 0.3 == 0.1
|
|
|
|
|
with paddle.static.program_guard(main_program, startup_program):
|
|
|
|
|
x = paddle.fill_constant(shape=[1], dtype='float32', value=0.3)
|
|
|
|
|
y = paddle.fill_constant(shape=[1], dtype='float32', value=0.1)
|
|
|
|
|
z = paddle.fill_constant(shape=[1], dtype='float32', value=0.2)
|
|
|
|
|
|
|
|
|
|
pred_1 = paddle.less_than(z, x) # true: 0.2 < 0.3
|
|
|
|
|
pred_2 = paddle.less_than(x, y) # false: 0.3 < 0.1
|
|
|
|
|
pred_3 = paddle.equal(x, y) # false: 0.3 == 0.1
|
|
|
|
|
|
|
|
|
|
# Call fn_1 because pred_1 is True
|
|
|
|
|
out_1 = layers.case(
|
|
|
|
|
out_1 = paddle.static.nn.case(
|
|
|
|
|
pred_fn_pairs=[(pred_1, fn_1), (pred_2, fn_2)], default=fn_3)
|
|
|
|
|
|
|
|
|
|
# Argument default is None and no pred in pred_fn_pairs is True. fn_3 will be called.
|
|
|
|
|
# because fn_3 is the last callable in pred_fn_pairs.
|
|
|
|
|
out_2 = layers.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
|
|
|
|
|
out_2 = paddle.static.nn.case(pred_fn_pairs=[(pred_2, fn_2), (pred_3, fn_3)])
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace())
|
|
|
|
|
exe = paddle.static.Executor(paddle.CPUPlace())
|
|
|
|
|
res_1, res_2 = exe.run(main_program, fetch_list=[out_1, out_2])
|
|
|
|
|
print(res_1) # [[1. 1.]]
|
|
|
|
|
print(res_2) # [3 3 3]
|
|
|
|
|