|
|
|
@ -2297,11 +2297,6 @@ def copy_var_to_parent_block(var, layer_helper):
|
|
|
|
|
|
|
|
|
|
def cond(pred, true_fn=None, false_fn=None, name=None):
|
|
|
|
|
"""
|
|
|
|
|
:api_attr: Static Graph
|
|
|
|
|
:alias_main: paddle.nn.cond
|
|
|
|
|
:alias: paddle.nn.cond,paddle.nn.control_flow.cond
|
|
|
|
|
:old_api: paddle.fluid.layers.cond
|
|
|
|
|
|
|
|
|
|
This API returns ``true_fn()`` if the predicate ``pred`` is true else
|
|
|
|
|
``false_fn()`` . Users could also set ``true_fn`` or ``false_fn`` to
|
|
|
|
|
``None`` if do nothing and this API will treat the callable simply returns
|
|
|
|
@ -2323,17 +2318,18 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
|
|
|
|
|
semantics. For example:
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
a = fluid.data(name='a', shape=[-1, 1], dtype='float32')
|
|
|
|
|
b = fluid.data(name='b', shape=[-1, 1], dtype='float32')
|
|
|
|
|
|
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
|
|
a = paddle.zeros((1, 1))
|
|
|
|
|
b = paddle.zeros((1, 1))
|
|
|
|
|
c = a * b
|
|
|
|
|
out = fluid.layers.cond(a < b, lambda: a + c, lambda: b * b)
|
|
|
|
|
out = paddle.nn.cond(a < b, lambda: a + c, lambda: b * b)
|
|
|
|
|
|
|
|
|
|
No matter whether ``a < b`` , ``c = a * b`` will run.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
pred(Variable): A boolean tensor whose numel should be 1. The boolean
|
|
|
|
|
pred(Tensor): A boolean tensor whose numel should be 1. The boolean
|
|
|
|
|
value determines whether to return the result of ``true_fn`` or
|
|
|
|
|
``false_fn`` .
|
|
|
|
|
true_fn(callable, optional): A callable to be performed if ``pred`` is
|
|
|
|
@ -2345,7 +2341,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
|
|
|
|
|
refer to :ref:`api_guide_Name` .
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Variable|list(Variable)|tuple(Variable): returns ``true_fn()`` if the
|
|
|
|
|
Tensor|list(Tensor)|tuple(Tensor): returns ``true_fn()`` if the
|
|
|
|
|
predicate ``pred`` is true else ``false_fn()`` .
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
@ -2356,10 +2352,7 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddle.fluid.layers as layers
|
|
|
|
|
from paddle.fluid.executor import Executor
|
|
|
|
|
from paddle.fluid.framework import Program, program_guard
|
|
|
|
|
import paddle
|
|
|
|
|
|
|
|
|
|
#
|
|
|
|
|
# pseudocode:
|
|
|
|
@ -2369,32 +2362,28 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
|
|
|
|
|
# return 3, 2
|
|
|
|
|
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def true_func():
|
|
|
|
|
return layers.fill_constant(
|
|
|
|
|
shape=[1, 2], dtype='int32', value=1), layers.fill_constant(
|
|
|
|
|
shape=[2, 3], dtype='bool', value=True)
|
|
|
|
|
return paddle.fill_constant(shape=[1, 2], dtype='int32',
|
|
|
|
|
value=1), paddle.fill_constant(shape=[2, 3],
|
|
|
|
|
dtype='bool',
|
|
|
|
|
value=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def false_func():
|
|
|
|
|
return layers.fill_constant(
|
|
|
|
|
shape=[3, 4], dtype='float32', value=3), layers.fill_constant(
|
|
|
|
|
shape=[4, 5], dtype='int64', value=2)
|
|
|
|
|
|
|
|
|
|
main_program = Program()
|
|
|
|
|
startup_program = Program()
|
|
|
|
|
with program_guard(main_program, startup_program):
|
|
|
|
|
x = layers.fill_constant(shape=[1], dtype='float32', value=0.1)
|
|
|
|
|
y = layers.fill_constant(shape=[1], dtype='float32', value=0.23)
|
|
|
|
|
pred = layers.less_than(x, y)
|
|
|
|
|
out = layers.cond(pred, true_func, false_func)
|
|
|
|
|
# out is a tuple containing 2 tensors
|
|
|
|
|
|
|
|
|
|
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
|
|
|
|
|
) else fluid.CPUPlace()
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
ret = exe.run(main_program, fetch_list=out)
|
|
|
|
|
return paddle.fill_constant(shape=[3, 4], dtype='float32',
|
|
|
|
|
value=3), paddle.fill_constant(shape=[4, 5],
|
|
|
|
|
dtype='int64',
|
|
|
|
|
value=2)
|
|
|
|
|
|
|
|
|
|
x = paddle.fill_constant(shape=[1], dtype='float32', value=0.1)
|
|
|
|
|
y = paddle.fill_constant(shape=[1], dtype='float32', value=0.23)
|
|
|
|
|
pred = paddle.less_than(x=x, y=y, name=None)
|
|
|
|
|
ret = paddle.nn.cond(pred, true_func, false_func)
|
|
|
|
|
# ret is a tuple containing 2 tensors
|
|
|
|
|
# ret[0] = [[1 1]]
|
|
|
|
|
# ret[1] = [[ True True True]
|
|
|
|
|
# [ True True True]]
|
|
|
|
|
# [ True True True]]
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|