Enhance load program state (#22546)

* enhance load program state; test=develop

* optimize commet; test=develop
revert-22710-feature/integrated_ps_api
hong 5 years ago committed by GitHub
parent 8acd745c25
commit 6980239632
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -20,6 +20,7 @@ import warnings
import six import six
import logging import logging
import pickle import pickle
import contextlib
from functools import reduce from functools import reduce
import numpy as np import numpy as np
@ -179,6 +180,17 @@ def _clone_var_in_block_(block, var):
persistable=True) persistable=True)
@contextlib.contextmanager
def _load_program_scope(main=None, startup=None, scope=None):
prog = main if main else paddle.fluid.Program()
startup_prog = startup if startup else paddle.fluid.Program()
scope = scope if scope else paddle.fluid.core.Scope()
with paddle.fluid.scope_guard(scope):
with paddle.fluid.program_guard(prog, startup_prog):
with paddle.fluid.unique_name.guard():
yield
def _get_valid_program(main_program): def _get_valid_program(main_program):
if main_program is None: if main_program is None:
main_program = default_main_program() main_program = default_main_program()
@ -1711,12 +1723,17 @@ def load(program, model_path, executor=None, var_list=None):
set_var(v, load_dict[v.name]) set_var(v, load_dict[v.name])
def load_program_state(model_path): def load_program_state(model_path, var_list=None):
""" """
Load program state from local file Load program state from local file
Args: Args:
model_path(str): The file prefix store the program model_path(str): The file prefix store the program
var_list(list, optional): The variable list to load saved with
[ save_params, save_persistables, save_vars ].
Default: None.
The var_list is only used to get name,
will not be modified.
Returns: Returns:
state_dict(dict): the dict store Parameter and optimizer information state_dict(dict): the dict store Parameter and optimizer information
@ -1737,14 +1754,94 @@ def load_program_state(model_path):
program_state = fluid.load_program_state( "./temp") program_state = fluid.load_program_state( "./temp")
""" """
parameter_file_name = model_path + ".pdparams" model_prefix = model_path
if model_prefix.endswith(".pdparams"):
model_prefix = model_prefix[:-9]
elif model_prefix.endswith(".pdopt"):
model_prefix = model_prefix[:-6]
elif model_prefix.endswith(".pdmodel"):
model_prefix = model_prefix[:-8]
parameter_file_name = model_prefix + ".pdparams"
if not os.path.exists(parameter_file_name):
# model file saved with fluid.save is not found, try to load model file saved with
# [save_vars, save_params, save_persistables]
_logger.warning(
"{} not found, try to load model file saved with [ save_params, save_persistables, save_vars ]".
format(parameter_file_name))
var_name_list = []
if var_list is None and os.path.isfile(model_path):
raise ValueError(
"var_list can not be None when model_path is a file type")
for root, dirs, files in os.walk(model_path, topdown=False):
for f in files:
file_path = os.path.join(root, f)
var_temp_name = os.path.relpath(file_path, model_path)
var_temp_name = var_temp_name.replace("\\", "/")
var_name_list.append(var_temp_name)
with _load_program_scope():
load_prog = Program()
load_block = load_prog.global_block()
def clone_var_to_block(block, var):
if not isinstance(var, Variable):
raise TypeError("value in var_list must be variable")
return block.create_var(
name=var.name,
shape=var.shape,
dtype=var.dtype,
type=var.type,
lod_level=var.lod_level
if var.desc.type() == core.VarDesc.VarType.LOD_TENSOR else
None,
persistable=True)
loaded_var_list = []
if var_list is not None:
for var in var_list:
loaded_var_list.append(clone_var_to_block(load_block, var))
else:
for var_name in var_name_list:
loaded_var_list.append(
load_block.create_var(
name=var_name, persistable=True))
place = paddle.fluid.CPUPlace()
exe = paddle.fluid.Executor(place)
try:
if os.path.isfile(model_path):
dir_name, file_name = os.path.split(model_path)
else:
dir_name = model_path
file_name = None
load_vars(
executor=exe,
dirname=dir_name,
vars=loaded_var_list,
filename=file_name)
except:
raise RuntimeError(
"Failed to load model file , please make sure model file is saved with the "
"following APIs: save_params, save_persistables, save_vars")
res_dict = {}
for var in loaded_var_list:
res_dict[var.name] = np.asarray(paddle.fluid.global_scope(
).find_var(var.name).get_tensor())
return res_dict
assert os.path.exists(parameter_file_name), \ assert os.path.exists(parameter_file_name), \
"Parameter file [{}] not exits".format(parameter_file_name) "Parameter file [{}] not exits".format(parameter_file_name)
with open(parameter_file_name, 'rb') as f: with open(parameter_file_name, 'rb') as f:
para_dict = pickle.load(f) para_dict = pickle.load(f)
opt_file_name = model_path + ".pdopt" opt_file_name = model_prefix + ".pdopt"
if os.path.exists(opt_file_name): if os.path.exists(opt_file_name):
with open(opt_file_name, 'rb') as f: with open(opt_file_name, 'rb') as f:
opti_dict = pickle.load(f) opti_dict = pickle.load(f)

Loading…
Cancel
Save