|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
import os
|
|
|
|
|
import cPickle as pickle
|
|
|
|
|
|
|
|
|
|
from paddle.v2.fluid.evaluator import Evaluator
|
|
|
|
|
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
@ -183,6 +184,8 @@ def load_persistables(executor, dirname, main_program=None):
|
|
|
|
|
def get_inference_program(target_vars, main_program=None):
|
|
|
|
|
if main_program is None:
|
|
|
|
|
main_program = default_main_program()
|
|
|
|
|
if isinstance(target_vars, Evaluator):
|
|
|
|
|
target_vars = target_vars.states + target_vars.metrics
|
|
|
|
|
if not isinstance(target_vars, list):
|
|
|
|
|
target_vars = [target_vars]
|
|
|
|
|
|
|
|
|
|