You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/benchmark/tensorflow/machine_translation.py

627 lines
24 KiB

# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.framework import dtypes
from tensorflow.python.layers.core import Dense
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import rnn_cell_impl
from tensorflow.python.ops.rnn_cell_impl import RNNCell, BasicLSTMCell
from tensorflow.python.ops.rnn_cell_impl import LSTMStateTuple
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
from tensorflow.python.ops import array_ops
from tensorflow.python.util import nest
import tensorflow.contrib.seq2seq as seq2seq
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
import numpy as np
import os
import argparse
import time
import paddle.v2 as paddle
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"--embedding_dim",
type=int,
default=512,
help="The dimension of embedding table. (default: %(default)d)")
parser.add_argument(
"--encoder_size",
type=int,
default=512,
help="The size of encoder bi-rnn unit. (default: %(default)d)")
parser.add_argument(
"--decoder_size",
type=int,
default=512,
help="The size of decoder rnn unit. (default: %(default)d)")
parser.add_argument(
"--batch_size",
type=int,
default=128,
help="The sequence number of a mini-batch data. (default: %(default)d)")
parser.add_argument(
"--dict_size",
type=int,
default=30000,
help="The dictionary capacity. Dictionaries of source sequence and "
"target dictionary have same capacity. (default: %(default)d)")
parser.add_argument(
"--max_time_steps",
type=int,
default=81,
help="Max number of time steps for sequence. (default: %(default)d)")
parser.add_argument(
"--pass_num",
type=int,
default=10,
help="The pass number to train. (default: %(default)d)")
parser.add_argument(
"--learning_rate",
type=float,
default=0.0002,
help="Learning rate used to train the model. (default: %(default)f)")
parser.add_argument(
"--infer_only", action='store_true', help="If set, run forward only.")
parser.add_argument(
"--beam_size",
type=int,
default=3,
help="The width for beam searching. (default: %(default)d)")
parser.add_argument(
"--max_generation_length",
type=int,
default=250,
help="The maximum length of sequence when doing generation. "
"(default: %(default)d)")
parser.add_argument(
"--save_freq",
type=int,
default=500,
help="Save model checkpoint every this interation. (default: %(default)d)")
parser.add_argument(
"--model_dir",
type=str,
default='./checkpoint',
help="Path to save model checkpoints. (default: %(default)d)")
_Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
START_TOKEN_IDX = 0
END_TOKEN_IDX = 1
class LSTMCellWithSimpleAttention(RNNCell):
"""Add attention mechanism to BasicLSTMCell.
This class is a wrapper based on tensorflow's `BasicLSTMCell`.
"""
def __init__(self,
num_units,
encoder_vector,
encoder_proj,
source_sequence_length,
forget_bias=1.0,
state_is_tuple=True,
activation=None,
reuse=None):
super(LSTMCellWithSimpleAttention, self).__init__(_reuse=reuse)
if not state_is_tuple:
logging.warn("%s: Using a concatenated state is slower and will "
"soon be deprecated. Use state_is_tuple=True.", self)
self._num_units = num_units
# set padding part to 0
self._encoder_vector = self._reset_padding(encoder_vector,
source_sequence_length)
self._encoder_proj = self._reset_padding(encoder_proj,
source_sequence_length)
self._forget_bias = forget_bias
self._state_is_tuple = state_is_tuple
self._activation = activation or math_ops.tanh
self._linear = None
@property
def state_size(self):
return (LSTMStateTuple(self._num_units, self._num_units) \
if self._state_is_tuple else 2 * self._num_units)
@property
def output_size(self):
return self._num_units
def zero_state(self, batch_size, dtype):
state_size = self.state_size
if hasattr(self, "_last_zero_state"):
(last_state_size, last_batch_size, last_dtype,
last_output) = getattr(self, "_last_zero_state")
if (last_batch_size == batch_size and last_dtype == dtype and
last_state_size == state_size):
return last_output
with ops.name_scope(
type(self).__name__ + "ZeroState", values=[batch_size]):
output = _zero_state_tensors(state_size, batch_size, dtype)
self._last_zero_state = (state_size, batch_size, dtype, output)
return output
def call(self, inputs, state):
sigmoid = math_ops.sigmoid
# Parameters of gates are concatenated into one multiply for efficiency.
if self._state_is_tuple:
c, h = state
else:
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
# get context from encoder outputs
context = self._simple_attention(self._encoder_vector,
self._encoder_proj, h)
if self._linear is None:
self._linear = _Linear([inputs, context, h], 4 * self._num_units,
True)
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
i, j, f, o = array_ops.split(
value=self._linear([inputs, context, h]),
num_or_size_splits=4,
axis=1)
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
self._activation(j))
new_h = self._activation(new_c) * sigmoid(o)
if self._state_is_tuple:
new_state = LSTMStateTuple(new_c, new_h)
else:
new_state = array_ops.concat([new_c, new_h], 1)
return new_h, new_state
def _simple_attention(self, encoder_vec, encoder_proj, decoder_state):
"""Implement the attention function.
The implementation has the same logic to the fluid decoder.
"""
decoder_state_proj = tf.contrib.layers.fully_connected(
inputs=decoder_state,
num_outputs=self._num_units,
activation_fn=None,
biases_initializer=None)
decoder_state_expand = tf.tile(
tf.expand_dims(
input=decoder_state_proj, axis=1),
[1, tf.shape(encoder_proj)[1], 1])
concated = tf.concat([decoder_state_expand, encoder_proj], axis=2)
# need reduce the first dimension
attention_weights = tf.contrib.layers.fully_connected(
inputs=tf.reshape(
concated, shape=[-1, self._num_units * 2]),
num_outputs=1,
activation_fn=tf.nn.tanh,
biases_initializer=None)
attention_weights_reshaped = tf.reshape(
attention_weights, shape=[tf.shape(encoder_vec)[0], -1, 1])
# normalize the attention weights using softmax
attention_weights_normed = tf.nn.softmax(
attention_weights_reshaped, dim=1)
scaled = tf.multiply(attention_weights_normed, encoder_vec)
context = tf.reduce_sum(scaled, axis=1)
return context
def _reset_padding(self,
memory,
memory_sequence_length,
check_inner_dims_defined=True):
"""Reset the padding part for encoder inputs.
This funtion comes from tensorflow's `_prepare_memory` function.
"""
memory = nest.map_structure(
lambda m: ops.convert_to_tensor(m, name="memory"), memory)
if memory_sequence_length is not None:
memory_sequence_length = ops.convert_to_tensor(
memory_sequence_length, name="memory_sequence_length")
if check_inner_dims_defined:
def _check_dims(m):
if not m.get_shape()[2:].is_fully_defined():
raise ValueError(
"Expected memory %s to have fully defined inner dims, "
"but saw shape: %s" % (m.name, m.get_shape()))
nest.map_structure(_check_dims, memory)
if memory_sequence_length is None:
seq_len_mask = None
else:
seq_len_mask = array_ops.sequence_mask(
memory_sequence_length,
maxlen=array_ops.shape(nest.flatten(memory)[0])[1],
dtype=nest.flatten(memory)[0].dtype)
seq_len_batch_size = (memory_sequence_length.shape[0].value or
array_ops.shape(memory_sequence_length)[0])
def _maybe_mask(m, seq_len_mask):
rank = m.get_shape().ndims
rank = rank if rank is not None else array_ops.rank(m)
extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32)
m_batch_size = m.shape[0].value or array_ops.shape(m)[0]
if memory_sequence_length is not None:
message = ("memory_sequence_length and memory tensor "
"batch sizes do not match.")
with ops.control_dependencies([
check_ops.assert_equal(
seq_len_batch_size, m_batch_size, message=message)
]):
seq_len_mask = array_ops.reshape(
seq_len_mask,
array_ops.concat(
(array_ops.shape(seq_len_mask), extra_ones), 0))
return m * seq_len_mask
else:
return m
return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask),
memory)
def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
target_dict_dim, is_generating, beam_size,
max_generation_length):
src_word_idx = tf.placeholder(tf.int32, shape=[None, None])
src_sequence_length = tf.placeholder(tf.int32, shape=[None, ])
src_embedding_weights = tf.get_variable("source_word_embeddings",
[source_dict_dim, embedding_dim])
src_embedding = tf.nn.embedding_lookup(src_embedding_weights, src_word_idx)
src_forward_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
src_reversed_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
# no peephole
encoder_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
cell_fw=src_forward_cell,
cell_bw=src_reversed_cell,
inputs=src_embedding,
sequence_length=src_sequence_length,
dtype=tf.float32)
# concat the forward outputs and backward outputs
encoded_vec = tf.concat(encoder_outputs, axis=2)
# project the encoder outputs to size of decoder lstm
encoded_proj = tf.contrib.layers.fully_connected(
inputs=tf.reshape(
encoded_vec, shape=[-1, embedding_dim * 2]),
num_outputs=decoder_size,
activation_fn=None,
biases_initializer=None)
encoded_proj_reshape = tf.reshape(
encoded_proj, shape=[-1, tf.shape(encoded_vec)[1], decoder_size])
# get init state for decoder lstm's H
backword_first = tf.slice(encoder_outputs[1], [0, 0, 0], [-1, 1, -1])
decoder_boot = tf.contrib.layers.fully_connected(
inputs=tf.reshape(
backword_first, shape=[-1, embedding_dim]),
num_outputs=decoder_size,
activation_fn=tf.nn.tanh,
biases_initializer=None)
# prepare the initial state for decoder lstm
cell_init = tf.zeros(tf.shape(decoder_boot), tf.float32)
initial_state = LSTMStateTuple(cell_init, decoder_boot)
# create decoder lstm cell
decoder_cell = LSTMCellWithSimpleAttention(
decoder_size,
encoded_vec
if not is_generating else seq2seq.tile_batch(encoded_vec, beam_size),
encoded_proj_reshape if not is_generating else
seq2seq.tile_batch(encoded_proj_reshape, beam_size),
src_sequence_length if not is_generating else
seq2seq.tile_batch(src_sequence_length, beam_size),
forget_bias=0.0)
output_layer = Dense(target_dict_dim, name='output_projection')
if not is_generating:
trg_word_idx = tf.placeholder(tf.int32, shape=[None, None])
trg_sequence_length = tf.placeholder(tf.int32, shape=[None, ])
trg_embedding_weights = tf.get_variable(
"target_word_embeddings", [target_dict_dim, embedding_dim])
trg_embedding = tf.nn.embedding_lookup(trg_embedding_weights,
trg_word_idx)
training_helper = seq2seq.TrainingHelper(
inputs=trg_embedding,
sequence_length=trg_sequence_length,
time_major=False,
name='training_helper')
training_decoder = seq2seq.BasicDecoder(
cell=decoder_cell,
helper=training_helper,
initial_state=initial_state,
output_layer=output_layer)
# get the max length of target sequence
max_decoder_length = tf.reduce_max(trg_sequence_length)
decoder_outputs_train, _, _ = seq2seq.dynamic_decode(
decoder=training_decoder,
output_time_major=False,
impute_finished=True,
maximum_iterations=max_decoder_length)
decoder_logits_train = tf.identity(decoder_outputs_train.rnn_output)
decoder_pred_train = tf.argmax(
decoder_logits_train, axis=-1, name='decoder_pred_train')
masks = tf.sequence_mask(
lengths=trg_sequence_length,
maxlen=max_decoder_length,
dtype=tf.float32,
name='masks')
# place holder of label sequence
lbl_word_idx = tf.placeholder(tf.int32, shape=[None, None])
# compute the loss
loss = seq2seq.sequence_loss(
logits=decoder_logits_train,
targets=lbl_word_idx,
weights=masks,
average_across_timesteps=True,
average_across_batch=True)
# return feeding list and loss operator
return {
'src_word_idx': src_word_idx,
'src_sequence_length': src_sequence_length,
'trg_word_idx': trg_word_idx,
'trg_sequence_length': trg_sequence_length,
'lbl_word_idx': lbl_word_idx
}, loss
else:
start_tokens = tf.ones([tf.shape(src_word_idx)[0], ],
tf.int32) * START_TOKEN_IDX
# share the same embedding weights with target word
trg_embedding_weights = tf.get_variable(
"target_word_embeddings", [target_dict_dim, embedding_dim])
inference_decoder = beam_search_decoder.BeamSearchDecoder(
cell=decoder_cell,
embedding=lambda tokens: tf.nn.embedding_lookup(trg_embedding_weights, tokens),
start_tokens=start_tokens,
end_token=END_TOKEN_IDX,
initial_state=tf.nn.rnn_cell.LSTMStateTuple(
tf.contrib.seq2seq.tile_batch(initial_state[0], beam_size),
tf.contrib.seq2seq.tile_batch(initial_state[1], beam_size)),
beam_width=beam_size,
output_layer=output_layer)
decoder_outputs_decode, _, _ = seq2seq.dynamic_decode(
decoder=inference_decoder,
output_time_major=False,
#impute_finished=True,# error occurs
maximum_iterations=max_generation_length)
predicted_ids = decoder_outputs_decode.predicted_ids
return {
'src_word_idx': src_word_idx,
'src_sequence_length': src_sequence_length
}, predicted_ids
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in vars(args).iteritems():
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def padding_data(data, padding_size, value):
data = data + [value] * padding_size
return data[:padding_size]
def save(sess, path, var_list=None, global_step=None):
saver = tf.train.Saver(var_list)
save_path = saver.save(sess, save_path=path, global_step=global_step)
print('Model save at %s' % save_path)
def restore(sess, path, var_list=None):
# var_list = None returns the list of all saveable variables
saver = tf.train.Saver(var_list)
saver.restore(sess, save_path=path)
print('model restored from %s' % path)
def adapt_batch_data(data):
src_seq = map(lambda x: x[0], data)
trg_seq = map(lambda x: x[1], data)
lbl_seq = map(lambda x: x[2], data)
src_sequence_length = np.array(
[len(seq) for seq in src_seq]).astype('int32')
src_seq_maxlen = np.max(src_sequence_length)
trg_sequence_length = np.array(
[len(seq) for seq in trg_seq]).astype('int32')
trg_seq_maxlen = np.max(trg_sequence_length)
src_seq = np.array(
[padding_data(seq, src_seq_maxlen, END_TOKEN_IDX)
for seq in src_seq]).astype('int32')
trg_seq = np.array(
[padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
for seq in trg_seq]).astype('int32')
lbl_seq = np.array(
[padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
for seq in lbl_seq]).astype('int32')
return {
'src_word_idx': src_seq,
'src_sequence_length': src_sequence_length,
'trg_word_idx': trg_seq,
'trg_sequence_length': trg_sequence_length,
'lbl_word_idx': lbl_seq
}
def train():
feeding_dict, loss = seq_to_seq_net(
embedding_dim=args.embedding_dim,
encoder_size=args.encoder_size,
decoder_size=args.decoder_size,
source_dict_dim=args.dict_size,
target_dict_dim=args.dict_size,
is_generating=False,
beam_size=args.beam_size,
max_generation_length=args.max_generation_length)
global_step = tf.Variable(0, trainable=False, name='global_step')
trainable_params = tf.trainable_variables()
optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
gradients = tf.gradients(loss, trainable_params)
# may clip the parameters
clip_gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
updates = optimizer.apply_gradients(
zip(gradients, trainable_params), global_step=global_step)
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(args.dict_size)
train_batch_generator = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(args.dict_size), buf_size=1000),
batch_size=args.batch_size)
test_batch_generator = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.test(args.dict_size), buf_size=1000),
batch_size=args.batch_size)
def do_validataion():
total_loss = 0.0
count = 0
for batch_id, data in enumerate(test_batch_generator()):
adapted_batch_data = adapt_batch_data(data)
outputs = sess.run([loss],
feed_dict={
item[1]: adapted_batch_data[item[0]]
for item in feeding_dict.items()
})
total_loss += outputs[0]
count += 1
return total_loss / count
config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
config.gpu_options.allow_growth = True
with tf.Session(config=config) as sess:
init_g = tf.global_variables_initializer()
init_l = tf.local_variables_initializer()
sess.run(init_l)
sess.run(init_g)
for pass_id in xrange(args.pass_num):
pass_start_time = time.time()
words_seen = 0
for batch_id, data in enumerate(train_batch_generator()):
adapted_batch_data = adapt_batch_data(data)
words_seen += np.sum(adapted_batch_data['src_sequence_length'])
words_seen += np.sum(adapted_batch_data['trg_sequence_length'])
outputs = sess.run([updates, loss],
feed_dict={
item[1]: adapted_batch_data[item[0]]
for item in feeding_dict.items()
})
print("pass_id=%d, batch_id=%d, train_loss: %f" %
(pass_id, batch_id, outputs[1]))
pass_end_time = time.time()
test_loss = do_validataion()
time_consumed = pass_end_time - pass_start_time
words_per_sec = words_seen / time_consumed
print("pass_id=%d, test_loss: %f, words/s: %f, sec/pass: %f" %
(pass_id, test_loss, words_per_sec, time_consumed))
def infer():
feeding_dict, predicted_ids = seq_to_seq_net(
embedding_dim=args.embedding_dim,
encoder_size=args.encoder_size,
decoder_size=args.decoder_size,
source_dict_dim=args.dict_size,
target_dict_dim=args.dict_size,
is_generating=True,
beam_size=args.beam_size,
max_generation_length=args.max_generation_length)
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(args.dict_size)
test_batch_generator = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.wmt14.train(args.dict_size), buf_size=1000),
batch_size=args.batch_size)
config = tf.ConfigProto(
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
with tf.Session(config=config) as sess:
restore(sess, './checkpoint/tf_seq2seq-1500')
for batch_id, data in enumerate(test_batch_generator()):
src_seq = map(lambda x: x[0], data)
source_language_seq = [
src_dict[item] for seq in src_seq for item in seq
]
src_sequence_length = np.array(
[len(seq) for seq in src_seq]).astype('int32')
src_seq_maxlen = np.max(src_sequence_length)
src_seq = np.array([
padding_data(seq, src_seq_maxlen, END_TOKEN_IDX)
for seq in src_seq
]).astype('int32')
outputs = sess.run([predicted_ids],
feed_dict={
feeding_dict['src_word_idx']: src_seq,
feeding_dict['src_sequence_length']:
src_sequence_length
})
print("\nDecoder result comparison: ")
source_language_seq = ' '.join(source_language_seq).lstrip(
'<s>').rstrip('<e>').strip()
inference_seq = ''
print(" --> source: " + source_language_seq)
for item in outputs[0][0]:
if item[0] == END_TOKEN_IDX: break
inference_seq += ' ' + trg_dict.get(item[0], '<unk>')
print(" --> inference: " + inference_seq)
if __name__ == '__main__':
args = parser.parse_args()
print_arguments(args)
if args.infer_only:
infer()
else:
train()