|
|
|
@ -1851,24 +1851,45 @@ def cond(pred, true_fn=None, false_fn=None, name=None):
|
|
|
|
|
list of tensors.
|
|
|
|
|
|
|
|
|
|
Note:
|
|
|
|
|
The tuples or lists in ``true_fn`` and ``false_fn`` must have same
|
|
|
|
|
shape because of dataflow model of PaddlePaddle while the tensors in the
|
|
|
|
|
tuples or the lists can have different shapes.
|
|
|
|
|
1. The tuples or lists returned by ``true_fn`` and ``false_fn`` must have
|
|
|
|
|
the same shape because of dataflow model of PaddlePaddle while the
|
|
|
|
|
tensors in the tuples or the lists can have different shapes.
|
|
|
|
|
|
|
|
|
|
2. Any tensors or operations created outside of ``true_fn`` and
|
|
|
|
|
``false_fn`` will be executed regardless of which branch is selected at
|
|
|
|
|
runtime. This has frequently surprised users who expected a lazy
|
|
|
|
|
semantics. For example:
|
|
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
a = fluid.data(name='a', shape=[-1, 1], dtype='float32')
|
|
|
|
|
b = fluid.data(name='b', shape=[-1, 1], dtype='float32')
|
|
|
|
|
c = a * b
|
|
|
|
|
out = fluid.layers.cond(a < b, lambda: a + c, lambda: b * b)
|
|
|
|
|
|
|
|
|
|
No matter whether ``a < b`` , ``c = a * b`` will run.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
pred(Variable): A boolean tensor whose numel should be 1. The boolean
|
|
|
|
|
value determines whether to return the result of ``true_fn`` or
|
|
|
|
|
``false_fn``
|
|
|
|
|
true_fn(callable): A callable to be performed if ``pred`` is true
|
|
|
|
|
false_fn(callable): A callable to be performed if ``pred`` is false
|
|
|
|
|
name(str, optional): The default value is ``None``. Normally users
|
|
|
|
|
``false_fn`` .
|
|
|
|
|
true_fn(callable, optional): A callable to be performed if ``pred`` is
|
|
|
|
|
true. The default value is ``None`` .
|
|
|
|
|
false_fn(callable, optional): A callable to be performed if ``pred`` is
|
|
|
|
|
false. The default value is ``None`` .
|
|
|
|
|
name(str, optional): The default value is ``None`` . Normally users
|
|
|
|
|
don't have to set this parameter. For more information, please
|
|
|
|
|
refer to :ref:`api_guide_Name`.
|
|
|
|
|
refer to :ref:`api_guide_Name` .
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Variable|list(Variable)|tuple(Variable): returns ``true_fn()`` if the
|
|
|
|
|
predicate ``pred`` is true else ``false_fn()`` .
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
TypeError: if ``true_fn`` or ``false_fn`` is not callable.
|
|
|
|
|
ValueError: if ``true_fn`` and ``false_fn`` doesn't return the same
|
|
|
|
|
nest structure of tensors.
|
|
|
|
|
ValueError: if ``true_fn`` and ``false_fn`` don't return the same nest
|
|
|
|
|
structure of tensors.
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
.. code-block:: python
|
|
|
|
|