refine test_understand_sentiment_lstm (#5781)

* fix

* Fix a bug
release/0.11.0
fengjiayi 8 years ago committed by GitHub
parent 3e9ea34821
commit f04c97a035
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -54,17 +54,17 @@ def to_lodtensor(data, place):
return res return res
def chop_data(data, chop_len=80, batch_len=50): def chop_data(data, chop_len=80, batch_size=50):
data = [(x[0][:chop_len], x[1]) for x in data if len(x[0]) >= chop_len] data = [(x[0][:chop_len], x[1]) for x in data if len(x[0]) >= chop_len]
return data[:batch_len] return data[:batch_size]
def prepare_feed_data(data, place): def prepare_feed_data(data, place):
tensor_words = to_lodtensor(map(lambda x: x[0], data), place) tensor_words = to_lodtensor(map(lambda x: x[0], data), place)
label = np.array(map(lambda x: x[1], data)).astype("int64") label = np.array(map(lambda x: x[1], data)).astype("int64")
label = label.reshape([50, 1]) label = label.reshape([len(label), 1])
tensor_label = core.LoDTensor() tensor_label = core.LoDTensor()
tensor_label.set(label, place) tensor_label.set(label, place)
@ -72,23 +72,30 @@ def prepare_feed_data(data, place):
def main(): def main():
word_dict = paddle.dataset.imdb.word_dict() BATCH_SIZE = 100
cost, acc = lstm_net(dict_dim=len(word_dict), class_dim=2) PASS_NUM = 5
batch_size = 100 word_dict = paddle.dataset.imdb.word_dict()
train_data = paddle.batch( print "load word dict successfully"
paddle.reader.buffered( dict_dim = len(word_dict)
paddle.dataset.imdb.train(word_dict), size=batch_size * 10), class_dim = 2
batch_size=batch_size)
data = chop_data(next(train_data())) cost, acc = lstm_net(dict_dim=dict_dim, class_dim=class_dim)
train_data = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.imdb.train(word_dict), buf_size=BATCH_SIZE * 10),
batch_size=BATCH_SIZE)
place = core.CPUPlace() place = core.CPUPlace()
tensor_words, tensor_label = prepare_feed_data(data, place)
exe = Executor(place) exe = Executor(place)
exe.run(framework.default_startup_program()) exe.run(framework.default_startup_program())
while True: for pass_id in xrange(PASS_NUM):
for data in train_data():
chopped_data = chop_data(data)
tensor_words, tensor_label = prepare_feed_data(chopped_data, place)
outs = exe.run(framework.default_main_program(), outs = exe.run(framework.default_main_program(),
feed={"words": tensor_words, feed={"words": tensor_words,
"label": tensor_label}, "label": tensor_label},
@ -97,8 +104,9 @@ def main():
acc_val = np.array(outs[1]) acc_val = np.array(outs[1])
print("cost=" + str(cost_val) + " acc=" + str(acc_val)) print("cost=" + str(cost_val) + " acc=" + str(acc_val))
if acc_val > 0.9: if acc_val > 0.7:
break exit(0)
exit(1)
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save