|
|
@ -55,6 +55,8 @@ __all__ = [
|
|
|
|
'load',
|
|
|
|
'load',
|
|
|
|
'load_program_state',
|
|
|
|
'load_program_state',
|
|
|
|
'set_program_state',
|
|
|
|
'set_program_state',
|
|
|
|
|
|
|
|
'get_program_parameter',
|
|
|
|
|
|
|
|
'get_program_persistable_vars',
|
|
|
|
] + reader.__all__ + paddle.reader.__all__
|
|
|
|
] + reader.__all__ + paddle.reader.__all__
|
|
|
|
|
|
|
|
|
|
|
|
_logger = get_logger(
|
|
|
|
_logger = get_logger(
|
|
|
@ -114,6 +116,50 @@ def is_belong_to_optimizer(var):
|
|
|
|
return False
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_program_parameter(program):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Get all the parameters from Program.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
var(Program): The Program to get parameters
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
list: The list contains all parameters in the program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
|
|
|
data = fluid.data(name="img", shape=[64, 784])
|
|
|
|
|
|
|
|
w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32', name='fc_w')
|
|
|
|
|
|
|
|
b = fluid.layers.create_parameter(shape=[200], dtype='float32', name='fc_b')
|
|
|
|
|
|
|
|
list_para = fluid.io.get_program_parameter( fluid.default_main_program() )
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
return list(filter(is_parameter, program.list_vars()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_program_persistable_vars(program):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Get all the persistable vars from Program.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
var(Program): The Program to get persistable vars
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
list: The list contains all persistable vars in the program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
|
|
|
.. code-block:: python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
|
|
|
|
data = fluid.data(name="img", shape=[64, 784])
|
|
|
|
|
|
|
|
w = fluid.layers.create_parameter(shape=[784, 200], dtype='float32', name='fc_w')
|
|
|
|
|
|
|
|
b = fluid.layers.create_parameter(shape=[200], dtype='float32', name='fc_b')
|
|
|
|
|
|
|
|
list_para = fluid.io.get_program_persistable_vars( fluid.default_main_program() )
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
return list(filter(is_persistable, program.list_vars()))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _clone_var_in_block_(block, var):
|
|
|
|
def _clone_var_in_block_(block, var):
|
|
|
|
assert isinstance(var, Variable)
|
|
|
|
assert isinstance(var, Variable)
|
|
|
|
if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR:
|
|
|
|
if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR:
|
|
|
@ -1497,16 +1543,23 @@ def save(program, model_path):
|
|
|
|
f.write(program.desc.serialize_to_string())
|
|
|
|
f.write(program.desc.serialize_to_string())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(program, model_path, executor=None):
|
|
|
|
def load(program, model_path, executor=None, var_list=None):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
This function filter out parameters and optimizer information from program, and then get corresponding value from file.
|
|
|
|
This function get parameters and optimizer information from program, and then get corresponding value from file.
|
|
|
|
An exception will throw if shape or dtype of the parameters is not match.
|
|
|
|
An exception will throw if shape or dtype of the parameters is not match.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
This function can also load model file saved with [ save_params, save_persistables, save_vars ].
|
|
|
|
|
|
|
|
var_list can not be None when load single model file
|
|
|
|
|
|
|
|
( filename is not None When save_params, save_persistables or save_vars is called ).
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
program(Program): The program will be loaded
|
|
|
|
program(Program): The program will be loaded
|
|
|
|
model_path(str): The file prefix store the program
|
|
|
|
model_path(str): The file prefix store the program
|
|
|
|
executor(Executor, optional): The executor used for initialize the parameter
|
|
|
|
executor(Executor, optional): The executor used for initialize the parameter
|
|
|
|
When startup program is not run.
|
|
|
|
When startup program is not run.
|
|
|
|
|
|
|
|
var_list(list, optional): The variable list to load single model file saved with
|
|
|
|
|
|
|
|
[ save_params, save_persistables, save_vars ].
|
|
|
|
|
|
|
|
Default: None
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
Returns:
|
|
|
|
None
|
|
|
|
None
|
|
|
@ -1525,9 +1578,85 @@ def load(program, model_path, executor=None):
|
|
|
|
|
|
|
|
|
|
|
|
assert executor is None or isinstance(executor, Executor)
|
|
|
|
assert executor is None or isinstance(executor, Executor)
|
|
|
|
|
|
|
|
|
|
|
|
parameter_file_name = model_path + ".pdparams"
|
|
|
|
model_prefix = model_path
|
|
|
|
assert os.path.exists(parameter_file_name), \
|
|
|
|
if model_prefix.endswith(".pdparams"):
|
|
|
|
"Parameter file [{}] not exits".format(parameter_file_name)
|
|
|
|
model_prefix = model_prefix[:-9]
|
|
|
|
|
|
|
|
elif model_prefix.endswith(".pdopt"):
|
|
|
|
|
|
|
|
model_prefix = model_prefix[:-6]
|
|
|
|
|
|
|
|
elif model_prefix.endswith(".pdmodel"):
|
|
|
|
|
|
|
|
model_prefix = model_prefix[:-8]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parameter_file_name = model_prefix + ".pdparams"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(parameter_file_name):
|
|
|
|
|
|
|
|
# model file save by fluid.save not found, try to load model file saved with
|
|
|
|
|
|
|
|
# [save_vars, save_params, save_persistables]
|
|
|
|
|
|
|
|
_logger.warning(
|
|
|
|
|
|
|
|
"{} not found, try to load model file saved with [ save_params, save_persistables, save_vars ]".
|
|
|
|
|
|
|
|
format(parameter_file_name))
|
|
|
|
|
|
|
|
if executor is None:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"executor is required when loading model file saved with [ save_params, save_persistables, save_vars ]"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
if os.path.isdir(model_path):
|
|
|
|
|
|
|
|
binary_file_set = set()
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk(model_path, topdown=False):
|
|
|
|
|
|
|
|
for f in files:
|
|
|
|
|
|
|
|
binary_file_set.add(
|
|
|
|
|
|
|
|
os.path.join(root, f).replace("\\", "/"))
|
|
|
|
|
|
|
|
program_var_list = list(program.list_vars())
|
|
|
|
|
|
|
|
loaded_var_list = []
|
|
|
|
|
|
|
|
for var in program_var_list:
|
|
|
|
|
|
|
|
var_path = os.path.join(model_path, var.name).replace("\\", "/")
|
|
|
|
|
|
|
|
if var_path in binary_file_set:
|
|
|
|
|
|
|
|
loaded_var_list.append(var)
|
|
|
|
|
|
|
|
binary_file_set.remove(var_path)
|
|
|
|
|
|
|
|
if len(binary_file_set) > 0:
|
|
|
|
|
|
|
|
unused_var_list = " ".join(list(binary_file_set))
|
|
|
|
|
|
|
|
_logger.warning("variable file [ %s ] not used" %
|
|
|
|
|
|
|
|
(" ".join(list(binary_file_set))))
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
load_vars(
|
|
|
|
|
|
|
|
executor=executor, dirname=model_path, vars=loaded_var_list)
|
|
|
|
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
|
|
|
_logger.error(e)
|
|
|
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
except:
|
|
|
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
|
|
|
"Failed to load model file , please make sure model file is saved with the "
|
|
|
|
|
|
|
|
"following APIs: save_params, save_persistables, save_vars")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
elif os.path.isfile(model_path):
|
|
|
|
|
|
|
|
if var_list == None:
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
|
|
|
"var_list is required when loading model file saved with [ save_params, save_persistables, save_vars ]"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
program_var_list = program.list_vars()
|
|
|
|
|
|
|
|
program_var_name_set = set([var.name for var in program_var_list])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# check all the variable inlcuded in program
|
|
|
|
|
|
|
|
for var in var_list:
|
|
|
|
|
|
|
|
if var.name not in program_var_name_set:
|
|
|
|
|
|
|
|
raise LookupError(
|
|
|
|
|
|
|
|
"loaded var [{}] not included in program variable list")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dir_name, file_name = os.path.split(model_path)
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
load_vars(
|
|
|
|
|
|
|
|
executor=executor,
|
|
|
|
|
|
|
|
dirname=dir_name,
|
|
|
|
|
|
|
|
vars=var_list,
|
|
|
|
|
|
|
|
filename=file_name)
|
|
|
|
|
|
|
|
except RuntimeError as e:
|
|
|
|
|
|
|
|
_logger.error(e)
|
|
|
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
except:
|
|
|
|
|
|
|
|
raise RuntimeError( "Failed to load model file , please make sure model file is saved with the " \
|
|
|
|
|
|
|
|
"the following APIs: [ save_params, save_persistables, save_vars ]. " \
|
|
|
|
|
|
|
|
"When these API called, filename CANNOT be None")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
def set_var(var, ndarray):
|
|
|
|
def set_var(var, ndarray):
|
|
|
|
t = global_scope().find_var(var.name).get_tensor()
|
|
|
|
t = global_scope().find_var(var.name).get_tensor()
|
|
|
@ -1561,7 +1690,7 @@ def load(program, model_path, executor=None):
|
|
|
|
filter(is_belong_to_optimizer, program.list_vars()))
|
|
|
|
filter(is_belong_to_optimizer, program.list_vars()))
|
|
|
|
|
|
|
|
|
|
|
|
if len(optimizer_var_list) > 0:
|
|
|
|
if len(optimizer_var_list) > 0:
|
|
|
|
opt_file_name = model_path + ".pdopt"
|
|
|
|
opt_file_name = model_prefix + ".pdopt"
|
|
|
|
assert os.path.exists(opt_file_name), \
|
|
|
|
assert os.path.exists(opt_file_name), \
|
|
|
|
"Optimizer file [{}] not exits".format(opt_file_name)
|
|
|
|
"Optimizer file [{}] not exits".format(opt_file_name)
|
|
|
|
|
|
|
|
|
|
|
@ -1603,8 +1732,6 @@ def load_program_state(model_path):
|
|
|
|
fluid.save( prog, "./temp")
|
|
|
|
fluid.save( prog, "./temp")
|
|
|
|
program_state = fluid.load_program_state( "./temp")
|
|
|
|
program_state = fluid.load_program_state( "./temp")
|
|
|
|
|
|
|
|
|
|
|
|
fluid.set_program_state( prog, program_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
parameter_file_name = model_path + ".pdparams"
|
|
|
|
parameter_file_name = model_path + ".pdparams"
|
|
|
|
assert os.path.exists(parameter_file_name), \
|
|
|
|
assert os.path.exists(parameter_file_name), \
|
|
|
@ -1653,6 +1780,8 @@ def set_program_state(program, state_dict):
|
|
|
|
fluid.save( prog, "./temp")
|
|
|
|
fluid.save( prog, "./temp")
|
|
|
|
program_state = fluid.load_program_state( "./temp")
|
|
|
|
program_state = fluid.load_program_state( "./temp")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fluid.set_program_state( prog, program_state)
|
|
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
parameter_list = list(filter(is_persistable, program.list_vars()))
|
|
|
|
parameter_list = list(filter(is_persistable, program.list_vars()))
|
|
|
|
|
|
|
|
|
|
|
|