Load inference enhance (#21919)

* enhance load interface; test=develop

* add uni test and add comment; test=develop

* fix converage; test=develop

* use path.joint replace "/"; test=develop

* windows debug; test=develop

* fix window unitest error; test=develop

* fix commet error; test=develop

* add model shuffix check; test=develop

* fix example error; test=develop
release/1.7
hong 5 years ago committed by GitHub
parent 4c2df8e4d4
commit 9f7d90d203
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -55,6 +55,8 @@ __all__ = [
'load', 'load',
'load_program_state', 'load_program_state',
'set_program_state', 'set_program_state',
'get_program_parameter',
'get_program_persistable_vars',
] + reader.__all__ + paddle.reader.__all__ ] + reader.__all__ + paddle.reader.__all__
_logger = get_logger( _logger = get_logger(
@ -114,6 +116,50 @@ def is_belong_to_optimizer(var):
return False return False
def get_program_parameter(program):
"""
Get all the parameters from Program.
Args:
var(Program): The Program to get parameters
Returns:
list: The list contains all parameters in the program
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.data(name="img", shape=[64, 784])
w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32', name='fc_w')
b = fluid.layers.create_parameter(shape=[200], dtype='float32', name='fc_b')
list_para = fluid.io.get_program_parameter( fluid.default_main_program() )
"""
return list(filter(is_parameter, program.list_vars()))
def get_program_persistable_vars(program):
"""
Get all the persistable vars from Program.
Args:
var(Program): The Program to get persistable vars
Returns:
list: The list contains all persistable vars in the program
Examples:
.. code-block:: python
import paddle.fluid as fluid
data = fluid.data(name="img", shape=[64, 784])
w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32', name='fc_w')
b = fluid.layers.create_parameter(shape=[200], dtype='float32', name='fc_b')
list_para = fluid.io.get_program_persistable_vars( fluid.default_main_program() )
"""
return list(filter(is_persistable, program.list_vars()))
def _clone_var_in_block_(block, var): def _clone_var_in_block_(block, var):
assert isinstance(var, Variable) assert isinstance(var, Variable)
if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR: if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR:
@ -1497,16 +1543,23 @@ def save(program, model_path):
f.write(program.desc.serialize_to_string()) f.write(program.desc.serialize_to_string())
def load(program, model_path, executor=None): def load(program, model_path, executor=None, var_list=None):
""" """
This function filter out parameters and optimizer information from program, and then get corresponding value from file. This function get parameters and optimizer information from program, and then get corresponding value from file.
An exception will throw if shape or dtype of the parameters is not match. An exception will throw if shape or dtype of the parameters is not match.
This function can also load model file saved with [ save_params, save_persistables, save_vars ].
var_list can not be None when load single model file
( filename is not None When save_params, save_persistables or save_vars is called ).
Args: Args:
program(Program): The program will be loaded program(Program): The program will be loaded
model_path(str): The file prefix store the program model_path(str): The file prefix store the program
executor(Executor, optional): The executor used for initialize the parameter executor(Executor, optional): The executor used for initialize the parameter
When startup program is not run. When startup program is not run.
var_list(list, optional): The variable list to load single model file saved with
[ save_params, save_persistables, save_vars ].
Default: None
Returns: Returns:
None None
@ -1525,9 +1578,85 @@ def load(program, model_path, executor=None):
assert executor is None or isinstance(executor, Executor) assert executor is None or isinstance(executor, Executor)
parameter_file_name = model_path + ".pdparams" model_prefix = model_path
assert os.path.exists(parameter_file_name), \ if model_prefix.endswith(".pdparams"):
"Parameter file [{}] not exits".format(parameter_file_name) model_prefix = model_prefix[:-9]
elif model_prefix.endswith(".pdopt"):
model_prefix = model_prefix[:-6]
elif model_prefix.endswith(".pdmodel"):
model_prefix = model_prefix[:-8]
parameter_file_name = model_prefix + ".pdparams"
if not os.path.exists(parameter_file_name):
# model file save by fluid.save not found, try to load model file saved with
# [save_vars, save_params, save_persistables]
_logger.warning(
"{} not found, try to load model file saved with [ save_params, save_persistables, save_vars ]".
format(parameter_file_name))
if executor is None:
raise ValueError(
"executor is required when loading model file saved with [ save_params, save_persistables, save_vars ]"
)
if os.path.isdir(model_path):
binary_file_set = set()
for root, dirs, files in os.walk(model_path, topdown=False):
for f in files:
binary_file_set.add(
os.path.join(root, f).replace("\\", "/"))
program_var_list = list(program.list_vars())
loaded_var_list = []
for var in program_var_list:
var_path = os.path.join(model_path, var.name).replace("\\", "/")
if var_path in binary_file_set:
loaded_var_list.append(var)
binary_file_set.remove(var_path)
if len(binary_file_set) > 0:
unused_var_list = " ".join(list(binary_file_set))
_logger.warning("variable file [ %s ] not used" %
(" ".join(list(binary_file_set))))
try:
load_vars(
executor=executor, dirname=model_path, vars=loaded_var_list)
except RuntimeError as e:
_logger.error(e)
raise e
except:
raise RuntimeError(
"Failed to load model file , please make sure model file is saved with the "
"following APIs: save_params, save_persistables, save_vars")
return
elif os.path.isfile(model_path):
if var_list == None:
raise ValueError(
"var_list is required when loading model file saved with [ save_params, save_persistables, save_vars ]"
)
program_var_list = program.list_vars()
program_var_name_set = set([var.name for var in program_var_list])
# check all the variable inlcuded in program
for var in var_list:
if var.name not in program_var_name_set:
raise LookupError(
"loaded var [{}] not included in program variable list")
dir_name, file_name = os.path.split(model_path)
try:
load_vars(
executor=executor,
dirname=dir_name,
vars=var_list,
filename=file_name)
except RuntimeError as e:
_logger.error(e)
raise e
except:
raise RuntimeError( "Failed to load model file , please make sure model file is saved with the " \
"the following APIs: [ save_params, save_persistables, save_vars ]. " \
"When these API called, filename CANNOT be None")
return
def set_var(var, ndarray): def set_var(var, ndarray):
t = global_scope().find_var(var.name).get_tensor() t = global_scope().find_var(var.name).get_tensor()
@ -1561,7 +1690,7 @@ def load(program, model_path, executor=None):
filter(is_belong_to_optimizer, program.list_vars())) filter(is_belong_to_optimizer, program.list_vars()))
if len(optimizer_var_list) > 0: if len(optimizer_var_list) > 0:
opt_file_name = model_path + ".pdopt" opt_file_name = model_prefix + ".pdopt"
assert os.path.exists(opt_file_name), \ assert os.path.exists(opt_file_name), \
"Optimizer file [{}] not exits".format(opt_file_name) "Optimizer file [{}] not exits".format(opt_file_name)
@ -1603,8 +1732,6 @@ def load_program_state(model_path):
fluid.save( prog, "./temp") fluid.save( prog, "./temp")
program_state = fluid.load_program_state( "./temp") program_state = fluid.load_program_state( "./temp")
fluid.set_program_state( prog, program_state)
""" """
parameter_file_name = model_path + ".pdparams" parameter_file_name = model_path + ".pdparams"
assert os.path.exists(parameter_file_name), \ assert os.path.exists(parameter_file_name), \
@ -1653,6 +1780,8 @@ def set_program_state(program, state_dict):
fluid.save( prog, "./temp") fluid.save( prog, "./temp")
program_state = fluid.load_program_state( "./temp") program_state = fluid.load_program_state( "./temp")
fluid.set_program_state( prog, program_state)
""" """
parameter_list = list(filter(is_persistable, program.list_vars())) parameter_list = list(filter(is_persistable, program.list_vars()))

Loading…
Cancel
Save