Add unitTest for `Tensor==constant` for ifElse in dygraph2static (#23407)

* Add unitTest for `Tensor==constant` for ifElse in dygraph2static test=develop
revert-23830-2.0-beta
Aurelius84 5 years ago committed by GitHub
parent 9676ac1c5c
commit 4955c97ee8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -85,6 +85,7 @@ class IfElseTransformer(gast.NodeTransformer):
attribute = node.func
if attribute.attr == 'numpy':
node = attribute.value
self.generic_visit(node)
return node
def visit_IfExp(self, node):
@ -292,12 +293,12 @@ class NodeTestTransformer(gast.NodeTransformer):
return self.visit(self.ast_root)
def visit_Call(self, node):
# self.generic_visit(node)
# Remove `numpy()` statement, like `Tensor.numpy()[i]` -> `Tensor[i]`
if isinstance(node.func, gast.Attribute):
attribute = node.func
if attribute.attr == 'numpy':
node = attribute.value
self.generic_visit(node)
return node
def visit_UnaryOp(self, node):

@ -116,6 +116,30 @@ class TestDygraphIfElse6(TestDygraphIfElse):
self.dyfunc = dyfunc_ifExp_with_while
def dyfunc_ifExp_with_while2(x):
y = [x]
def add_fn(x):
x = x + 1
return x
def map_func(func, tensor_list):
return [func(x) for x in tensor_list]
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
# It will be converted into `layers.cond` as followed.
# map_func(lambda x: fluid.layers.cond(i==0, lambda: x, lambda: add_fn(x), y)
# `i (Tensor) == 0` is supported in dygraph.
y = map_func(lambda x: x if i == 0 else add_fn(x), y)
return y[0]
class TestDygraphIfElse7(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')
self.dyfunc = dyfunc_ifExp_with_while2
class TestDygraphIfElseWithAndOr(TestDygraphIfElse):
def setUp(self):
self.x = np.random.random([10, 16]).astype('float32')

Loading…
Cancel
Save