|
|
|
@ -211,5 +211,67 @@ class TestSimpleMul(unittest.TestCase):
|
|
|
|
|
self.assertTrue(numpy.allclose(w_g_num, w_g, rtol=0.05))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestSimpleMulWithMemory(unittest.TestCase):
|
|
|
|
|
DATA_WIDTH = 32
|
|
|
|
|
HIDDEN_WIDTH = 10
|
|
|
|
|
DATA_NAME = 'X'
|
|
|
|
|
PARAM_NAME = 'W'
|
|
|
|
|
|
|
|
|
|
class SimpleMulWithMemory(BaseRNN):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(TestSimpleMulWithMemory.SimpleMulWithMemory, self).__init__({
|
|
|
|
|
TestSimpleMulWithMemory.DATA_NAME: {
|
|
|
|
|
'shape': [TestSimpleMulWithMemory.DATA_WIDTH]
|
|
|
|
|
}
|
|
|
|
|
}, {'Mem': {
|
|
|
|
|
'shape': [TestSimpleMulWithMemory.HIDDEN_WIDTH]
|
|
|
|
|
}}, {
|
|
|
|
|
TestSimpleMulWithMemory.PARAM_NAME: {
|
|
|
|
|
'shape': [
|
|
|
|
|
TestSimpleMulWithMemory.DATA_WIDTH,
|
|
|
|
|
TestSimpleMulWithMemory.HIDDEN_WIDTH
|
|
|
|
|
]
|
|
|
|
|
}
|
|
|
|
|
}, ['Out'])
|
|
|
|
|
|
|
|
|
|
def step(self, X, Mem, W, Out):
|
|
|
|
|
o = numpy.matmul(X, W)
|
|
|
|
|
assert isinstance(Mem, Memory)
|
|
|
|
|
o += Mem.ex
|
|
|
|
|
Mem.update(o)
|
|
|
|
|
assert isinstance(Out, Output)
|
|
|
|
|
Out.out(o)
|
|
|
|
|
|
|
|
|
|
@prog_scope()
|
|
|
|
|
def test_forward_backward(self):
|
|
|
|
|
py_rnn = TestSimpleMulWithMemory.SimpleMulWithMemory()
|
|
|
|
|
|
|
|
|
|
data = fluid.layers.data(
|
|
|
|
|
name=self.DATA_NAME, shape=[self.DATA_WIDTH], lod_level=1)
|
|
|
|
|
rnn = fluid.layers.DynamicRNN()
|
|
|
|
|
with rnn.block():
|
|
|
|
|
d = rnn.step_input(data)
|
|
|
|
|
mem = rnn.memory(value=0.0, shape=[self.HIDDEN_WIDTH])
|
|
|
|
|
hidden = fluid.layers.fc(input=d,
|
|
|
|
|
size=self.HIDDEN_WIDTH,
|
|
|
|
|
param_attr=self.PARAM_NAME,
|
|
|
|
|
bias_attr=False,
|
|
|
|
|
act=None)
|
|
|
|
|
o = fluid.layers.elementwise_add(x=hidden, y=mem)
|
|
|
|
|
rnn.update_memory(mem, o)
|
|
|
|
|
rnn.output(o)
|
|
|
|
|
|
|
|
|
|
out = rnn()
|
|
|
|
|
last = fluid.layers.sequence_pool(input=out, pool_type='last')
|
|
|
|
|
|
|
|
|
|
cpu = fluid.CPUPlace()
|
|
|
|
|
exe = fluid.Executor(cpu)
|
|
|
|
|
|
|
|
|
|
last_np, = exe.run(feed=py_rnn.to_feed(cpu), fetch_list=[last])
|
|
|
|
|
last_by_py, = py_rnn.exe().values()
|
|
|
|
|
|
|
|
|
|
self.assertTrue(numpy.allclose(last_np, last_by_py))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|