use structured name in loaded dict (#27242)

disable_ut_1
Chen Weihang 4 years ago committed by GitHub
parent 5e0dde02b2
commit ac8afe184e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,7 +25,7 @@ import warnings
from .. import core from .. import core
from .base import guard from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME
__all__ = [ __all__ = [
'save_dygraph', 'save_dygraph',
@ -233,6 +233,19 @@ def load_dygraph(model_path, config=None):
para_dict = dict() para_dict = dict()
for var_name in persistable_var_dict: for var_name in persistable_var_dict:
para_dict[var_name] = persistable_var_dict[var_name].numpy() para_dict[var_name] = persistable_var_dict[var_name].numpy()
# if __variables.info__ exists, we can recover structured_name
var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME)
if os.path.exists(var_info_path):
with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f)
structured_para_dict = dict()
for var_name in para_dict:
structured_name = extra_var_info[var_name].get(
'structured_name', None)
assert structured_name is not None, "Cannot find saved variable (%s)'s structured name in saved model." % var_name
structured_para_dict[structured_name] = para_dict[var_name]
para_dict = structured_para_dict
else: else:
# Load state dict by `save_dygraph` save format # Load state dict by `save_dygraph` save format
para_dict = {} para_dict = {}

@ -255,8 +255,11 @@ class TestJitSaveLoad(unittest.TestCase):
train_layer.eval() train_layer.eval()
# construct new model # construct new model
new_layer = LinearNet(784, 1) new_layer = LinearNet(784, 1)
model_dict, _ = fluid.dygraph.load_dygraph(self.model_path) orig_state_dict = new_layer.state_dict()
new_layer.set_dict(model_dict) load_state_dict, _ = fluid.dygraph.load_dygraph(self.model_path)
for structured_name in orig_state_dict:
self.assertTrue(structured_name in load_state_dict)
new_layer.set_state_dict(load_state_dict)
new_layer.eval() new_layer.eval()
# inference & compare # inference & compare
x = fluid.dygraph.to_variable( x = fluid.dygraph.to_variable(

Loading…
Cancel
Save