|
|
|
@ -107,8 +107,7 @@ def train(use_cuda, train_program, parallel, params_dirname):
|
|
|
|
|
event_handler=event_handler,
|
|
|
|
|
feed_order=['pixel', 'label'])
|
|
|
|
|
|
|
|
|
|
if six.PY3:
|
|
|
|
|
del trainer
|
|
|
|
|
return trainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer(use_cuda, inference_program, parallel, params_dirname=None):
|
|
|
|
@ -132,12 +131,15 @@ def main(use_cuda, parallel):
|
|
|
|
|
save_path = "image_classification_vgg.inference.model"
|
|
|
|
|
|
|
|
|
|
os.environ['CPU_NUM'] = str(4)
|
|
|
|
|
train(
|
|
|
|
|
trainer = train(
|
|
|
|
|
use_cuda=use_cuda,
|
|
|
|
|
train_program=train_network,
|
|
|
|
|
params_dirname=save_path,
|
|
|
|
|
parallel=parallel)
|
|
|
|
|
|
|
|
|
|
if six.PY3:
|
|
|
|
|
del trainer
|
|
|
|
|
|
|
|
|
|
# FIXME(zcd): in the inference stage, the number of
|
|
|
|
|
# input data is one, it is not appropriate to use parallel.
|
|
|
|
|
if parallel and use_cuda:
|
|
|
|
|