From 026e88a340f3f89deb27e1d8060b9fb53db32c90 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 24 Jan 2018 14:34:54 +0800 Subject: [PATCH 1/2] Make get_inference_program support for Evaluator. --- python/paddle/v2/fluid/io.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index 499df05e59..d440fab9a5 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -14,6 +14,7 @@ import os import cPickle as pickle +from paddle.v2.fluid.evaluator import Evaluator from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable __all__ = [ @@ -183,6 +184,8 @@ def load_persistables(executor, dirname, main_program=None): def get_inference_program(target_vars, main_program=None): if main_program is None: main_program = default_main_program() + if isinstance(target_vars, Evaluator): + target_vars = target_vars.states + target_vars.metrics if not isinstance(target_vars, list): target_vars = [target_vars] From a6a79c35c9f959ce24e43033a7108ff2fa7cea06 Mon Sep 17 00:00:00 2001 From: wanghaoshuang Date: Wed, 24 Jan 2018 15:47:11 +0800 Subject: [PATCH 2/2] More general implementation. --- python/paddle/v2/fluid/io.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index e63643c29e..5b02d2495d 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -186,12 +186,16 @@ def load_persistables(executor, dirname, main_program=None): def get_inference_program(target_vars, main_program=None): if main_program is None: main_program = default_main_program() - if isinstance(target_vars, Evaluator): - target_vars = target_vars.states + target_vars.metrics if not isinstance(target_vars, list): target_vars = [target_vars] - - pruned_program = main_program.prune(targets=target_vars) + vars = [] + for var in target_vars: + if isinstance(var, Evaluator): + vars.append(var.states) + vars.append(var.metrics) + else: + vars.append(var) + pruned_program = main_program.prune(targets=vars) inference_program = pruned_program.inference_optimize() return inference_program