parent
cbb532196d
commit
92f52e3bb7
@ -0,0 +1,72 @@
|
||||
# Copyright PaddlePaddle contributors. All Rights Reserved
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import difflib
|
||||
import unittest
|
||||
|
||||
import paddle.trainer_config_helpers as conf_helps
|
||||
import paddle.v2.activation as activation
|
||||
import paddle.v2.data_type as data_type
|
||||
import paddle.v2.layer as layer
|
||||
from paddle.trainer_config_helpers.config_parser_utils import \
|
||||
parse_network_config as parse_network
|
||||
|
||||
|
||||
class RNNTest(unittest.TestCase):
|
||||
def test_simple_rnn(self):
|
||||
dict_dim = 10
|
||||
word_dim = 8
|
||||
hidden_dim = 8
|
||||
|
||||
def test_old_rnn():
|
||||
def step(y):
|
||||
mem = conf_helps.memory(name="rnn_state", size=hidden_dim)
|
||||
out = conf_helps.fc_layer(
|
||||
input=[y, mem],
|
||||
size=hidden_dim,
|
||||
act=activation.Tanh(),
|
||||
bias_attr=True,
|
||||
name="rnn_state")
|
||||
return out
|
||||
|
||||
def test():
|
||||
data1 = conf_helps.data_layer(name="word", size=dict_dim)
|
||||
embd = conf_helps.embedding_layer(input=data1, size=word_dim)
|
||||
conf_helps.recurrent_group(name="rnn", step=step, input=embd)
|
||||
|
||||
return str(parse_network(test))
|
||||
|
||||
def test_new_rnn():
|
||||
def new_step(y):
|
||||
mem = layer.memory(name="rnn_state", size=hidden_dim)
|
||||
out = layer.fc(input=[mem],
|
||||
step_input=y,
|
||||
size=hidden_dim,
|
||||
act=activation.Tanh(),
|
||||
bias_attr=True,
|
||||
name="rnn_state")
|
||||
return out.to_proto(dict())
|
||||
|
||||
data1 = layer.data(
|
||||
name="word", type=data_type.integer_value(dict_dim))
|
||||
embd = layer.embedding(input=data1, size=word_dim)
|
||||
aaa = layer.recurrent_group(name="rnn", step=new_step, input=embd)
|
||||
return str(layer.parse_network(aaa))
|
||||
|
||||
diff = difflib.unified_diff(test_old_rnn().splitlines(1),
|
||||
test_new_rnn().splitlines(1))
|
||||
print ''.join(diff)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue