diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index e63643c29e..5b02d2495d 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -186,12 +186,16 @@ def load_persistables(executor, dirname, main_program=None): def get_inference_program(target_vars, main_program=None): if main_program is None: main_program = default_main_program() - if isinstance(target_vars, Evaluator): - target_vars = target_vars.states + target_vars.metrics if not isinstance(target_vars, list): target_vars = [target_vars] - - pruned_program = main_program.prune(targets=target_vars) + vars = [] + for var in target_vars: + if isinstance(var, Evaluator): + vars.append(var.states) + vars.append(var.metrics) + else: + vars.append(var) + pruned_program = main_program.prune(targets=vars) inference_program = pruned_program.inference_optimize() return inference_program