|
|
|
@ -17,6 +17,7 @@ import paddle.v2.fluid as fluid
|
|
|
|
|
import paddle.v2 as paddle
|
|
|
|
|
import sys
|
|
|
|
|
import numpy
|
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_arg():
|
|
|
|
@ -74,18 +75,18 @@ def conv_net(img, label):
|
|
|
|
|
return loss_net(conv_pool_2, label)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(args, save_dirname=None):
|
|
|
|
|
print("recognize digits with args: {0}".format(" ".join(sys.argv[1:])))
|
|
|
|
|
|
|
|
|
|
def train(nn_type, use_cuda, parallel, save_dirname):
|
|
|
|
|
if use_cuda and not fluid.core.is_compiled_with_cuda():
|
|
|
|
|
return
|
|
|
|
|
img = fluid.layers.data(name='img', shape=[1, 28, 28], dtype='float32')
|
|
|
|
|
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
|
|
|
|
|
|
|
|
|
|
if args.nn_type == 'mlp':
|
|
|
|
|
if nn_type == 'mlp':
|
|
|
|
|
net_conf = mlp
|
|
|
|
|
else:
|
|
|
|
|
net_conf = conv_net
|
|
|
|
|
|
|
|
|
|
if args.parallel:
|
|
|
|
|
if parallel:
|
|
|
|
|
places = fluid.layers.get_places()
|
|
|
|
|
pd = fluid.layers.ParallelDo(places)
|
|
|
|
|
with pd.do():
|
|
|
|
@ -107,7 +108,7 @@ def train(args, save_dirname=None):
|
|
|
|
|
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
|
|
|
|
|
optimizer.minimize(avg_loss)
|
|
|
|
|
|
|
|
|
|
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
|
|
|
|
|
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
|
|
|
|
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
exe.run(fluid.default_startup_program())
|
|
|
|
@ -147,13 +148,14 @@ def train(args, save_dirname=None):
|
|
|
|
|
'PassID {0:1}, BatchID {1:04}, Test Loss {2:2.2}, Acc {3:2.2}'.
|
|
|
|
|
format(pass_id, batch_id + 1,
|
|
|
|
|
float(avg_loss_val), float(acc_val)))
|
|
|
|
|
raise AssertionError("Loss of recognize digits is too large")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer(args, save_dirname=None):
|
|
|
|
|
def infer(use_cuda, save_dirname=None):
|
|
|
|
|
if save_dirname is None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
|
|
|
|
|
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
|
|
|
|
exe = fluid.Executor(place)
|
|
|
|
|
|
|
|
|
|
# Use fluid.io.load_inference_model to obtain the inference program desc,
|
|
|
|
@ -174,11 +176,48 @@ def infer(args, save_dirname=None):
|
|
|
|
|
print("infer results: ", results[0])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
args = parse_arg()
|
|
|
|
|
if not args.use_cuda and not args.parallel:
|
|
|
|
|
save_dirname = "recognize_digits_" + args.nn_type + ".inference.model"
|
|
|
|
|
def main(use_cuda, parallel, nn_type):
|
|
|
|
|
if not use_cuda and not parallel:
|
|
|
|
|
save_dirname = "recognize_digits_" + nn_type + ".inference.model"
|
|
|
|
|
else:
|
|
|
|
|
save_dirname = None
|
|
|
|
|
train(args, save_dirname)
|
|
|
|
|
infer(args, save_dirname)
|
|
|
|
|
|
|
|
|
|
train(
|
|
|
|
|
nn_type=nn_type,
|
|
|
|
|
use_cuda=use_cuda,
|
|
|
|
|
parallel=parallel,
|
|
|
|
|
save_dirname=save_dirname)
|
|
|
|
|
infer(use_cuda=use_cuda, save_dirname=save_dirname)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestRecognizeDigits(unittest.TestCase):
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inject_test_method(use_cuda, parallel, nn_type):
|
|
|
|
|
def __impl__(self):
|
|
|
|
|
prog = fluid.Program()
|
|
|
|
|
startup_prog = fluid.Program()
|
|
|
|
|
scope = fluid.core.Scope()
|
|
|
|
|
with fluid.scope_guard(scope):
|
|
|
|
|
with fluid.program_guard(prog, startup_prog):
|
|
|
|
|
main(use_cuda, parallel, nn_type)
|
|
|
|
|
|
|
|
|
|
fn = 'test_{0}_{1}_{2}'.format(nn_type, 'cuda'
|
|
|
|
|
if use_cuda else 'cpu', 'parallel'
|
|
|
|
|
if parallel else 'normal')
|
|
|
|
|
|
|
|
|
|
setattr(TestRecognizeDigits, fn, __impl__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inject_all_tests():
|
|
|
|
|
for use_cuda in (False, True):
|
|
|
|
|
for parallel in (False, True):
|
|
|
|
|
for nn_type in ('mlp', 'conv'):
|
|
|
|
|
inject_test_method(use_cuda, parallel, nn_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inject_all_tests()
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|
|
|
|
|