|
|
|
@ -2025,35 +2025,63 @@ def load_program_state(model_path, var_list=None):
|
|
|
|
|
None,
|
|
|
|
|
persistable=True)
|
|
|
|
|
|
|
|
|
|
def _load_vars_with_try_catch(exe,
|
|
|
|
|
dirname,
|
|
|
|
|
vars,
|
|
|
|
|
filename,
|
|
|
|
|
raise_error=True):
|
|
|
|
|
try:
|
|
|
|
|
load_vars(
|
|
|
|
|
executor=exe,
|
|
|
|
|
dirname=dirname,
|
|
|
|
|
vars=vars,
|
|
|
|
|
filename=filename)
|
|
|
|
|
return True
|
|
|
|
|
except:
|
|
|
|
|
error_str = "Failed to load model/variables `%s`, please make sure " \
|
|
|
|
|
"model/variables file is saved with the following APIs: " \
|
|
|
|
|
"save_params, save_persistables, save_vars."
|
|
|
|
|
filenames = [var.name for var in vars
|
|
|
|
|
] if filename is None else filename
|
|
|
|
|
if raise_error:
|
|
|
|
|
raise RuntimeError(error_str % filenames)
|
|
|
|
|
else:
|
|
|
|
|
warnings.warn(error_str % filenames, RuntimeWarning)
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
place = paddle.fluid.CPUPlace()
|
|
|
|
|
exe = paddle.fluid.Executor(place)
|
|
|
|
|
|
|
|
|
|
loaded_var_list = []
|
|
|
|
|
|
|
|
|
|
if var_list is not None:
|
|
|
|
|
if os.path.isfile(model_path):
|
|
|
|
|
# when model_path is file, var_list cannot be None
|
|
|
|
|
dir_name, file_name = os.path.split(model_path)
|
|
|
|
|
for var in var_list:
|
|
|
|
|
loaded_var_list.append(clone_var_to_block(load_block, var))
|
|
|
|
|
_load_vars_with_try_catch(exe, dir_name, loaded_var_list,
|
|
|
|
|
file_name)
|
|
|
|
|
else:
|
|
|
|
|
for var_name in var_name_list:
|
|
|
|
|
loaded_var_list.append(
|
|
|
|
|
load_block.create_var(
|
|
|
|
|
name=var_name, persistable=True))
|
|
|
|
|
|
|
|
|
|
place = paddle.fluid.CPUPlace()
|
|
|
|
|
exe = paddle.fluid.Executor(place)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if os.path.isfile(model_path):
|
|
|
|
|
dir_name, file_name = os.path.split(model_path)
|
|
|
|
|
# var_list can be None or not None
|
|
|
|
|
if var_list is not None:
|
|
|
|
|
for var in var_list:
|
|
|
|
|
loaded_var_list.append(
|
|
|
|
|
clone_var_to_block(load_block, var))
|
|
|
|
|
_load_vars_with_try_catch(exe, model_path, loaded_var_list,
|
|
|
|
|
None)
|
|
|
|
|
else:
|
|
|
|
|
dir_name = model_path
|
|
|
|
|
file_name = None
|
|
|
|
|
load_vars(
|
|
|
|
|
executor=exe,
|
|
|
|
|
dirname=dir_name,
|
|
|
|
|
vars=loaded_var_list,
|
|
|
|
|
filename=file_name)
|
|
|
|
|
except:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Failed to load model file , please make sure model file is saved with the "
|
|
|
|
|
"following APIs: save_params, save_persistables, save_vars")
|
|
|
|
|
for var_name in var_name_list:
|
|
|
|
|
# NOTE(chenweihang): If identify which files the user wants
|
|
|
|
|
# to load from the disk, we load these variables one by one.
|
|
|
|
|
# If a file does not exist, we only warn the user that the
|
|
|
|
|
# file may be an irrelevant file, but does not throw an error
|
|
|
|
|
# to ensure that other legal variables can be loaded.
|
|
|
|
|
temp_var = load_block.create_var(
|
|
|
|
|
name=var_name, persistable=True)
|
|
|
|
|
if _load_vars_with_try_catch(exe, model_path,
|
|
|
|
|
[temp_var], None, False):
|
|
|
|
|
loaded_var_list.append(temp_var)
|
|
|
|
|
|
|
|
|
|
res_dict = {}
|
|
|
|
|
for var in loaded_var_list:
|
|
|
|
|
res_dict[var.name] = np.asarray(paddle.fluid.global_scope(
|
|
|
|
|