|
|
|
@ -18,7 +18,7 @@ from ..wrapped_decorator import signature_safe_contextmanager
|
|
|
|
|
from .layer_function_generator import autodoc, templatedoc
|
|
|
|
|
from .tensor import assign, cast, fill_constant
|
|
|
|
|
from .. import core
|
|
|
|
|
from ..framework import Program, Variable, Operator
|
|
|
|
|
from ..framework import Program, Variable, Operator, in_dygraph_mode
|
|
|
|
|
from ..layer_helper import LayerHelper, unique_name
|
|
|
|
|
from .nn import logical_and, logical_not, logical_or
|
|
|
|
|
from .utils import assert_same_structure, map_structure
|
|
|
|
@ -999,6 +999,20 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
|
|
|
|
|
"the shape of the variable returned by cond should be [],"
|
|
|
|
|
"but given shape as {0}.".format(list(pre_cond.shape)))
|
|
|
|
|
|
|
|
|
|
if in_dygraph_mode():
|
|
|
|
|
now_cond = pre_cond.numpy()[0]
|
|
|
|
|
while (now_cond):
|
|
|
|
|
output_vars = body(*loop_vars)
|
|
|
|
|
if not isinstance(output_vars, (list, tuple)):
|
|
|
|
|
output_vars = [output_vars]
|
|
|
|
|
if len(output_vars) != len(loop_vars):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"body in while_loop should return the same arity "
|
|
|
|
|
"(length and structure) and types as loop_vars")
|
|
|
|
|
now_cond = cond(*output_vars).numpy()[0]
|
|
|
|
|
loop_vars = output_vars
|
|
|
|
|
return loop_vars
|
|
|
|
|
|
|
|
|
|
while_loop_block = While(pre_cond, is_test, name)
|
|
|
|
|
with while_loop_block.block():
|
|
|
|
|
output_vars = body(*loop_vars)
|
|
|
|
|