|
|
|
@ -17,6 +17,7 @@ from __future__ import print_function
|
|
|
|
|
import argparse
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddle
|
|
|
|
|
import six
|
|
|
|
|
import sys
|
|
|
|
|
import numpy
|
|
|
|
|
import unittest
|
|
|
|
@ -79,11 +80,22 @@ def train(use_cuda, train_program, params_dirname):
|
|
|
|
|
paddle.dataset.mnist.train(), buf_size=500),
|
|
|
|
|
batch_size=BATCH_SIZE)
|
|
|
|
|
|
|
|
|
|
if six.PY2:
|
|
|
|
|
trainer.train(
|
|
|
|
|
num_epochs=1,
|
|
|
|
|
event_handler=event_handler,
|
|
|
|
|
reader=train_reader,
|
|
|
|
|
feed_order=['img', 'label'])
|
|
|
|
|
else:
|
|
|
|
|
import paddle.fluid.core as core
|
|
|
|
|
try:
|
|
|
|
|
trainer.train(
|
|
|
|
|
num_epochs=1,
|
|
|
|
|
event_handler=event_handler,
|
|
|
|
|
reader=train_reader,
|
|
|
|
|
feed_order=['img', 'label'])
|
|
|
|
|
except core.EnforceNotMet as ex:
|
|
|
|
|
assert ("kid scope" in cpt.get_exception_message(ex))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer(use_cuda, inference_program, params_dirname=None):
|
|
|
|
|