|
|
|
@ -20,6 +20,7 @@ import warnings
|
|
|
|
|
import six
|
|
|
|
|
import logging
|
|
|
|
|
import pickle
|
|
|
|
|
import contextlib
|
|
|
|
|
from functools import reduce
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
@ -179,6 +180,17 @@ def _clone_var_in_block_(block, var):
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager
|
|
|
|
|
def _load_program_scope(main=None, startup=None, scope=None):
|
|
|
|
|
prog = main if main else paddle.fluid.Program()
|
|
|
|
|
startup_prog = startup if startup else paddle.fluid.Program()
|
|
|
|
|
scope = scope if scope else paddle.fluid.core.Scope()
|
|
|
|
|
with paddle.fluid.scope_guard(scope):
|
|
|
|
|
with paddle.fluid.program_guard(prog, startup_prog):
|
|
|
|
|
with paddle.fluid.unique_name.guard():
|
|
|
|
|
yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_valid_program(main_program):
|
|
|
|
|
if main_program is None:
|
|
|
|
|
main_program = default_main_program()
|
|
|
|
@ -1711,12 +1723,17 @@ def load(program, model_path, executor=None, var_list=None):
|
|
|
|
|
set_var(v, load_dict[v.name])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_program_state(model_path):
|
|
|
|
|
def load_program_state(model_path, var_list=None):
|
|
|
|
|
"""
|
|
|
|
|
Load program state from local file
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
model_path(str): The file prefix store the program
|
|
|
|
|
var_list(list, optional): The variable list to load saved with
|
|
|
|
|
[ save_params, save_persistables, save_vars ].
|
|
|
|
|
Default: None.
|
|
|
|
|
The var_list is only used to get name,
|
|
|
|
|
will not be modified.
|
|
|
|
|
Returns:
|
|
|
|
|
state_dict(dict): the dict store Parameter and optimizer information
|
|
|
|
|
|
|
|
|
@ -1737,14 +1754,94 @@ def load_program_state(model_path):
|
|
|
|
|
program_state = fluid.load_program_state( "./temp")
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
parameter_file_name = model_path + ".pdparams"
|
|
|
|
|
model_prefix = model_path
|
|
|
|
|
if model_prefix.endswith(".pdparams"):
|
|
|
|
|
model_prefix = model_prefix[:-9]
|
|
|
|
|
elif model_prefix.endswith(".pdopt"):
|
|
|
|
|
model_prefix = model_prefix[:-6]
|
|
|
|
|
elif model_prefix.endswith(".pdmodel"):
|
|
|
|
|
model_prefix = model_prefix[:-8]
|
|
|
|
|
|
|
|
|
|
parameter_file_name = model_prefix + ".pdparams"
|
|
|
|
|
if not os.path.exists(parameter_file_name):
|
|
|
|
|
# model file saved with fluid.save is not found, try to load model file saved with
|
|
|
|
|
# [save_vars, save_params, save_persistables]
|
|
|
|
|
_logger.warning(
|
|
|
|
|
"{} not found, try to load model file saved with [ save_params, save_persistables, save_vars ]".
|
|
|
|
|
format(parameter_file_name))
|
|
|
|
|
|
|
|
|
|
var_name_list = []
|
|
|
|
|
if var_list is None and os.path.isfile(model_path):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"var_list can not be None when model_path is a file type")
|
|
|
|
|
|
|
|
|
|
for root, dirs, files in os.walk(model_path, topdown=False):
|
|
|
|
|
for f in files:
|
|
|
|
|
file_path = os.path.join(root, f)
|
|
|
|
|
var_temp_name = os.path.relpath(file_path, model_path)
|
|
|
|
|
var_temp_name = var_temp_name.replace("\\", "/")
|
|
|
|
|
var_name_list.append(var_temp_name)
|
|
|
|
|
|
|
|
|
|
with _load_program_scope():
|
|
|
|
|
load_prog = Program()
|
|
|
|
|
load_block = load_prog.global_block()
|
|
|
|
|
|
|
|
|
|
def clone_var_to_block(block, var):
|
|
|
|
|
if not isinstance(var, Variable):
|
|
|
|
|
raise TypeError("value in var_list must be variable")
|
|
|
|
|
return block.create_var(
|
|
|
|
|
name=var.name,
|
|
|
|
|
shape=var.shape,
|
|
|
|
|
dtype=var.dtype,
|
|
|
|
|
type=var.type,
|
|
|
|
|
lod_level=var.lod_level
|
|
|
|
|
if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR else
|
|
|
|
|
None,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
loaded_var_list = []
|
|
|
|
|
|
|
|
|
|
if var_list is not None:
|
|
|
|
|
for var in var_list:
|
|
|
|
|
loaded_var_list.append(clone_var_to_block(load_block, var))
|
|
|
|
|
else:
|
|
|
|
|
for var_name in var_name_list:
|
|
|
|
|
loaded_var_list.append(
|
|
|
|
|
load_block.create_var(
|
|
|
|
|
name=var_name, persistable=True))
|
|
|
|
|
|
|
|
|
|
place = paddle.fluid.CPUPlace()
|
|
|
|
|
exe = paddle.fluid.Executor(place)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if os.path.isfile(model_path):
|
|
|
|
|
dir_name, file_name = os.path.split(model_path)
|
|
|
|
|
else:
|
|
|
|
|
dir_name = model_path
|
|
|
|
|
file_name = None
|
|
|
|
|
load_vars(
|
|
|
|
|
executor=exe,
|
|
|
|
|
dirname=dir_name,
|
|
|
|
|
vars=loaded_var_list,
|
|
|
|
|
filename=file_name)
|
|
|
|
|
except:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Failed to load model file , please make sure model file is saved with the "
|
|
|
|
|
"following APIs: save_params, save_persistables, save_vars")
|
|
|
|
|
res_dict = {}
|
|
|
|
|
for var in loaded_var_list:
|
|
|
|
|
res_dict[var.name] = np.asarray(paddle.fluid.global_scope(
|
|
|
|
|
).find_var(var.name).get_tensor())
|
|
|
|
|
|
|
|
|
|
return res_dict
|
|
|
|
|
|
|
|
|
|
assert os.path.exists(parameter_file_name), \
|
|
|
|
|
"Parameter file [{}] not exits".format(parameter_file_name)
|
|
|
|
|
|
|
|
|
|
with open(parameter_file_name, 'rb') as f:
|
|
|
|
|
para_dict = pickle.load(f)
|
|
|
|
|
|
|
|
|
|
opt_file_name = model_path + ".pdopt"
|
|
|
|
|
opt_file_name = model_prefix + ".pdopt"
|
|
|
|
|
if os.path.exists(opt_file_name):
|
|
|
|
|
with open(opt_file_name, 'rb') as f:
|
|
|
|
|
opti_dict = pickle.load(f)
|
|
|
|
|