[Dy2Static]Support return variable created in only one of If.body or If.orelse (#24841)

* Support return variable in only one of if body or else. 

* remove after_visit in IfElseTransformer.

* Modify the result of get_name_ids in test_ifelse_basic.py 

* Add unittest to test the new case. 

* Modify code according to reviews.
revert-24981-add_device_attr_for_regulization
liym27 5 years ago committed by GitHub
parent 0494239b1f
commit 5ea82e8a45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -52,6 +52,51 @@ def dyfunc_with_if_else2(x, col=100):
return y
def dyfunc_with_if_else3(x):
# Create new var in parent scope, return it in true_fn and false_fn.
# The var is created only in one of If.body or If.orelse node, and it used as gast.Load firstly after gast.If node.
# The transformed code:
"""
q = fluid.dygraph.dygraph_to_static.variable_trans_func.
data_layer_not_check(name='q', shape=[-1], dtype='float32')
z = fluid.dygraph.dygraph_to_static.variable_trans_func.
data_layer_not_check(name='z', shape=[-1], dtype='float32')
def true_fn_0(q, x, y):
x = x + 1
z = x + 2
q = x + 3
return q, x, y, z
def false_fn_0(q, x, y):
y = y + 1
z = x - 2
m = x + 2
n = x + 3
return q, x, y, z
q, x, y, z = fluid.layers.cond(fluid.layers.mean(x)[0] < 5, lambda :
fluid.dygraph.dygraph_to_static.convert_call(true_fn_0)(q, x, y),
lambda : fluid.dygraph.dygraph_to_static.convert_call(false_fn_0)(q,
x, y))
"""
y = x + 1
# NOTE: x_v[0] < 5 is True
if fluid.layers.mean(x).numpy()[0] < 5:
x = x + 1
z = x + 2
q = x + 3
else:
y = y + 1
z = x - 2
m = x + 2
n = x + 3
q = q + 1
n = q + 2
x = n
return x
def nested_if_else(x_v):
batch_size = 16
feat_size = x_v.shape[-1]

@ -64,18 +64,24 @@ class TestDygraphIfElse2(TestDygraphIfElse):
class TestDygraphIfElse3(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_with_if_else3
class TestDygraphNestedIfElse(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = nested_if_else
class TestDygraphIfElse4(TestDygraphIfElse):
class TestDygraphNestedIfElse2(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = nested_if_else_2
class TestDygraphIfElse5(TestDygraphIfElse):
class TestDygraphNestedIfElse3(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = nested_if_else_3

@ -34,7 +34,7 @@ class TestGetNameIds(unittest.TestCase):
def test_fn(x):
return x+1
"""
self.all_name_ids = {'x': [gast.Param()]}
self.all_name_ids = {'x': [gast.Param(), gast.Load()]}
def test_get_name_ids(self):
source = textwrap.dedent(self.source)
@ -82,6 +82,7 @@ class TestGetNameIds2(TestGetNameIds):
gast.Load(),
gast.Store(),
gast.Store(),
gast.Load(),
]
}
@ -113,6 +114,7 @@ class TestGetNameIds3(TestGetNameIds):
gast.Store(),
gast.Load(),
gast.Store(),
gast.Load(),
]
}

Loading…
Cancel
Save