|
|
|
@ -18,8 +18,8 @@ import unittest
|
|
|
|
|
import numpy as np
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddle.fluid.core as core
|
|
|
|
|
from paddle.fluid.dygraph import declarative
|
|
|
|
|
from paddle.fluid.dygraph import ProgramTranslator
|
|
|
|
|
from paddle.jit import to_static
|
|
|
|
|
from paddle.jit import ProgramTranslator
|
|
|
|
|
|
|
|
|
|
from ifelse_simple_func import dyfunc_with_if_else
|
|
|
|
|
|
|
|
|
@ -27,13 +27,13 @@ SEED = 2020
|
|
|
|
|
np.random.seed(SEED)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_base(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_inside_func_base(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
|
|
|
|
@ -43,7 +43,7 @@ def test_inside_func_base(x):
|
|
|
|
|
return inner_func(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_if(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
if x < 0:
|
|
|
|
@ -53,7 +53,7 @@ def test_return_if(x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_if_else(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
if x > 0:
|
|
|
|
@ -66,7 +66,7 @@ def test_return_if_else(x):
|
|
|
|
|
x -= 8888 # useless statement to test our code can handle it.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_in_while(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
i = fluid.layers.fill_constant(shape=[1], dtype='int32', value=0)
|
|
|
|
@ -79,7 +79,7 @@ def test_return_in_while(x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_in_for(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
for i in range(10):
|
|
|
|
@ -91,13 +91,13 @@ def test_return_in_for(x):
|
|
|
|
|
return x - 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_recursive_return(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
return dyfunc_with_if_else(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_different_length_if_body(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
y = x + 1
|
|
|
|
@ -108,7 +108,7 @@ def test_return_different_length_if_body(x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_different_length_else(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
y = x + 1
|
|
|
|
@ -119,13 +119,13 @@ def test_return_different_length_else(x):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_no_return(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
y = x + 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_none(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
y = x + 1
|
|
|
|
@ -136,7 +136,7 @@ def test_return_none(x):
|
|
|
|
|
return x, y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@declarative
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_no_variable(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
y = x + 1
|
|
|
|
@ -147,6 +147,38 @@ def test_return_no_variable(x):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_list_one_value(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
x += 1
|
|
|
|
|
return [x]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_list_many_values(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
x += 1
|
|
|
|
|
y = x * 2
|
|
|
|
|
z = x * x
|
|
|
|
|
return [x, y, z]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_tuple_one_value(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
x += 1
|
|
|
|
|
return (x, )
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@to_static
|
|
|
|
|
def test_return_tuple_many_values(x):
|
|
|
|
|
x = fluid.dygraph.to_variable(x)
|
|
|
|
|
x += 1
|
|
|
|
|
y = x * 2
|
|
|
|
|
z = x * x
|
|
|
|
|
return (x, y, z)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestReturnBase(unittest.TestCase):
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self.input = np.ones((1)).astype('int32')
|
|
|
|
@ -158,29 +190,19 @@ class TestReturnBase(unittest.TestCase):
|
|
|
|
|
def init_dygraph_func(self):
|
|
|
|
|
self.dygraph_func = test_return_base
|
|
|
|
|
|
|
|
|
|
def run_dygraph_mode(self):
|
|
|
|
|
self.program_translator.enable(False)
|
|
|
|
|
def _run(self, to_static=False):
|
|
|
|
|
self.program_translator.enable(to_static)
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
res = self.dygraph_func(self.input)
|
|
|
|
|
if isinstance(res, (tuple)):
|
|
|
|
|
return tuple(r.numpy() for r in res)
|
|
|
|
|
elif isinstance(res, core.VarBase):
|
|
|
|
|
return res.numpy()
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
def run_static_mode(self):
|
|
|
|
|
self.program_translator.enable(True)
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
res = self.dygraph_func(self.input)
|
|
|
|
|
if isinstance(res, tuple):
|
|
|
|
|
if isinstance(res, (tuple, list)):
|
|
|
|
|
return tuple(r.numpy() for r in res)
|
|
|
|
|
elif isinstance(res, core.VarBase):
|
|
|
|
|
return res.numpy()
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
def test_transformed_static_result(self):
|
|
|
|
|
dygraph_res = self.run_dygraph_mode()
|
|
|
|
|
static_res = self.run_static_mode()
|
|
|
|
|
dygraph_res = self._run(to_static=False)
|
|
|
|
|
static_res = self._run(to_static=True)
|
|
|
|
|
if isinstance(dygraph_res, tuple):
|
|
|
|
|
self.assertTrue(isinstance(static_res, tuple))
|
|
|
|
|
self.assertEqual(len(dygraph_res), len(static_res))
|
|
|
|
@ -255,5 +277,25 @@ class TestReturnNoVariable(TestReturnBase):
|
|
|
|
|
self.dygraph_func = test_return_no_variable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestReturnListOneValue(TestReturnBase):
|
|
|
|
|
def init_dygraph_func(self):
|
|
|
|
|
self.dygraph_func = test_return_list_one_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestReturnListManyValue(TestReturnBase):
|
|
|
|
|
def init_dygraph_func(self):
|
|
|
|
|
self.dygraph_func = test_return_list_many_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestReturnTupleOneValue(TestReturnBase):
|
|
|
|
|
def init_dygraph_func(self):
|
|
|
|
|
self.dygraph_func = test_return_tuple_one_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestReturnTupleManyValue(TestReturnBase):
|
|
|
|
|
def init_dygraph_func(self):
|
|
|
|
|
self.dygraph_func = test_return_tuple_many_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|