Merge pull request #7770 from peterzhang2029/nmt_fix
enhance the machine_translation model in unittest.emailweixu-patch-1
commit
f91f134458
@ -0,0 +1,202 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import paddle.v2 as paddle
|
||||
import paddle.v2.fluid as fluid
|
||||
import paddle.v2.fluid.core as core
|
||||
import paddle.v2.fluid.framework as framework
|
||||
import paddle.v2.fluid.layers as layers
|
||||
from paddle.v2.fluid.executor import Executor
|
||||
|
||||
dict_size = 30000
|
||||
source_dict_dim = target_dict_dim = dict_size
|
||||
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(dict_size)
|
||||
hidden_dim = 32
|
||||
embedding_dim = 16
|
||||
batch_size = 10
|
||||
max_length = 50
|
||||
topk_size = 50
|
||||
encoder_size = decoder_size = hidden_dim
|
||||
IS_SPARSE = True
|
||||
USE_PEEPHOLES = False
|
||||
|
||||
|
||||
def bi_lstm_encoder(input_seq, hidden_size):
|
||||
input_forward_proj = fluid.layers.fc(input=input_seq,
|
||||
size=hidden_size * 4,
|
||||
bias_attr=True)
|
||||
forward, _ = fluid.layers.dynamic_lstm(
|
||||
input=input_forward_proj,
|
||||
size=hidden_size * 4,
|
||||
use_peepholes=USE_PEEPHOLES)
|
||||
input_backward_proj = fluid.layers.fc(input=input_seq,
|
||||
size=hidden_size * 4,
|
||||
bias_attr=True)
|
||||
backward, _ = fluid.layers.dynamic_lstm(
|
||||
input=input_backward_proj,
|
||||
size=hidden_size * 4,
|
||||
is_reverse=True,
|
||||
use_peepholes=USE_PEEPHOLES)
|
||||
return forward, backward
|
||||
|
||||
|
||||
# FIXME(peterzhang2029): Replace this function with the lstm_unit_op.
|
||||
def lstm_step(x_t, hidden_t_prev, cell_t_prev, size):
|
||||
def linear(inputs):
|
||||
return fluid.layers.fc(input=inputs, size=size, bias_attr=True)
|
||||
|
||||
forget_gate = fluid.layers.sigmoid(x=linear([hidden_t_prev, x_t]))
|
||||
input_gate = fluid.layers.sigmoid(x=linear([hidden_t_prev, x_t]))
|
||||
output_gate = fluid.layers.sigmoid(x=linear([hidden_t_prev, x_t]))
|
||||
cell_tilde = fluid.layers.tanh(x=linear([hidden_t_prev, x_t]))
|
||||
|
||||
cell_t = fluid.layers.sums(input=[
|
||||
fluid.layers.elementwise_mul(
|
||||
x=forget_gate, y=cell_t_prev), fluid.layers.elementwise_mul(
|
||||
x=input_gate, y=cell_tilde)
|
||||
])
|
||||
|
||||
hidden_t = fluid.layers.elementwise_mul(
|
||||
x=output_gate, y=fluid.layers.tanh(x=cell_t))
|
||||
|
||||
return hidden_t, cell_t
|
||||
|
||||
|
||||
def lstm_decoder_without_attention(target_embedding, decoder_boot, context,
|
||||
decoder_size):
|
||||
rnn = fluid.layers.DynamicRNN()
|
||||
|
||||
cell_init = fluid.layers.fill_constant_batch_size_like(
|
||||
input=decoder_boot,
|
||||
value=0.0,
|
||||
shape=[-1, decoder_size],
|
||||
dtype='float32')
|
||||
cell_init.stop_gradient = False
|
||||
|
||||
with rnn.block():
|
||||
current_word = rnn.step_input(target_embedding)
|
||||
context = rnn.static_input(context)
|
||||
|
||||
hidden_mem = rnn.memory(init=decoder_boot, need_reorder=True)
|
||||
cell_mem = rnn.memory(init=cell_init)
|
||||
decoder_inputs = fluid.layers.concat(
|
||||
input=[context, current_word], axis=1)
|
||||
h, c = lstm_step(decoder_inputs, hidden_mem, cell_mem, decoder_size)
|
||||
rnn.update_memory(hidden_mem, h)
|
||||
rnn.update_memory(cell_mem, c)
|
||||
out = fluid.layers.fc(input=h,
|
||||
size=target_dict_dim,
|
||||
bias_attr=True,
|
||||
act='softmax')
|
||||
rnn.output(out)
|
||||
return rnn()
|
||||
|
||||
|
||||
def seq_to_seq_net():
|
||||
"""Construct a seq2seq network."""
|
||||
|
||||
src_word_idx = fluid.layers.data(
|
||||
name='source_sequence', shape=[1], dtype='int64', lod_level=1)
|
||||
|
||||
src_embedding = fluid.layers.embedding(
|
||||
input=src_word_idx,
|
||||
size=[source_dict_dim, embedding_dim],
|
||||
dtype='float32')
|
||||
|
||||
src_forward, src_backward = bi_lstm_encoder(
|
||||
input_seq=src_embedding, hidden_size=encoder_size)
|
||||
|
||||
encoded_vector = fluid.layers.concat(
|
||||
input=[src_forward, src_backward], axis=1)
|
||||
|
||||
enc_vec_last = fluid.layers.sequence_last_step(input=encoded_vector)
|
||||
|
||||
decoder_boot = fluid.layers.fc(input=enc_vec_last,
|
||||
size=decoder_size,
|
||||
bias_attr=False,
|
||||
act='tanh')
|
||||
|
||||
trg_word_idx = fluid.layers.data(
|
||||
name='target_sequence', shape=[1], dtype='int64', lod_level=1)
|
||||
|
||||
trg_embedding = fluid.layers.embedding(
|
||||
input=trg_word_idx,
|
||||
size=[target_dict_dim, embedding_dim],
|
||||
dtype='float32')
|
||||
|
||||
prediction = lstm_decoder_without_attention(trg_embedding, decoder_boot,
|
||||
enc_vec_last, decoder_size)
|
||||
label = fluid.layers.data(
|
||||
name='label_sequence', shape=[1], dtype='int64', lod_level=1)
|
||||
cost = fluid.layers.cross_entropy(input=prediction, label=label)
|
||||
avg_cost = fluid.layers.mean(x=cost)
|
||||
|
||||
return avg_cost
|
||||
|
||||
|
||||
def to_lodtensor(data, place):
|
||||
seq_lens = [len(seq) for seq in data]
|
||||
cur_len = 0
|
||||
lod = [cur_len]
|
||||
for l in seq_lens:
|
||||
cur_len += l
|
||||
lod.append(cur_len)
|
||||
flattened_data = np.concatenate(data, axis=0).astype("int64")
|
||||
flattened_data = flattened_data.reshape([len(flattened_data), 1])
|
||||
res = core.LoDTensor()
|
||||
res.set(flattened_data, place)
|
||||
res.set_lod([lod])
|
||||
return res
|
||||
|
||||
|
||||
def main():
|
||||
avg_cost = seq_to_seq_net()
|
||||
|
||||
optimizer = fluid.optimizer.Adagrad(learning_rate=1e-4)
|
||||
optimizer.minimize(avg_cost)
|
||||
|
||||
train_data = paddle.batch(
|
||||
paddle.reader.shuffle(
|
||||
paddle.dataset.wmt14.train(dict_size), buf_size=1000),
|
||||
batch_size=batch_size)
|
||||
|
||||
place = core.CPUPlace()
|
||||
exe = Executor(place)
|
||||
|
||||
exe.run(framework.default_startup_program())
|
||||
|
||||
batch_id = 0
|
||||
for pass_id in xrange(2):
|
||||
for data in train_data():
|
||||
word_data = to_lodtensor(map(lambda x: x[0], data), place)
|
||||
trg_word = to_lodtensor(map(lambda x: x[1], data), place)
|
||||
trg_word_next = to_lodtensor(map(lambda x: x[2], data), place)
|
||||
outs = exe.run(framework.default_main_program(),
|
||||
feed={
|
||||
'source_sequence': word_data,
|
||||
'target_sequence': trg_word,
|
||||
'label_sequence': trg_word_next
|
||||
},
|
||||
fetch_list=[avg_cost])
|
||||
avg_cost_val = np.array(outs[0])
|
||||
print('pass_id=' + str(pass_id) + ' batch=' + str(batch_id) +
|
||||
" avg_cost=" + str(avg_cost_val))
|
||||
if batch_id > 3:
|
||||
exit(0)
|
||||
batch_id += 1
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in new issue