@ -1,8 +1,7 @@
import os
import cPickle as pickle
from paddle . v2 . fluid . framework import Program , Parameter , g_main_program , \
Variable
from paddle . v2 . fluid . framework import Program , Parameter , default_main_program , Variable
__all__ = [
' save_vars ' , ' save_params ' , ' save_persistables ' , ' load_vars ' , ' load_params ' ,
@ -46,7 +45,7 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
"""
if vars is None :
if main_program is None :
main_program = g_main_program
main_program = default_main_program( )
if not isinstance ( main_program , Program ) :
raise TypeError ( " program should be as Program type or None " )
@ -98,7 +97,7 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
: param executor : executor that save variable
: param dirname : directory path
: param main_program : program . If vars is None , then filter all variables in this
program which fit ` predicate ` . Default g_program .
program which fit ` predicate ` . Default default_main_program( ) .
: param predicate : The Predicate describes a callable that returns a variable
as a bool . If it returns true , the variables will be loaded .
: param vars : variables need to be loaded . If specify vars , program &
@ -107,7 +106,7 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
"""
if vars is None :
if main_program is None :
main_program = g_main_program
main_program = default_main_program( )
if not isinstance ( main_program , Program ) :
raise TypeError ( " program ' s type should be Program " )
@ -154,7 +153,7 @@ def load_persistables(executor, dirname, main_program=None):
def get_inference_program ( target_vars , main_program = None ) :
if main_program is None :
main_program = g_main_program
main_program = default_main_program( )
if not isinstance ( target_vars , list ) :
target_vars = [ target_vars ]
@ -177,12 +176,12 @@ def save_inference_model(dirname,
: param target_vars : Variables from which we can get inference results .
: param executor : executor that save inference model
: param main_program : original program , which will be pruned to build the inference model .
Default g_main_program .
Default default_main_program ( ) .
: return : None
"""
if main_program is None :
main_program = g_main_program
main_program = default_main_program( )
if not isinstance ( target_vars , list ) :
target_vars = [ target_vars ]
@ -272,10 +271,10 @@ def get_parameter_value_by_name(name, executor, program=None):
: param executor : executor for retrieving the value
: param name : the name of the parameter
: param program : the program where the variable is found
Default g_main_program .
Default default_main_program ( ) .
: return : the LoDTensor for the variable
"""
if program is None :
program = g_main_program
program = default_main_program( )
var = program . global_block ( ) . var ( name )
return get_parameter_value ( var , executor )