You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
221 lines
7.8 KiB
221 lines
7.8 KiB
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import argparse
|
|
import time
|
|
import tensorflow as tf
|
|
|
|
import paddle.v2 as paddle
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser("LSTM model benchmark.")
|
|
parser.add_argument(
|
|
'--batch_size',
|
|
type=int,
|
|
default=32,
|
|
help='The sequence number of a batch data. (default: %(default)d)')
|
|
parser.add_argument(
|
|
'--stacked_num',
|
|
type=int,
|
|
default=5,
|
|
help='Number of lstm layers to stack. (default: %(default)d)')
|
|
parser.add_argument(
|
|
'--embedding_dim',
|
|
type=int,
|
|
default=512,
|
|
help='Dimension of embedding table. (default: %(default)d)')
|
|
parser.add_argument(
|
|
'--hidden_dim',
|
|
type=int,
|
|
default=512,
|
|
help='Hidden size of lstm unit. (default: %(default)d)')
|
|
parser.add_argument(
|
|
'--pass_num',
|
|
type=int,
|
|
default=10,
|
|
help='Epoch number to train. (default: %(default)d)')
|
|
parser.add_argument(
|
|
'--learning_rate',
|
|
type=float,
|
|
default=0.0002,
|
|
help='Learning rate used to train. (default: %(default)f)')
|
|
parser.add_argument(
|
|
'--infer_only', action='store_true', help='If set, run forward only.')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
def print_arguments(args):
|
|
print('----------- Configuration Arguments -----------')
|
|
for arg, value in sorted(vars(args).iteritems()):
|
|
print('%s: %s' % (arg, value))
|
|
print('------------------------------------------------')
|
|
|
|
|
|
def dynamic_lstm_model(dict_size,
|
|
embedding_dim,
|
|
hidden_dim,
|
|
stacked_num,
|
|
class_num=2,
|
|
is_train=True):
|
|
word_idx = tf.placeholder(tf.int64, shape=[None, None])
|
|
sequence_length = tf.placeholder(tf.int64, shape=[None, ])
|
|
|
|
embedding_weights = tf.get_variable('word_embeddings',
|
|
[dict_size, embedding_dim])
|
|
embedding = tf.nn.embedding_lookup(embedding_weights, word_idx)
|
|
|
|
lstm_cell = tf.nn.rnn_cell.LSTMCell(
|
|
num_units=hidden_dim, use_peepholes=False)
|
|
stacked_cell = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * stacked_num)
|
|
|
|
# final_state [LSTMTuple(c, h), LSTMTuple(c, h) ...] total stacked_num LSTMTuples
|
|
_, final_state = tf.nn.dynamic_rnn(
|
|
cell=stacked_cell,
|
|
inputs=embedding,
|
|
dtype=tf.float32,
|
|
sequence_length=sequence_length)
|
|
|
|
w = tf.Variable(
|
|
tf.truncated_normal([hidden_dim, class_num]), dtype=tf.float32)
|
|
bias = tf.Variable(
|
|
tf.constant(
|
|
value=0.0, shape=[class_num], dtype=tf.float32))
|
|
prediction = tf.matmul(final_state[-1][1], w) + bias
|
|
|
|
if not is_train:
|
|
return (word_idx, sequence_length), tf.nn.softmax(prediction)
|
|
|
|
label = tf.placeholder(tf.int64, shape=[None, ])
|
|
loss = tf.nn.softmax_cross_entropy_with_logits(
|
|
labels=tf.one_hot(label, 2), logits=prediction)
|
|
avg_loss = tf.reduce_mean(loss)
|
|
|
|
correct_count = tf.equal(tf.argmax(prediction, 1), label)
|
|
acc = tf.reduce_mean(tf.cast(correct_count, tf.float32))
|
|
|
|
with tf.variable_scope("reset_metrics_accuracy_scope") as scope:
|
|
g_acc = tf.metrics.accuracy(label, tf.argmax(prediction, axis=1))
|
|
vars = tf.contrib.framework.get_variables(
|
|
scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
|
|
reset_op = tf.variables_initializer(vars)
|
|
|
|
return (word_idx, sequence_length, label), avg_loss, acc, g_acc, reset_op
|
|
|
|
|
|
def padding_data(data, padding_size, value):
|
|
data = data + [value] * padding_size
|
|
return data[:padding_size]
|
|
|
|
|
|
def train(args):
|
|
word_dict = paddle.dataset.imdb.word_dict()
|
|
dict_size = len(word_dict)
|
|
|
|
feeding_list, avg_loss, acc, g_acc, reset_op = dynamic_lstm_model(
|
|
dict_size, args.embedding_dim, args.hidden_dim, args.stacked_num)
|
|
|
|
adam_optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
|
|
train_op = adam_optimizer.minimize(avg_loss)
|
|
|
|
train_reader = paddle.batch(
|
|
paddle.reader.shuffle(
|
|
paddle.dataset.imdb.train(word_dict), buf_size=25000),
|
|
batch_size=args.batch_size)
|
|
|
|
test_reader = paddle.batch(
|
|
paddle.reader.shuffle(
|
|
paddle.dataset.imdb.test(word_dict), buf_size=25000),
|
|
batch_size=args.batch_size)
|
|
|
|
def do_validation(sess):
|
|
sess.run(reset_op)
|
|
for batch_id, data in enumerate(test_reader()):
|
|
word_idx = map(lambda x: x[0], data)
|
|
sequence_length = np.array(
|
|
[len(seq) for seq in word_idx]).astype('int64')
|
|
maxlen = np.max(sequence_length)
|
|
word_idx = [padding_data(seq, maxlen, 0) for seq in word_idx]
|
|
word_idx = np.array(word_idx).astype('int64')
|
|
label = np.array(map(lambda x: x[1], data)).astype('int64')
|
|
|
|
_, loss, fetch_acc, fetch_g_acc = sess.run(
|
|
[train_op, avg_loss, acc, g_acc],
|
|
feed_dict={
|
|
feeding_list[0]: word_idx,
|
|
feeding_list[1]: sequence_length,
|
|
feeding_list[2]: label
|
|
})
|
|
|
|
return fetch_g_acc[1]
|
|
|
|
config = tf.ConfigProto(
|
|
intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
|
|
config.gpu_options.allow_growth = True
|
|
with tf.Session(config=config) as sess:
|
|
init_g = tf.global_variables_initializer()
|
|
init_l = tf.local_variables_initializer()
|
|
sess.run(init_l)
|
|
sess.run(init_g)
|
|
|
|
for pass_id in xrange(args.pass_num):
|
|
# clear accuracy local variable
|
|
sess.run(reset_op)
|
|
pass_start_time = time.time()
|
|
words_seen = 0
|
|
|
|
for batch_id, data in enumerate(train_reader()):
|
|
word_idx = map(lambda x: x[0], data)
|
|
sequence_length = np.array(
|
|
[len(seq) for seq in word_idx]).astype('int64')
|
|
words_seen += np.sum(sequence_length)
|
|
maxlen = np.max(sequence_length)
|
|
word_idx = [padding_data(seq, maxlen, 0) for seq in word_idx]
|
|
word_idx = np.array(word_idx).astype('int64')
|
|
label = np.array(map(lambda x: x[1], data)).astype('int64')
|
|
|
|
_, loss, fetch_acc, fetch_g_acc = sess.run(
|
|
[train_op, avg_loss, acc, g_acc],
|
|
feed_dict={
|
|
feeding_list[0]: word_idx,
|
|
feeding_list[1]: sequence_length,
|
|
feeding_list[2]: label
|
|
})
|
|
|
|
print("pass_id=%d, batch_id=%d, loss: %f, acc: %f, avg_acc: %f"
|
|
% (pass_id, batch_id, loss, fetch_acc, fetch_g_acc[1]))
|
|
|
|
pass_end_time = time.time()
|
|
time_consumed = pass_end_time - pass_start_time
|
|
words_per_sec = words_seen / time_consumed
|
|
test_acc = do_validation(sess)
|
|
print("pass_id=%d, test_acc: %f, words/s: %f, sec/pass: %f" %
|
|
(pass_id, test_acc, words_per_sec, time_consumed))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
args = parse_args()
|
|
print_arguments(args)
|
|
|
|
if args.infer_only:
|
|
pass
|
|
else:
|
|
train(args)
|