|
|
@ -1,8 +1,7 @@
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
import cPickle as pickle
|
|
|
|
import cPickle as pickle
|
|
|
|
|
|
|
|
|
|
|
|
from paddle.v2.fluid.framework import Program, Parameter, g_main_program, \
|
|
|
|
from paddle.v2.fluid.framework import Program, Parameter, default_main_program, Variable
|
|
|
|
Variable
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
__all__ = [
|
|
|
|
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
|
|
|
|
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
|
|
|
@ -46,7 +45,7 @@ def save_vars(executor, dirname, main_program=None, vars=None, predicate=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if vars is None:
|
|
|
|
if vars is None:
|
|
|
|
if main_program is None:
|
|
|
|
if main_program is None:
|
|
|
|
main_program = g_main_program
|
|
|
|
main_program = default_main_program()
|
|
|
|
if not isinstance(main_program, Program):
|
|
|
|
if not isinstance(main_program, Program):
|
|
|
|
raise TypeError("program should be as Program type or None")
|
|
|
|
raise TypeError("program should be as Program type or None")
|
|
|
|
|
|
|
|
|
|
|
@ -98,7 +97,7 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
|
|
|
|
:param executor: executor that save variable
|
|
|
|
:param executor: executor that save variable
|
|
|
|
:param dirname: directory path
|
|
|
|
:param dirname: directory path
|
|
|
|
:param main_program: program. If vars is None, then filter all variables in this
|
|
|
|
:param main_program: program. If vars is None, then filter all variables in this
|
|
|
|
program which fit `predicate`. Default g_program.
|
|
|
|
program which fit `predicate`. Default default_main_program().
|
|
|
|
:param predicate: The Predicate describes a callable that returns a variable
|
|
|
|
:param predicate: The Predicate describes a callable that returns a variable
|
|
|
|
as a bool. If it returns true, the variables will be loaded.
|
|
|
|
as a bool. If it returns true, the variables will be loaded.
|
|
|
|
:param vars: variables need to be loaded. If specify vars, program &
|
|
|
|
:param vars: variables need to be loaded. If specify vars, program &
|
|
|
@ -107,7 +106,7 @@ def load_vars(executor, dirname, main_program=None, vars=None, predicate=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if vars is None:
|
|
|
|
if vars is None:
|
|
|
|
if main_program is None:
|
|
|
|
if main_program is None:
|
|
|
|
main_program = g_main_program
|
|
|
|
main_program = default_main_program()
|
|
|
|
if not isinstance(main_program, Program):
|
|
|
|
if not isinstance(main_program, Program):
|
|
|
|
raise TypeError("program's type should be Program")
|
|
|
|
raise TypeError("program's type should be Program")
|
|
|
|
|
|
|
|
|
|
|
@ -154,7 +153,7 @@ def load_persistables(executor, dirname, main_program=None):
|
|
|
|
|
|
|
|
|
|
|
|
def get_inference_program(target_vars, main_program=None):
|
|
|
|
def get_inference_program(target_vars, main_program=None):
|
|
|
|
if main_program is None:
|
|
|
|
if main_program is None:
|
|
|
|
main_program = g_main_program
|
|
|
|
main_program = default_main_program()
|
|
|
|
if not isinstance(target_vars, list):
|
|
|
|
if not isinstance(target_vars, list):
|
|
|
|
target_vars = [target_vars]
|
|
|
|
target_vars = [target_vars]
|
|
|
|
|
|
|
|
|
|
|
@ -177,12 +176,12 @@ def save_inference_model(dirname,
|
|
|
|
:param target_vars: Variables from which we can get inference results.
|
|
|
|
:param target_vars: Variables from which we can get inference results.
|
|
|
|
:param executor: executor that save inference model
|
|
|
|
:param executor: executor that save inference model
|
|
|
|
:param main_program: original program, which will be pruned to build the inference model.
|
|
|
|
:param main_program: original program, which will be pruned to build the inference model.
|
|
|
|
Default g_main_program.
|
|
|
|
Default default_main_program().
|
|
|
|
|
|
|
|
|
|
|
|
:return: None
|
|
|
|
:return: None
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if main_program is None:
|
|
|
|
if main_program is None:
|
|
|
|
main_program = g_main_program
|
|
|
|
main_program = default_main_program()
|
|
|
|
if not isinstance(target_vars, list):
|
|
|
|
if not isinstance(target_vars, list):
|
|
|
|
target_vars = [target_vars]
|
|
|
|
target_vars = [target_vars]
|
|
|
|
|
|
|
|
|
|
|
@ -272,10 +271,10 @@ def get_parameter_value_by_name(name, executor, program=None):
|
|
|
|
:param executor: executor for retrieving the value
|
|
|
|
:param executor: executor for retrieving the value
|
|
|
|
:param name: the name of the parameter
|
|
|
|
:param name: the name of the parameter
|
|
|
|
:param program: the program where the variable is found
|
|
|
|
:param program: the program where the variable is found
|
|
|
|
Default g_main_program.
|
|
|
|
Default default_main_program().
|
|
|
|
:return: the LoDTensor for the variable
|
|
|
|
:return: the LoDTensor for the variable
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if program is None:
|
|
|
|
if program is None:
|
|
|
|
program = g_main_program
|
|
|
|
program = default_main_program()
|
|
|
|
var = program.global_block().var(name)
|
|
|
|
var = program.global_block().var(name)
|
|
|
|
return get_parameter_value(var, executor)
|
|
|
|
return get_parameter_value(var, executor)
|
|
|
|