|
|
|
@ -14,11 +14,11 @@
|
|
|
|
|
|
|
|
|
|
from __future__ import print_function
|
|
|
|
|
|
|
|
|
|
import six
|
|
|
|
|
import paddle
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
import paddle.fluid.core as core
|
|
|
|
|
import numpy
|
|
|
|
|
import six
|
|
|
|
|
import os
|
|
|
|
|
import cifar10_small_test_set
|
|
|
|
|
|
|
|
|
@ -101,23 +101,14 @@ def train(use_cuda, train_program, parallel, params_dirname):
|
|
|
|
|
optimizer_func=optimizer_func,
|
|
|
|
|
parallel=parallel)
|
|
|
|
|
|
|
|
|
|
if six.PY2:
|
|
|
|
|
trainer.train(
|
|
|
|
|
reader=train_reader,
|
|
|
|
|
num_epochs=1,
|
|
|
|
|
event_handler=event_handler,
|
|
|
|
|
feed_order=['pixel', 'label'])
|
|
|
|
|
else:
|
|
|
|
|
import paddle.fluid.core as core
|
|
|
|
|
import paddle.compat as cpt
|
|
|
|
|
try:
|
|
|
|
|
trainer.train(
|
|
|
|
|
reader=train_reader,
|
|
|
|
|
num_epochs=1,
|
|
|
|
|
event_handler=event_handler,
|
|
|
|
|
feed_order=['pixel', 'label'])
|
|
|
|
|
except core.EnforceNotMet as ex:
|
|
|
|
|
assert ("kid scope" in cpt.get_exception_message(ex))
|
|
|
|
|
trainer.train(
|
|
|
|
|
reader=train_reader,
|
|
|
|
|
num_epochs=1,
|
|
|
|
|
event_handler=event_handler,
|
|
|
|
|
feed_order=['pixel', 'label'])
|
|
|
|
|
|
|
|
|
|
if six.PY3:
|
|
|
|
|
del trainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer(use_cuda, inference_program, parallel, params_dirname=None):
|
|
|
|
|