|
|
|
@ -348,6 +348,55 @@ class TestImperative(unittest.TestCase):
|
|
|
|
|
self.assertEqual(mlp._fc2, sublayers[1])
|
|
|
|
|
self.assertEqual(len(sublayers), 2)
|
|
|
|
|
|
|
|
|
|
def test_dygraph_vs_static(self):
|
|
|
|
|
inp1 = np.random.rand(4, 3, 3)
|
|
|
|
|
inp2 = np.random.rand(4, 3, 3)
|
|
|
|
|
|
|
|
|
|
# dynamic graph
|
|
|
|
|
with fluid.dygraph.guard():
|
|
|
|
|
if np.sum(inp1) < np.sum(inp2):
|
|
|
|
|
x = fluid.layers.elementwise_add(inp1, inp2)
|
|
|
|
|
else:
|
|
|
|
|
x = fluid.layers.elementwise_sub(inp1, inp2)
|
|
|
|
|
dygraph_result = x._numpy()
|
|
|
|
|
|
|
|
|
|
# static graph
|
|
|
|
|
with new_program_scope():
|
|
|
|
|
inp_data1 = fluid.layers.data(
|
|
|
|
|
name='inp1', shape=[3, 3], dtype=np.float32)
|
|
|
|
|
inp_data2 = fluid.layers.data(
|
|
|
|
|
name='inp2', shape=[3, 3], dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
a = fluid.layers.expand(
|
|
|
|
|
fluid.layers.reshape(
|
|
|
|
|
fluid.layers.reduce_sum(inp_data1), [1, 1]), [4, 1])
|
|
|
|
|
b = fluid.layers.expand(
|
|
|
|
|
fluid.layers.reshape(
|
|
|
|
|
fluid.layers.reduce_sum(inp_data2), [1, 1]), [4, 1])
|
|
|
|
|
cond = fluid.layers.less_than(x=a, y=b)
|
|
|
|
|
|
|
|
|
|
ie = fluid.layers.IfElse(cond)
|
|
|
|
|
with ie.true_block():
|
|
|
|
|
d1 = ie.input(inp_data1)
|
|
|
|
|
d2 = ie.input(inp_data2)
|
|
|
|
|
d3 = fluid.layers.elementwise_add(d1, d2)
|
|
|
|
|
ie.output(d3)
|
|
|
|
|
|
|
|
|
|
with ie.false_block():
|
|
|
|
|
d1 = ie.input(inp_data1)
|
|
|
|
|
d2 = ie.input(inp_data2)
|
|
|
|
|
d3 = fluid.layers.elementwise_sub(d1, d2)
|
|
|
|
|
ie.output(d3)
|
|
|
|
|
out = ie()
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(fluid.CPUPlace(
|
|
|
|
|
) if not core.is_compiled_with_cuda() else fluid.CUDAPlace(0))
|
|
|
|
|
static_result = exe.run(fluid.default_main_program(),
|
|
|
|
|
feed={'inp1': inp1,
|
|
|
|
|
'inp2': inp2},
|
|
|
|
|
fetch_list=out)[0]
|
|
|
|
|
self.assertTrue(np.allclose(dygraph_result, static_result))
|
|
|
|
|
|
|
|
|
|
def test_rnn(self):
|
|
|
|
|
np_inp = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],
|
|
|
|
|
[10.0, 11.0, 12.0]])
|
|
|
|
|