You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
627 lines
24 KiB
627 lines
24 KiB
7 years ago
|
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||
|
#
|
||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
# you may not use this file except in compliance with the License.
|
||
|
# You may obtain a copy of the License at
|
||
|
#
|
||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||
|
#
|
||
|
# Unless required by applicable law or agreed to in writing, software
|
||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
# See the License for the specific language governing permissions and
|
||
|
# limitations under the License.
|
||
|
|
||
|
from __future__ import absolute_import
|
||
|
from __future__ import division
|
||
|
from __future__ import print_function
|
||
|
|
||
|
import tensorflow as tf
|
||
|
from tensorflow.python.framework import dtypes
|
||
|
from tensorflow.python.layers.core import Dense
|
||
|
from tensorflow.python.ops import check_ops
|
||
|
from tensorflow.python.ops import math_ops
|
||
|
from tensorflow.python.framework import ops
|
||
|
from tensorflow.python.ops import rnn_cell_impl
|
||
|
from tensorflow.python.ops.rnn_cell_impl import RNNCell, BasicLSTMCell
|
||
|
from tensorflow.python.ops.rnn_cell_impl import LSTMStateTuple
|
||
|
from tensorflow.contrib.rnn.python.ops import core_rnn_cell
|
||
|
from tensorflow.python.ops import array_ops
|
||
|
from tensorflow.python.util import nest
|
||
|
import tensorflow.contrib.seq2seq as seq2seq
|
||
|
from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
|
||
|
import numpy as np
|
||
|
import os
|
||
|
import argparse
|
||
|
import time
|
||
|
|
||
|
import paddle.v2 as paddle
|
||
|
|
||
|
parser = argparse.ArgumentParser(description=__doc__)
|
||
|
parser.add_argument(
|
||
|
"--embedding_dim",
|
||
|
type=int,
|
||
|
default=512,
|
||
|
help="The dimension of embedding table. (default: %(default)d)")
|
||
|
parser.add_argument(
|
||
|
"--encoder_size",
|
||
|
type=int,
|
||
|
default=512,
|
||
|
help="The size of encoder bi-rnn unit. (default: %(default)d)")
|
||
|
parser.add_argument(
|
||
|
"--decoder_size",
|
||
|
type=int,
|
||
|
default=512,
|
||
|
help="The size of decoder rnn unit. (default: %(default)d)")
|
||
|
parser.add_argument(
|
||
|
"--batch_size",
|
||
|
type=int,
|
||
|
default=128,
|
||
|
help="The sequence number of a mini-batch data. (default: %(default)d)")
|
||
|
parser.add_argument(
|
||
|
"--dict_size",
|
||
|
type=int,
|
||
|
default=30000,
|
||
|
help="The dictionary capacity. Dictionaries of source sequence and "
|
||
|
"target dictionary have same capacity. (default: %(default)d)")
|
||
|
parser.add_argument(
|
||
|
"--max_time_steps",
|
||
|
type=int,
|
||
|
default=81,
|
||
|
help="Max number of time steps for sequence. (default: %(default)d)")
|
||
|
parser.add_argument(
|
||
|
"--pass_num",
|
||
|
type=int,
|
||
|
default=10,
|
||
|
help="The pass number to train. (default: %(default)d)")
|
||
|
parser.add_argument(
|
||
|
"--learning_rate",
|
||
|
type=float,
|
||
|
default=0.0002,
|
||
|
help="Learning rate used to train the model. (default: %(default)f)")
|
||
|
parser.add_argument(
|
||
|
"--infer_only", action='store_true', help="If set, run forward only.")
|
||
|
parser.add_argument(
|
||
|
"--beam_size",
|
||
|
type=int,
|
||
|
default=3,
|
||
|
help="The width for beam searching. (default: %(default)d)")
|
||
|
parser.add_argument(
|
||
|
"--max_generation_length",
|
||
|
type=int,
|
||
|
default=250,
|
||
|
help="The maximum length of sequence when doing generation. "
|
||
|
"(default: %(default)d)")
|
||
|
parser.add_argument(
|
||
|
"--save_freq",
|
||
|
type=int,
|
||
|
default=500,
|
||
|
help="Save model checkpoint every this interation. (default: %(default)d)")
|
||
|
parser.add_argument(
|
||
|
"--model_dir",
|
||
|
type=str,
|
||
|
default='./checkpoint',
|
||
|
help="Path to save model checkpoints. (default: %(default)d)")
|
||
|
|
||
|
_Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
|
||
|
|
||
|
START_TOKEN_IDX = 0
|
||
|
END_TOKEN_IDX = 1
|
||
|
|
||
|
|
||
|
class LSTMCellWithSimpleAttention(RNNCell):
|
||
|
"""Add attention mechanism to BasicLSTMCell.
|
||
|
This class is a wrapper based on tensorflow's `BasicLSTMCell`.
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
num_units,
|
||
|
encoder_vector,
|
||
|
encoder_proj,
|
||
|
source_sequence_length,
|
||
|
forget_bias=1.0,
|
||
|
state_is_tuple=True,
|
||
|
activation=None,
|
||
|
reuse=None):
|
||
|
super(LSTMCellWithSimpleAttention, self).__init__(_reuse=reuse)
|
||
|
if not state_is_tuple:
|
||
|
logging.warn("%s: Using a concatenated state is slower and will "
|
||
|
"soon be deprecated. Use state_is_tuple=True.", self)
|
||
|
self._num_units = num_units
|
||
|
# set padding part to 0
|
||
|
self._encoder_vector = self._reset_padding(encoder_vector,
|
||
|
source_sequence_length)
|
||
|
self._encoder_proj = self._reset_padding(encoder_proj,
|
||
|
source_sequence_length)
|
||
|
self._forget_bias = forget_bias
|
||
|
self._state_is_tuple = state_is_tuple
|
||
|
self._activation = activation or math_ops.tanh
|
||
|
self._linear = None
|
||
|
|
||
|
@property
|
||
|
def state_size(self):
|
||
|
return (LSTMStateTuple(self._num_units, self._num_units) \
|
||
|
if self._state_is_tuple else 2 * self._num_units)
|
||
|
|
||
|
@property
|
||
|
def output_size(self):
|
||
|
return self._num_units
|
||
|
|
||
|
def zero_state(self, batch_size, dtype):
|
||
|
state_size = self.state_size
|
||
|
if hasattr(self, "_last_zero_state"):
|
||
|
(last_state_size, last_batch_size, last_dtype,
|
||
|
last_output) = getattr(self, "_last_zero_state")
|
||
|
if (last_batch_size == batch_size and last_dtype == dtype and
|
||
|
last_state_size == state_size):
|
||
|
return last_output
|
||
|
with ops.name_scope(
|
||
|
type(self).__name__ + "ZeroState", values=[batch_size]):
|
||
|
output = _zero_state_tensors(state_size, batch_size, dtype)
|
||
|
self._last_zero_state = (state_size, batch_size, dtype, output)
|
||
|
return output
|
||
|
|
||
|
def call(self, inputs, state):
|
||
|
sigmoid = math_ops.sigmoid
|
||
|
# Parameters of gates are concatenated into one multiply for efficiency.
|
||
|
if self._state_is_tuple:
|
||
|
c, h = state
|
||
|
else:
|
||
|
c, h = array_ops.split(value=state, num_or_size_splits=2, axis=1)
|
||
|
|
||
|
# get context from encoder outputs
|
||
|
context = self._simple_attention(self._encoder_vector,
|
||
|
self._encoder_proj, h)
|
||
|
|
||
|
if self._linear is None:
|
||
|
self._linear = _Linear([inputs, context, h], 4 * self._num_units,
|
||
|
True)
|
||
|
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
|
||
|
i, j, f, o = array_ops.split(
|
||
|
value=self._linear([inputs, context, h]),
|
||
|
num_or_size_splits=4,
|
||
|
axis=1)
|
||
|
|
||
|
new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
|
||
|
self._activation(j))
|
||
|
new_h = self._activation(new_c) * sigmoid(o)
|
||
|
|
||
|
if self._state_is_tuple:
|
||
|
new_state = LSTMStateTuple(new_c, new_h)
|
||
|
else:
|
||
|
new_state = array_ops.concat([new_c, new_h], 1)
|
||
|
return new_h, new_state
|
||
|
|
||
|
def _simple_attention(self, encoder_vec, encoder_proj, decoder_state):
|
||
|
"""Implement the attention function.
|
||
|
The implementation has the same logic to the fluid decoder.
|
||
|
"""
|
||
|
decoder_state_proj = tf.contrib.layers.fully_connected(
|
||
|
inputs=decoder_state,
|
||
|
num_outputs=self._num_units,
|
||
|
activation_fn=None,
|
||
|
biases_initializer=None)
|
||
|
decoder_state_expand = tf.tile(
|
||
|
tf.expand_dims(
|
||
|
input=decoder_state_proj, axis=1),
|
||
|
[1, tf.shape(encoder_proj)[1], 1])
|
||
|
concated = tf.concat([decoder_state_expand, encoder_proj], axis=2)
|
||
|
# need reduce the first dimension
|
||
|
attention_weights = tf.contrib.layers.fully_connected(
|
||
|
inputs=tf.reshape(
|
||
|
concated, shape=[-1, self._num_units * 2]),
|
||
|
num_outputs=1,
|
||
|
activation_fn=tf.nn.tanh,
|
||
|
biases_initializer=None)
|
||
|
attention_weights_reshaped = tf.reshape(
|
||
|
attention_weights, shape=[tf.shape(encoder_vec)[0], -1, 1])
|
||
|
# normalize the attention weights using softmax
|
||
|
attention_weights_normed = tf.nn.softmax(
|
||
|
attention_weights_reshaped, dim=1)
|
||
|
scaled = tf.multiply(attention_weights_normed, encoder_vec)
|
||
|
context = tf.reduce_sum(scaled, axis=1)
|
||
|
return context
|
||
|
|
||
|
def _reset_padding(self,
|
||
|
memory,
|
||
|
memory_sequence_length,
|
||
|
check_inner_dims_defined=True):
|
||
|
"""Reset the padding part for encoder inputs.
|
||
|
This funtion comes from tensorflow's `_prepare_memory` function.
|
||
|
"""
|
||
|
memory = nest.map_structure(
|
||
|
lambda m: ops.convert_to_tensor(m, name="memory"), memory)
|
||
|
if memory_sequence_length is not None:
|
||
|
memory_sequence_length = ops.convert_to_tensor(
|
||
|
memory_sequence_length, name="memory_sequence_length")
|
||
|
if check_inner_dims_defined:
|
||
|
|
||
|
def _check_dims(m):
|
||
|
if not m.get_shape()[2:].is_fully_defined():
|
||
|
raise ValueError(
|
||
|
"Expected memory %s to have fully defined inner dims, "
|
||
|
"but saw shape: %s" % (m.name, m.get_shape()))
|
||
|
|
||
|
nest.map_structure(_check_dims, memory)
|
||
|
if memory_sequence_length is None:
|
||
|
seq_len_mask = None
|
||
|
else:
|
||
|
seq_len_mask = array_ops.sequence_mask(
|
||
|
memory_sequence_length,
|
||
|
maxlen=array_ops.shape(nest.flatten(memory)[0])[1],
|
||
|
dtype=nest.flatten(memory)[0].dtype)
|
||
|
seq_len_batch_size = (memory_sequence_length.shape[0].value or
|
||
|
array_ops.shape(memory_sequence_length)[0])
|
||
|
|
||
|
def _maybe_mask(m, seq_len_mask):
|
||
|
rank = m.get_shape().ndims
|
||
|
rank = rank if rank is not None else array_ops.rank(m)
|
||
|
extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32)
|
||
|
m_batch_size = m.shape[0].value or array_ops.shape(m)[0]
|
||
|
if memory_sequence_length is not None:
|
||
|
message = ("memory_sequence_length and memory tensor "
|
||
|
"batch sizes do not match.")
|
||
|
with ops.control_dependencies([
|
||
|
check_ops.assert_equal(
|
||
|
seq_len_batch_size, m_batch_size, message=message)
|
||
|
]):
|
||
|
seq_len_mask = array_ops.reshape(
|
||
|
seq_len_mask,
|
||
|
array_ops.concat(
|
||
|
(array_ops.shape(seq_len_mask), extra_ones), 0))
|
||
|
return m * seq_len_mask
|
||
|
else:
|
||
|
return m
|
||
|
|
||
|
return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask),
|
||
|
memory)
|
||
|
|
||
|
|
||
|
def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
|
||
|
target_dict_dim, is_generating, beam_size,
|
||
|
max_generation_length):
|
||
|
src_word_idx = tf.placeholder(tf.int32, shape=[None, None])
|
||
|
src_sequence_length = tf.placeholder(tf.int32, shape=[None, ])
|
||
|
|
||
|
src_embedding_weights = tf.get_variable("source_word_embeddings",
|
||
|
[source_dict_dim, embedding_dim])
|
||
|
src_embedding = tf.nn.embedding_lookup(src_embedding_weights, src_word_idx)
|
||
|
|
||
|
src_forward_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
|
||
|
src_reversed_cell = tf.nn.rnn_cell.BasicLSTMCell(encoder_size)
|
||
|
# no peephole
|
||
|
encoder_outputs, _ = tf.nn.bidirectional_dynamic_rnn(
|
||
|
cell_fw=src_forward_cell,
|
||
|
cell_bw=src_reversed_cell,
|
||
|
inputs=src_embedding,
|
||
|
sequence_length=src_sequence_length,
|
||
|
dtype=tf.float32)
|
||
|
|
||
|
# concat the forward outputs and backward outputs
|
||
|
encoded_vec = tf.concat(encoder_outputs, axis=2)
|
||
|
|
||
|
# project the encoder outputs to size of decoder lstm
|
||
|
encoded_proj = tf.contrib.layers.fully_connected(
|
||
|
inputs=tf.reshape(
|
||
|
encoded_vec, shape=[-1, embedding_dim * 2]),
|
||
|
num_outputs=decoder_size,
|
||
|
activation_fn=None,
|
||
|
biases_initializer=None)
|
||
|
encoded_proj_reshape = tf.reshape(
|
||
|
encoded_proj, shape=[-1, tf.shape(encoded_vec)[1], decoder_size])
|
||
|
|
||
|
# get init state for decoder lstm's H
|
||
|
backword_first = tf.slice(encoder_outputs[1], [0, 0, 0], [-1, 1, -1])
|
||
|
decoder_boot = tf.contrib.layers.fully_connected(
|
||
|
inputs=tf.reshape(
|
||
|
backword_first, shape=[-1, embedding_dim]),
|
||
|
num_outputs=decoder_size,
|
||
|
activation_fn=tf.nn.tanh,
|
||
|
biases_initializer=None)
|
||
|
|
||
|
# prepare the initial state for decoder lstm
|
||
|
cell_init = tf.zeros(tf.shape(decoder_boot), tf.float32)
|
||
|
initial_state = LSTMStateTuple(cell_init, decoder_boot)
|
||
|
|
||
|
# create decoder lstm cell
|
||
|
decoder_cell = LSTMCellWithSimpleAttention(
|
||
|
decoder_size,
|
||
|
encoded_vec
|
||
|
if not is_generating else seq2seq.tile_batch(encoded_vec, beam_size),
|
||
|
encoded_proj_reshape if not is_generating else
|
||
|
seq2seq.tile_batch(encoded_proj_reshape, beam_size),
|
||
|
src_sequence_length if not is_generating else
|
||
|
seq2seq.tile_batch(src_sequence_length, beam_size),
|
||
|
forget_bias=0.0)
|
||
|
|
||
|
output_layer = Dense(target_dict_dim, name='output_projection')
|
||
|
|
||
|
if not is_generating:
|
||
|
trg_word_idx = tf.placeholder(tf.int32, shape=[None, None])
|
||
|
trg_sequence_length = tf.placeholder(tf.int32, shape=[None, ])
|
||
|
trg_embedding_weights = tf.get_variable(
|
||
|
"target_word_embeddings", [target_dict_dim, embedding_dim])
|
||
|
trg_embedding = tf.nn.embedding_lookup(trg_embedding_weights,
|
||
|
trg_word_idx)
|
||
|
|
||
|
training_helper = seq2seq.TrainingHelper(
|
||
|
inputs=trg_embedding,
|
||
|
sequence_length=trg_sequence_length,
|
||
|
time_major=False,
|
||
|
name='training_helper')
|
||
|
|
||
|
training_decoder = seq2seq.BasicDecoder(
|
||
|
cell=decoder_cell,
|
||
|
helper=training_helper,
|
||
|
initial_state=initial_state,
|
||
|
output_layer=output_layer)
|
||
|
|
||
|
# get the max length of target sequence
|
||
|
max_decoder_length = tf.reduce_max(trg_sequence_length)
|
||
|
|
||
|
decoder_outputs_train, _, _ = seq2seq.dynamic_decode(
|
||
|
decoder=training_decoder,
|
||
|
output_time_major=False,
|
||
|
impute_finished=True,
|
||
|
maximum_iterations=max_decoder_length)
|
||
|
|
||
|
decoder_logits_train = tf.identity(decoder_outputs_train.rnn_output)
|
||
|
decoder_pred_train = tf.argmax(
|
||
|
decoder_logits_train, axis=-1, name='decoder_pred_train')
|
||
|
masks = tf.sequence_mask(
|
||
|
lengths=trg_sequence_length,
|
||
|
maxlen=max_decoder_length,
|
||
|
dtype=tf.float32,
|
||
|
name='masks')
|
||
|
|
||
|
# place holder of label sequence
|
||
|
lbl_word_idx = tf.placeholder(tf.int32, shape=[None, None])
|
||
|
|
||
|
# compute the loss
|
||
|
loss = seq2seq.sequence_loss(
|
||
|
logits=decoder_logits_train,
|
||
|
targets=lbl_word_idx,
|
||
|
weights=masks,
|
||
|
average_across_timesteps=True,
|
||
|
average_across_batch=True)
|
||
|
|
||
|
# return feeding list and loss operator
|
||
|
return {
|
||
|
'src_word_idx': src_word_idx,
|
||
|
'src_sequence_length': src_sequence_length,
|
||
|
'trg_word_idx': trg_word_idx,
|
||
|
'trg_sequence_length': trg_sequence_length,
|
||
|
'lbl_word_idx': lbl_word_idx
|
||
|
}, loss
|
||
|
else:
|
||
|
start_tokens = tf.ones([tf.shape(src_word_idx)[0], ],
|
||
|
tf.int32) * START_TOKEN_IDX
|
||
|
# share the same embedding weights with target word
|
||
|
trg_embedding_weights = tf.get_variable(
|
||
|
"target_word_embeddings", [target_dict_dim, embedding_dim])
|
||
|
|
||
|
inference_decoder = beam_search_decoder.BeamSearchDecoder(
|
||
|
cell=decoder_cell,
|
||
|
embedding=lambda tokens: tf.nn.embedding_lookup(trg_embedding_weights, tokens),
|
||
|
start_tokens=start_tokens,
|
||
|
end_token=END_TOKEN_IDX,
|
||
|
initial_state=tf.nn.rnn_cell.LSTMStateTuple(
|
||
|
tf.contrib.seq2seq.tile_batch(initial_state[0], beam_size),
|
||
|
tf.contrib.seq2seq.tile_batch(initial_state[1], beam_size)),
|
||
|
beam_width=beam_size,
|
||
|
output_layer=output_layer)
|
||
|
|
||
|
decoder_outputs_decode, _, _ = seq2seq.dynamic_decode(
|
||
|
decoder=inference_decoder,
|
||
|
output_time_major=False,
|
||
|
#impute_finished=True,# error occurs
|
||
|
maximum_iterations=max_generation_length)
|
||
|
|
||
|
predicted_ids = decoder_outputs_decode.predicted_ids
|
||
|
|
||
|
return {
|
||
|
'src_word_idx': src_word_idx,
|
||
|
'src_sequence_length': src_sequence_length
|
||
|
}, predicted_ids
|
||
|
|
||
|
|
||
|
def print_arguments(args):
|
||
|
print('----------- Configuration Arguments -----------')
|
||
|
for arg, value in vars(args).iteritems():
|
||
|
print('%s: %s' % (arg, value))
|
||
|
print('------------------------------------------------')
|
||
|
|
||
|
|
||
|
def padding_data(data, padding_size, value):
|
||
|
data = data + [value] * padding_size
|
||
|
return data[:padding_size]
|
||
|
|
||
|
|
||
|
def save(sess, path, var_list=None, global_step=None):
|
||
|
saver = tf.train.Saver(var_list)
|
||
|
save_path = saver.save(sess, save_path=path, global_step=global_step)
|
||
|
print('Model save at %s' % save_path)
|
||
|
|
||
|
|
||
|
def restore(sess, path, var_list=None):
|
||
|
# var_list = None returns the list of all saveable variables
|
||
|
saver = tf.train.Saver(var_list)
|
||
|
saver.restore(sess, save_path=path)
|
||
|
print('model restored from %s' % path)
|
||
|
|
||
|
|
||
|
def adapt_batch_data(data):
|
||
|
src_seq = map(lambda x: x[0], data)
|
||
|
trg_seq = map(lambda x: x[1], data)
|
||
|
lbl_seq = map(lambda x: x[2], data)
|
||
|
|
||
|
src_sequence_length = np.array(
|
||
|
[len(seq) for seq in src_seq]).astype('int32')
|
||
|
src_seq_maxlen = np.max(src_sequence_length)
|
||
|
|
||
|
trg_sequence_length = np.array(
|
||
|
[len(seq) for seq in trg_seq]).astype('int32')
|
||
|
trg_seq_maxlen = np.max(trg_sequence_length)
|
||
|
|
||
|
src_seq = np.array(
|
||
|
[padding_data(seq, src_seq_maxlen, END_TOKEN_IDX)
|
||
|
for seq in src_seq]).astype('int32')
|
||
|
|
||
|
trg_seq = np.array(
|
||
|
[padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
|
||
|
for seq in trg_seq]).astype('int32')
|
||
|
|
||
|
lbl_seq = np.array(
|
||
|
[padding_data(seq, trg_seq_maxlen, END_TOKEN_IDX)
|
||
|
for seq in lbl_seq]).astype('int32')
|
||
|
|
||
|
return {
|
||
|
'src_word_idx': src_seq,
|
||
|
'src_sequence_length': src_sequence_length,
|
||
|
'trg_word_idx': trg_seq,
|
||
|
'trg_sequence_length': trg_sequence_length,
|
||
|
'lbl_word_idx': lbl_seq
|
||
|
}
|
||
|
|
||
|
|
||
|
def train():
|
||
|
feeding_dict, loss = seq_to_seq_net(
|
||
|
embedding_dim=args.embedding_dim,
|
||
|
encoder_size=args.encoder_size,
|
||
|
decoder_size=args.decoder_size,
|
||
|
source_dict_dim=args.dict_size,
|
||
|
target_dict_dim=args.dict_size,
|
||
|
is_generating=False,
|
||
|
beam_size=args.beam_size,
|
||
|
max_generation_length=args.max_generation_length)
|
||
|
|
||
|
global_step = tf.Variable(0, trainable=False, name='global_step')
|
||
|
trainable_params = tf.trainable_variables()
|
||
|
optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
|
||
|
|
||
|
gradients = tf.gradients(loss, trainable_params)
|
||
|
# may clip the parameters
|
||
|
clip_gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
|
||
|
|
||
|
updates = optimizer.apply_gradients(
|
||
|
zip(gradients, trainable_params), global_step=global_step)
|
||
|
|
||
|
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(args.dict_size)
|
||
|
|
||
|
train_batch_generator = paddle.batch(
|
||
|
paddle.reader.shuffle(
|
||
|
paddle.dataset.wmt14.train(args.dict_size), buf_size=1000),
|
||
|
batch_size=args.batch_size)
|
||
|
|
||
|
test_batch_generator = paddle.batch(
|
||
|
paddle.reader.shuffle(
|
||
|
paddle.dataset.wmt14.test(args.dict_size), buf_size=1000),
|
||
|
batch_size=args.batch_size)
|
||
|
|
||
|
def do_validataion():
|
||
|
total_loss = 0.0
|
||
|
count = 0
|
||
|
for batch_id, data in enumerate(test_batch_generator()):
|
||
|
adapted_batch_data = adapt_batch_data(data)
|
||
|
outputs = sess.run([loss],
|
||
|
feed_dict={
|
||
|
item[1]: adapted_batch_data[item[0]]
|
||
|
for item in feeding_dict.items()
|
||
|
})
|
||
|
total_loss += outputs[0]
|
||
|
count += 1
|
||
|
return total_loss / count
|
||
|
|
||
|
config = tf.ConfigProto(
|
||
|
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
|
||
|
config.gpu_options.allow_growth = True
|
||
|
|
||
|
with tf.Session(config=config) as sess:
|
||
|
init_g = tf.global_variables_initializer()
|
||
|
init_l = tf.local_variables_initializer()
|
||
|
sess.run(init_l)
|
||
|
sess.run(init_g)
|
||
|
for pass_id in xrange(args.pass_num):
|
||
|
pass_start_time = time.time()
|
||
|
words_seen = 0
|
||
|
for batch_id, data in enumerate(train_batch_generator()):
|
||
|
adapted_batch_data = adapt_batch_data(data)
|
||
|
words_seen += np.sum(adapted_batch_data['src_sequence_length'])
|
||
|
words_seen += np.sum(adapted_batch_data['trg_sequence_length'])
|
||
|
outputs = sess.run([updates, loss],
|
||
|
feed_dict={
|
||
|
item[1]: adapted_batch_data[item[0]]
|
||
|
for item in feeding_dict.items()
|
||
|
})
|
||
|
print("pass_id=%d, batch_id=%d, train_loss: %f" %
|
||
|
(pass_id, batch_id, outputs[1]))
|
||
|
pass_end_time = time.time()
|
||
|
test_loss = do_validataion()
|
||
|
time_consumed = pass_end_time - pass_start_time
|
||
|
words_per_sec = words_seen / time_consumed
|
||
|
print("pass_id=%d, test_loss: %f, words/s: %f, sec/pass: %f" %
|
||
|
(pass_id, test_loss, words_per_sec, time_consumed))
|
||
|
|
||
|
|
||
|
def infer():
|
||
|
feeding_dict, predicted_ids = seq_to_seq_net(
|
||
|
embedding_dim=args.embedding_dim,
|
||
|
encoder_size=args.encoder_size,
|
||
|
decoder_size=args.decoder_size,
|
||
|
source_dict_dim=args.dict_size,
|
||
|
target_dict_dim=args.dict_size,
|
||
|
is_generating=True,
|
||
|
beam_size=args.beam_size,
|
||
|
max_generation_length=args.max_generation_length)
|
||
|
|
||
|
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(args.dict_size)
|
||
|
test_batch_generator = paddle.batch(
|
||
|
paddle.reader.shuffle(
|
||
|
paddle.dataset.wmt14.train(args.dict_size), buf_size=1000),
|
||
|
batch_size=args.batch_size)
|
||
|
|
||
|
config = tf.ConfigProto(
|
||
|
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
|
||
|
with tf.Session(config=config) as sess:
|
||
|
restore(sess, './checkpoint/tf_seq2seq-1500')
|
||
|
for batch_id, data in enumerate(test_batch_generator()):
|
||
|
src_seq = map(lambda x: x[0], data)
|
||
|
|
||
|
source_language_seq = [
|
||
|
src_dict[item] for seq in src_seq for item in seq
|
||
|
]
|
||
|
|
||
|
src_sequence_length = np.array(
|
||
|
[len(seq) for seq in src_seq]).astype('int32')
|
||
|
src_seq_maxlen = np.max(src_sequence_length)
|
||
|
src_seq = np.array([
|
||
|
padding_data(seq, src_seq_maxlen, END_TOKEN_IDX)
|
||
|
for seq in src_seq
|
||
|
]).astype('int32')
|
||
|
|
||
|
outputs = sess.run([predicted_ids],
|
||
|
feed_dict={
|
||
|
feeding_dict['src_word_idx']: src_seq,
|
||
|
feeding_dict['src_sequence_length']:
|
||
|
src_sequence_length
|
||
|
})
|
||
|
|
||
|
print("\nDecoder result comparison: ")
|
||
|
source_language_seq = ' '.join(source_language_seq).lstrip(
|
||
|
'<s>').rstrip('<e>').strip()
|
||
|
inference_seq = ''
|
||
|
print(" --> source: " + source_language_seq)
|
||
|
for item in outputs[0][0]:
|
||
|
if item[0] == END_TOKEN_IDX: break
|
||
|
inference_seq += ' ' + trg_dict.get(item[0], '<unk>')
|
||
|
print(" --> inference: " + inference_seq)
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
args = parser.parse_args()
|
||
|
print_arguments(args)
|
||
|
if args.infer_only:
|
||
|
infer()
|
||
|
else:
|
||
|
train()
|