|
|
|
@ -16,6 +16,7 @@ import argparse
|
|
|
|
|
import paddle.v2.fluid as fluid
|
|
|
|
|
import paddle.v2 as paddle
|
|
|
|
|
import sys
|
|
|
|
|
import numpy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_arg():
|
|
|
|
@ -100,6 +101,8 @@ def main():
|
|
|
|
|
else:
|
|
|
|
|
avg_loss, acc = net_conf(img, label)
|
|
|
|
|
|
|
|
|
|
test_program = fluid.default_main_program().clone()
|
|
|
|
|
|
|
|
|
|
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
|
|
|
|
|
optimizer.minimize(avg_loss)
|
|
|
|
|
|
|
|
|
@ -112,6 +115,8 @@ def main():
|
|
|
|
|
paddle.reader.shuffle(
|
|
|
|
|
paddle.dataset.mnist.train(), buf_size=500),
|
|
|
|
|
batch_size=BATCH_SIZE)
|
|
|
|
|
test_reader = paddle.batch(
|
|
|
|
|
paddle.dataset.mnist.test(), batch_size=BATCH_SIZE)
|
|
|
|
|
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
|
|
|
|
|
|
|
|
|
|
PASS_NUM = 100
|
|
|
|
@ -119,21 +124,27 @@ def main():
|
|
|
|
|
for batch_id, data in enumerate(train_reader()):
|
|
|
|
|
need_check = (batch_id + 1) % 10 == 0
|
|
|
|
|
|
|
|
|
|
# train a mini-batch, fetch nothing
|
|
|
|
|
exe.run(feed=feeder.feed(data))
|
|
|
|
|
if need_check:
|
|
|
|
|
fetch_list = [avg_loss, acc]
|
|
|
|
|
else:
|
|
|
|
|
fetch_list = []
|
|
|
|
|
|
|
|
|
|
outs = exe.run(feed=feeder.feed(data), fetch_list=fetch_list)
|
|
|
|
|
if need_check:
|
|
|
|
|
avg_loss_np, acc_np = outs
|
|
|
|
|
if float(acc_np) > 0.9:
|
|
|
|
|
acc_set = []
|
|
|
|
|
avg_loss_set = []
|
|
|
|
|
for test_data in test_reader():
|
|
|
|
|
acc_np, avg_loss_np = exe.run(program=test_program,
|
|
|
|
|
feed=feeder.feed(test_data),
|
|
|
|
|
fetch_list=[acc, avg_loss])
|
|
|
|
|
acc_set.append(float(acc_np))
|
|
|
|
|
avg_loss_set.append(float(avg_loss_np))
|
|
|
|
|
# get test acc and loss
|
|
|
|
|
acc_val = numpy.array(acc_set).mean()
|
|
|
|
|
avg_loss_val = numpy.array(avg_loss_set).mean()
|
|
|
|
|
if float(acc_val) > 0.85: # test acc > 85%
|
|
|
|
|
exit(0)
|
|
|
|
|
else:
|
|
|
|
|
print(
|
|
|
|
|
'PassID {0:1}, BatchID {1:04}, Loss {2:2.2}, Acc {3:2.2}'.
|
|
|
|
|
'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'.
|
|
|
|
|
format(pass_id, batch_id + 1,
|
|
|
|
|
float(avg_loss_np), float(acc_np)))
|
|
|
|
|
float(avg_loss_val), float(acc_val)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|