Enhance load program state (#22546)

* enhance load program state; test=develop

* optimize commet; test=develop
revert-22710-feature/integrated_ps_api
hong 5 years ago committed by GitHub
parent 8acd745c25
commit 6980239632
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,6 +20,7 @@ import warnings
import six
import logging
import pickle
import contextlib
from functools import reduce
import numpy as np
@ -179,6 +180,17 @@ def _clone_var_in_block_(block, var):
persistable=True)
@contextlib.contextmanager
def _load_program_scope(main=None, startup=None, scope=None):
prog = main if main else paddle.fluid.Program()
startup_prog = startup if startup else paddle.fluid.Program()
scope = scope if scope else paddle.fluid.core.Scope()
with paddle.fluid.scope_guard(scope):
with paddle.fluid.program_guard(prog, startup_prog):
with paddle.fluid.unique_name.guard():
yield
def _get_valid_program(main_program):
if main_program is None:
main_program = default_main_program()
@ -1711,12 +1723,17 @@ def load(program, model_path, executor=None, var_list=None):
set_var(v, load_dict[v.name])
def load_program_state(model_path):
def load_program_state(model_path, var_list=None):
"""
Load program state from local file
Args:
model_path(str): The file prefix store the program
var_list(list, optional): The variable list to load saved with
[ save_params, save_persistables, save_vars ].
Default: None.
The var_list is only used to get name,
will not be modified.
Returns:
state_dict(dict): the dict store Parameter and optimizer information
@ -1737,14 +1754,94 @@ def load_program_state(model_path):
program_state = fluid.load_program_state( "./temp")
"""
parameter_file_name = model_path + ".pdparams"
model_prefix = model_path
if model_prefix.endswith(".pdparams"):
model_prefix = model_prefix[:-9]
elif model_prefix.endswith(".pdopt"):
model_prefix = model_prefix[:-6]
elif model_prefix.endswith(".pdmodel"):
model_prefix = model_prefix[:-8]
parameter_file_name = model_prefix + ".pdparams"
if not os.path.exists(parameter_file_name):
# model file saved with fluid.save is not found, try to load model file saved with
# [save_vars, save_params, save_persistables]
_logger.warning(
"{} not found, try to load model file saved with [ save_params, save_persistables, save_vars ]".
format(parameter_file_name))
var_name_list = []
if var_list is None and os.path.isfile(model_path):
raise ValueError(
"var_list can not be None when model_path is a file type")
for root, dirs, files in os.walk(model_path, topdown=False):
for f in files:
file_path = os.path.join(root, f)
var_temp_name = os.path.relpath(file_path, model_path)
var_temp_name = var_temp_name.replace("\\", "/")
var_name_list.append(var_temp_name)
with _load_program_scope():
load_prog = Program()
load_block = load_prog.global_block()
def clone_var_to_block(block, var):
if not isinstance(var, Variable):
raise TypeError("value in var_list must be variable")
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level
if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR else
None,
persistable=True)
loaded_var_list = []
if var_list is not None:
for var in var_list:
loaded_var_list.append(clone_var_to_block(load_block, var))
else:
for var_name in var_name_list:
loaded_var_list.append(
load_block.create_var(
name=var_name, persistable=True))
place = paddle.fluid.CPUPlace()
exe = paddle.fluid.Executor(place)
try:
if os.path.isfile(model_path):
dir_name, file_name = os.path.split(model_path)
else:
dir_name = model_path
file_name = None
load_vars(
executor=exe,
dirname=dir_name,
vars=loaded_var_list,
filename=file_name)
except:
raise RuntimeError(
"Failed to load model file , please make sure model file is saved with the "
"following APIs: save_params, save_persistables, save_vars")
res_dict = {}
for var in loaded_var_list:
res_dict[var.name] = np.asarray(paddle.fluid.global_scope(
).find_var(var.name).get_tensor())
return res_dict
assert os.path.exists(parameter_file_name), \
"Parameter file [{}] not exits".format(parameter_file_name)
with open(parameter_file_name, 'rb') as f:
para_dict = pickle.load(f)
opt_file_name = model_path + ".pdopt"
opt_file_name = model_prefix + ".pdopt"
if os.path.exists(opt_file_name):
with open(opt_file_name, 'rb') as f:
opti_dict = pickle.load(f)

Loading…
Cancel
Save