|
|
|
@ -15,6 +15,7 @@
|
|
|
|
|
import os
|
|
|
|
|
import cPickle as pickle
|
|
|
|
|
|
|
|
|
|
from paddle.v2.fluid.evaluator import Evaluator
|
|
|
|
|
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
|
|
|
|
|
from . import core
|
|
|
|
|
|
|
|
|
@ -187,8 +188,14 @@ def get_inference_program(target_vars, main_program=None):
|
|
|
|
|
main_program = default_main_program()
|
|
|
|
|
if not isinstance(target_vars, list):
|
|
|
|
|
target_vars = [target_vars]
|
|
|
|
|
|
|
|
|
|
pruned_program = main_program.prune(targets=target_vars)
|
|
|
|
|
vars = []
|
|
|
|
|
for var in target_vars:
|
|
|
|
|
if isinstance(var, Evaluator):
|
|
|
|
|
vars.append(var.states)
|
|
|
|
|
vars.append(var.metrics)
|
|
|
|
|
else:
|
|
|
|
|
vars.append(var)
|
|
|
|
|
pruned_program = main_program.prune(targets=vars)
|
|
|
|
|
inference_program = pruned_program.inference_optimize()
|
|
|
|
|
return inference_program
|
|
|
|
|
|
|
|
|
|