Merge pull request #7816 from wanghaoshuang/infer_prog

Make get_inference_program support for Evaluator.
emailweixu-patch-1
whs 7 years ago committed by GitHub
commit ef8cb8f624
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -15,6 +15,7 @@
import os
import cPickle as pickle
from paddle.v2.fluid.evaluator import Evaluator
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
from . import core
@ -187,8 +188,14 @@ def get_inference_program(target_vars, main_program=None):
main_program = default_main_program()
if not isinstance(target_vars, list):
target_vars = [target_vars]
pruned_program = main_program.prune(targets=target_vars)
vars = []
for var in target_vars:
if isinstance(var, Evaluator):
vars.append(var.states)
vars.append(var.metrics)
else:
vars.append(var)
pruned_program = main_program.prune(targets=vars)
inference_program = pruned_program.inference_optimize()
return inference_program

Loading…
Cancel
Save