[Dy2stat] Refine return mechanism in @to_static (#28116)

* remove some judgement

* fix len(outputs) == 1
revert-27871-prv-conv-grad-opt
Aurelius84 4 years ago committed by GitHub
parent 68449d19a5
commit e730516090
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -606,8 +606,10 @@ class ConcreteProgram(object):
error.attach_error_data(e) error.attach_error_data(e)
raise raise
if not isinstance(outputs, if outputs is not None:
(tuple, list)) and outputs is not None: need_wrap_into_list = not isinstance(outputs, (
tuple, list)) or len(outputs) == 1
if need_wrap_into_list:
outputs = [outputs] outputs = [outputs]
main_program = update_op_callstack_with_origin_info(main_program) main_program = update_op_callstack_with_origin_info(main_program)

@ -18,8 +18,8 @@ import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.dygraph import declarative from paddle.jit import to_static
from paddle.fluid.dygraph import ProgramTranslator from paddle.jit import ProgramTranslator
from ifelse_simple_func import dyfunc_with_if_else from ifelse_simple_func import dyfunc_with_if_else
@ -27,13 +27,13 @@ SEED = 2020
np.random.seed(SEED) np.random.seed(SEED)
@declarative @to_static
def test_return_base(x): def test_return_base(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
return x return x
@declarative @to_static
def test_inside_func_base(x): def test_inside_func_base(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
@ -43,7 +43,7 @@ def test_inside_func_base(x):
return inner_func(x) return inner_func(x)
@declarative @to_static
def test_return_if(x): def test_return_if(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
if x < 0: if x < 0:
@ -53,7 +53,7 @@ def test_return_if(x):
return x return x
@declarative @to_static
def test_return_if_else(x): def test_return_if_else(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
if x > 0: if x > 0:
@ -66,7 +66,7 @@ def test_return_if_else(x):
x -= 8888 # useless statement to test our code can handle it. x -= 8888 # useless statement to test our code can handle it.
@declarative @to_static
def test_return_in_while(x): def test_return_in_while(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0) i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
@ -79,7 +79,7 @@ def test_return_in_while(x):
return x return x
@declarative @to_static
def test_return_in_for(x): def test_return_in_for(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
for i in range(10): for i in range(10):
@ -91,13 +91,13 @@ def test_return_in_for(x):
return x - 1 return x - 1
@declarative @to_static
def test_recursive_return(x): def test_recursive_return(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
return dyfunc_with_if_else(x) return dyfunc_with_if_else(x)
@declarative @to_static
def test_return_different_length_if_body(x): def test_return_different_length_if_body(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = x + 1 y = x + 1
@ -108,7 +108,7 @@ def test_return_different_length_if_body(x):
return x return x
@declarative @to_static
def test_return_different_length_else(x): def test_return_different_length_else(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = x + 1 y = x + 1
@ -119,13 +119,13 @@ def test_return_different_length_else(x):
return x return x
@declarative @to_static
def test_no_return(x): def test_no_return(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = x + 1 y = x + 1
@declarative @to_static
def test_return_none(x): def test_return_none(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = x + 1 y = x + 1
@ -136,7 +136,7 @@ def test_return_none(x):
return x, y return x, y
@declarative @to_static
def test_return_no_variable(x): def test_return_no_variable(x):
x = fluid.dygraph.to_variable(x) x = fluid.dygraph.to_variable(x)
y = x + 1 y = x + 1
@ -147,6 +147,38 @@ def test_return_no_variable(x):
return return
@to_static
def test_return_list_one_value(x):
x = fluid.dygraph.to_variable(x)
x += 1
return [x]
@to_static
def test_return_list_many_values(x):
x = fluid.dygraph.to_variable(x)
x += 1
y = x * 2
z = x * x
return [x, y, z]
@to_static
def test_return_tuple_one_value(x):
x = fluid.dygraph.to_variable(x)
x += 1
return (x, )
@to_static
def test_return_tuple_many_values(x):
x = fluid.dygraph.to_variable(x)
x += 1
y = x * 2
z = x * x
return (x, y, z)
class TestReturnBase(unittest.TestCase): class TestReturnBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.input = np.ones((1)).astype('int32') self.input = np.ones((1)).astype('int32')
@ -158,29 +190,19 @@ class TestReturnBase(unittest.TestCase):
def init_dygraph_func(self): def init_dygraph_func(self):
self.dygraph_func = test_return_base self.dygraph_func = test_return_base
def run_dygraph_mode(self): def _run(self, to_static=False):
self.program_translator.enable(False) self.program_translator.enable(to_static)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
res = self.dygraph_func(self.input) res = self.dygraph_func(self.input)
if isinstance(res, (tuple)): if isinstance(res, (tuple, list)):
return tuple(r.numpy() for r in res)
elif isinstance(res, core.VarBase):
return res.numpy()
return res
def run_static_mode(self):
self.program_translator.enable(True)
with fluid.dygraph.guard():
res = self.dygraph_func(self.input)
if isinstance(res, tuple):
return tuple(r.numpy() for r in res) return tuple(r.numpy() for r in res)
elif isinstance(res, core.VarBase): elif isinstance(res, core.VarBase):
return res.numpy() return res.numpy()
return res return res
def test_transformed_static_result(self): def test_transformed_static_result(self):
dygraph_res = self.run_dygraph_mode() dygraph_res = self._run(to_static=False)
static_res = self.run_static_mode() static_res = self._run(to_static=True)
if isinstance(dygraph_res, tuple): if isinstance(dygraph_res, tuple):
self.assertTrue(isinstance(static_res, tuple)) self.assertTrue(isinstance(static_res, tuple))
self.assertEqual(len(dygraph_res), len(static_res)) self.assertEqual(len(dygraph_res), len(static_res))
@ -255,5 +277,25 @@ class TestReturnNoVariable(TestReturnBase):
self.dygraph_func = test_return_no_variable self.dygraph_func = test_return_no_variable
class TestReturnListOneValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_list_one_value
class TestReturnListManyValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_list_many_values
class TestReturnTupleOneValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_tuple_one_value
class TestReturnTupleManyValue(TestReturnBase):
def init_dygraph_func(self):
self.dygraph_func = test_return_tuple_many_values
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

Loading…
Cancel
Save