|
|
|
@ -195,58 +195,11 @@ def load_dygraph(model_path, config=None):
|
|
|
|
|
params_file_path = model_prefix + ".pdparams"
|
|
|
|
|
opti_file_path = model_prefix + ".pdopt"
|
|
|
|
|
|
|
|
|
|
# deal with argument `configs`
|
|
|
|
|
configs = config
|
|
|
|
|
if configs is None:
|
|
|
|
|
configs = SaveLoadConfig()
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(params_file_path) and not os.path.exists(
|
|
|
|
|
opti_file_path):
|
|
|
|
|
# Load state dict by `jit.save/io.save_inference_model` save format
|
|
|
|
|
# NOTE(chenweihang): [ Compatibility of save_inference_model save format ]
|
|
|
|
|
# The model saved by `save_inference_model` does not completely correspond to
|
|
|
|
|
# the information required by the `state_dict` under the dygraph.
|
|
|
|
|
# `save_inference_model` not save structured name, we need to remind
|
|
|
|
|
# the user to configure the `use_structured_name` argument when `set_state_dict`
|
|
|
|
|
# NOTE(chenweihang): `jit.save` doesn't save optimizer state
|
|
|
|
|
|
|
|
|
|
# 1. check model path
|
|
|
|
|
if not os.path.isdir(model_prefix):
|
|
|
|
|
raise ValueError("Model saved directory '%s' is not exists." %
|
|
|
|
|
model_prefix)
|
|
|
|
|
# deal with argument `config`
|
|
|
|
|
if config is None:
|
|
|
|
|
config = SaveLoadConfig()
|
|
|
|
|
|
|
|
|
|
# 2. load program desc & construct _ProgramHolder
|
|
|
|
|
programs = _construct_program_holders(model_path,
|
|
|
|
|
configs.model_filename)
|
|
|
|
|
|
|
|
|
|
# 3. load layer parameters & buffers
|
|
|
|
|
# NOTE: using fluid.dygraph.guard() here will cause import error in py2
|
|
|
|
|
with guard():
|
|
|
|
|
persistable_var_dict = _construct_params_and_buffers(
|
|
|
|
|
model_prefix,
|
|
|
|
|
programs,
|
|
|
|
|
configs.separate_params,
|
|
|
|
|
configs.params_filename,
|
|
|
|
|
append_suffix=False)
|
|
|
|
|
|
|
|
|
|
# 4. construct state_dict
|
|
|
|
|
para_dict = dict()
|
|
|
|
|
for var_name in persistable_var_dict:
|
|
|
|
|
para_dict[var_name] = persistable_var_dict[var_name].numpy()
|
|
|
|
|
|
|
|
|
|
# if __variables.info__ exists, we can recover structured_name
|
|
|
|
|
var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME)
|
|
|
|
|
if os.path.exists(var_info_path):
|
|
|
|
|
with open(var_info_path, 'rb') as f:
|
|
|
|
|
extra_var_info = pickle.load(f)
|
|
|
|
|
structured_para_dict = dict()
|
|
|
|
|
for var_name in para_dict:
|
|
|
|
|
structured_name = extra_var_info[var_name].get(
|
|
|
|
|
'structured_name', None)
|
|
|
|
|
assert structured_name is not None, "Cannot find saved variable (%s)'s structured name in saved model." % var_name
|
|
|
|
|
structured_para_dict[structured_name] = para_dict[var_name]
|
|
|
|
|
para_dict = structured_para_dict
|
|
|
|
|
else:
|
|
|
|
|
if os.path.exists(params_file_path) or os.path.exists(opti_file_path):
|
|
|
|
|
# Load state dict by `save_dygraph` save format
|
|
|
|
|
para_dict = {}
|
|
|
|
|
if os.path.exists(params_file_path):
|
|
|
|
@ -254,12 +207,103 @@ def load_dygraph(model_path, config=None):
|
|
|
|
|
para_dict = pickle.load(f) if six.PY2 else pickle.load(
|
|
|
|
|
f, encoding='latin1')
|
|
|
|
|
|
|
|
|
|
if not configs.keep_name_table and "StructuredToParameterName@@" in para_dict:
|
|
|
|
|
if not config.keep_name_table and "StructuredToParameterName@@" in para_dict:
|
|
|
|
|
del para_dict["StructuredToParameterName@@"]
|
|
|
|
|
|
|
|
|
|
if os.path.exists(opti_file_path):
|
|
|
|
|
with open(opti_file_path, 'rb') as f:
|
|
|
|
|
opti_dict = pickle.load(f) if six.PY2 else pickle.load(
|
|
|
|
|
f, encoding='latin1')
|
|
|
|
|
else:
|
|
|
|
|
# check model path
|
|
|
|
|
if not os.path.isdir(model_prefix):
|
|
|
|
|
raise ValueError("Model saved directory '%s' is not exists." %
|
|
|
|
|
model_prefix)
|
|
|
|
|
|
|
|
|
|
# check whether model file exists
|
|
|
|
|
if config.model_filename is None:
|
|
|
|
|
model_filename = '__model__'
|
|
|
|
|
else:
|
|
|
|
|
model_filename = config.model_filename
|
|
|
|
|
model_file_path = os.path.join(model_path, model_filename)
|
|
|
|
|
|
|
|
|
|
if os.path.exists(model_file_path):
|
|
|
|
|
# Load state dict by `jit.save/io.save_inference_model` save format
|
|
|
|
|
# NOTE(chenweihang): [ Compatibility of save_inference_model save format ]
|
|
|
|
|
# The model saved by `save_inference_model` does not completely correspond to
|
|
|
|
|
# the information required by the `state_dict` under the dygraph.
|
|
|
|
|
# `save_inference_model` not save structured name, we need to remind
|
|
|
|
|
# the user to configure the `use_structured_name` argument when `set_state_dict`
|
|
|
|
|
# NOTE(chenweihang): `jit.save` doesn't save optimizer state
|
|
|
|
|
|
|
|
|
|
# 1. load program desc & construct _ProgramHolder
|
|
|
|
|
programs = _construct_program_holders(model_path,
|
|
|
|
|
config.model_filename)
|
|
|
|
|
|
|
|
|
|
# 2. load layer parameters & buffers
|
|
|
|
|
# NOTE: using fluid.dygraph.guard() here will cause import error in py2
|
|
|
|
|
with guard():
|
|
|
|
|
persistable_var_dict = _construct_params_and_buffers(
|
|
|
|
|
model_prefix,
|
|
|
|
|
programs,
|
|
|
|
|
config.separate_params,
|
|
|
|
|
config.params_filename,
|
|
|
|
|
append_suffix=False)
|
|
|
|
|
|
|
|
|
|
# 3. construct state_dict
|
|
|
|
|
para_dict = dict()
|
|
|
|
|
for var_name in persistable_var_dict:
|
|
|
|
|
para_dict[var_name] = persistable_var_dict[var_name].numpy()
|
|
|
|
|
|
|
|
|
|
# if __variables.info__ exists, we can recover structured_name
|
|
|
|
|
var_info_path = os.path.join(model_prefix,
|
|
|
|
|
EXTRA_VAR_INFO_FILENAME)
|
|
|
|
|
if os.path.exists(var_info_path):
|
|
|
|
|
with open(var_info_path, 'rb') as f:
|
|
|
|
|
extra_var_info = pickle.load(f)
|
|
|
|
|
structured_para_dict = dict()
|
|
|
|
|
for var_name in para_dict:
|
|
|
|
|
structured_name = extra_var_info[var_name].get(
|
|
|
|
|
'structured_name', None)
|
|
|
|
|
assert structured_name is not None, "Cannot find saved variable (%s)'s structured name in saved model." % var_name
|
|
|
|
|
structured_para_dict[structured_name] = para_dict[
|
|
|
|
|
var_name]
|
|
|
|
|
para_dict = structured_para_dict
|
|
|
|
|
else:
|
|
|
|
|
# load state dict by `io.save_params/persistables` save format
|
|
|
|
|
# TODO(chenweihang): [ Now only supports loading parameters seperately ]
|
|
|
|
|
# If users save all parameters as one file, the [ variable.name -> variable ]
|
|
|
|
|
# mapping info will lost, so users need to give variable list, but users build
|
|
|
|
|
# variable list in dygraph mode is difficult, we recommend users to use
|
|
|
|
|
# paddle.io.load_program_state in this case
|
|
|
|
|
|
|
|
|
|
# Try to load all the files in the directory in VarBase format,
|
|
|
|
|
# the file name is used as the name of VarBase
|
|
|
|
|
load_var_list = []
|
|
|
|
|
|
|
|
|
|
# 1. load file names
|
|
|
|
|
var_name_list = []
|
|
|
|
|
for root, _, files in os.walk(model_path):
|
|
|
|
|
for filename in files:
|
|
|
|
|
file_path = os.path.join(root, filename)
|
|
|
|
|
tmp_var_name = os.path.relpath(file_path, model_path)
|
|
|
|
|
var_name = tmp_var_name.replace("\\", "/")
|
|
|
|
|
var_name_list.append(var_name)
|
|
|
|
|
|
|
|
|
|
# 2. create and load VarBase
|
|
|
|
|
with guard():
|
|
|
|
|
for name in var_name_list:
|
|
|
|
|
new_var = _varbase_creator(name=name, persistable=True)
|
|
|
|
|
_dygraph_tracer().trace_op(
|
|
|
|
|
type='load',
|
|
|
|
|
inputs={},
|
|
|
|
|
outputs={'Out': new_var},
|
|
|
|
|
attrs={'file_path': os.path.join(model_path, name)})
|
|
|
|
|
load_var_list.append(new_var)
|
|
|
|
|
|
|
|
|
|
# 3. construct state_dict
|
|
|
|
|
para_dict = dict()
|
|
|
|
|
for var in load_var_list:
|
|
|
|
|
para_dict[var.name] = var.numpy()
|
|
|
|
|
|
|
|
|
|
return para_dict, opti_dict
|
|
|
|
|