|
|
@ -25,7 +25,7 @@ import warnings
|
|
|
|
from .. import core
|
|
|
|
from .. import core
|
|
|
|
from .base import guard
|
|
|
|
from .base import guard
|
|
|
|
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs
|
|
|
|
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs
|
|
|
|
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
|
|
|
|
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
__all__ = [
|
|
|
|
'save_dygraph',
|
|
|
|
'save_dygraph',
|
|
|
@ -233,6 +233,19 @@ def load_dygraph(model_path, config=None):
|
|
|
|
para_dict = dict()
|
|
|
|
para_dict = dict()
|
|
|
|
for var_name in persistable_var_dict:
|
|
|
|
for var_name in persistable_var_dict:
|
|
|
|
para_dict[var_name] = persistable_var_dict[var_name].numpy()
|
|
|
|
para_dict[var_name] = persistable_var_dict[var_name].numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# if __variables.info__ exists, we can recover structured_name
|
|
|
|
|
|
|
|
var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME)
|
|
|
|
|
|
|
|
if os.path.exists(var_info_path):
|
|
|
|
|
|
|
|
with open(var_info_path, 'rb') as f:
|
|
|
|
|
|
|
|
extra_var_info = pickle.load(f)
|
|
|
|
|
|
|
|
structured_para_dict = dict()
|
|
|
|
|
|
|
|
for var_name in para_dict:
|
|
|
|
|
|
|
|
structured_name = extra_var_info[var_name].get(
|
|
|
|
|
|
|
|
'structured_name', None)
|
|
|
|
|
|
|
|
assert structured_name is not None, "Cannot find saved variable (%s)'s structured name in saved model." % var_name
|
|
|
|
|
|
|
|
structured_para_dict[structured_name] = para_dict[var_name]
|
|
|
|
|
|
|
|
para_dict = structured_para_dict
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
# Load state dict by `save_dygraph` save format
|
|
|
|
# Load state dict by `save_dygraph` save format
|
|
|
|
para_dict = {}
|
|
|
|
para_dict = {}
|
|
|
|