You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
147 lines
4.8 KiB
147 lines
4.8 KiB
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import collections
|
|
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase
|
|
import pickle
|
|
import six
|
|
from . import learning_rate_scheduler
|
|
import warnings
|
|
from .. import core
|
|
|
|
__all__ = [
|
|
'save_dygraph',
|
|
'load_dygraph',
|
|
]
|
|
|
|
|
|
@dygraph_only
|
|
def save_dygraph(state_dict, model_path):
|
|
'''
|
|
Save Layer's state_dict to disk. This will generate a file with suffix ".pdparams"
|
|
|
|
The state_dict is get from Layers.state_dict function
|
|
|
|
Args:
|
|
state_dict(dict) : The state dict to be saved.
|
|
model_path(str) : the file prefix to save the state_dict. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
|
|
|
|
Returns:
|
|
None
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
with fluid.dygraph.guard():
|
|
emb = fluid.dygraph.Embedding([10, 10])
|
|
|
|
state_dict = emb.state_dict()
|
|
fluid.save_dygraph( state_dict, "paddle_dy")
|
|
|
|
adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
|
|
parameter_list = emb.parameters() )
|
|
|
|
state_dict = adam.state_dict()
|
|
fluid.save_dygraph( state_dict, "paddle_dy")
|
|
|
|
'''
|
|
|
|
base_name = os.path.basename(model_path)
|
|
assert base_name != "", "model_path MUST be format of dirname/filename [dirname\\filename in Window], Now filename is empty str"
|
|
|
|
suffix = ".pdparams"
|
|
assert len(state_dict) > 0, "state_dict is empty, no need to save"
|
|
|
|
for k, v in state_dict.items():
|
|
if not isinstance(v, ParamBase):
|
|
suffix = ".pdopt"
|
|
break
|
|
|
|
model_dict = {}
|
|
name_table = {}
|
|
for k, v in state_dict.items():
|
|
if isinstance(v, (Variable, core.VarBase)):
|
|
model_dict[k] = v.numpy()
|
|
else:
|
|
model_dict[k] = v
|
|
name_table[k] = v.name
|
|
model_dict["StructuredToParameterName@@"] = name_table
|
|
|
|
file_name = model_path + suffix
|
|
dir_name = os.path.dirname(file_name)
|
|
if dir_name and not os.path.exists(dir_name):
|
|
os.makedirs(dir_name)
|
|
|
|
with open(file_name, 'wb') as f:
|
|
pickle.dump(model_dict, f, protocol=2)
|
|
|
|
|
|
@dygraph_only
|
|
def load_dygraph(model_path, keep_name_table=False):
|
|
'''
|
|
Load parameter state_dict from disk.
|
|
|
|
Args:
|
|
model_path(str) : The file prefix store the state_dict. (The path should Not contain suffix '.pdparams')
|
|
keep_name_table(bool, optional) : Whether keep structed name to parameter name conversion table in output dict.
|
|
Default : False
|
|
|
|
Returns:
|
|
state_dict(dict) : the dict store the state_dict
|
|
|
|
Examples:
|
|
.. code-block:: python
|
|
|
|
import paddle.fluid as fluid
|
|
|
|
with fluid.dygraph.guard():
|
|
emb = fluid.dygraph.Embedding([10, 10])
|
|
|
|
state_dict = emb.state_dict()
|
|
fluid.save_dygraph( state_dict, "paddle_dy")
|
|
|
|
adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
|
|
parameter_list = emb.parameters() )
|
|
state_dict = adam.state_dict()
|
|
fluid.save_dygraph( state_dict, "paddle_dy")
|
|
|
|
para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy")
|
|
|
|
'''
|
|
|
|
params_file_path = model_path + ".pdparams"
|
|
if not os.path.exists(params_file_path):
|
|
raise RuntimeError("Parameter file [ {} ] not exists".format(
|
|
params_file_path))
|
|
|
|
with open(params_file_path, 'rb') as f:
|
|
para_dict = pickle.load(f) if six.PY2 else pickle.load(
|
|
f, encoding='latin1')
|
|
|
|
if not keep_name_table and "StructuredToParameterName@@" in para_dict:
|
|
del para_dict["StructuredToParameterName@@"]
|
|
opti_dict = None
|
|
opti_file_path = model_path + ".pdopt"
|
|
if os.path.exists(opti_file_path):
|
|
with open(opti_file_path, 'rb') as f:
|
|
opti_dict = pickle.load(f) if six.PY2 else pickle.load(
|
|
f, encoding='latin1')
|
|
|
|
return para_dict, opti_dict
|