|
|
|
@ -24,7 +24,7 @@ from . import learning_rate_scheduler
|
|
|
|
|
import warnings
|
|
|
|
|
from .. import core
|
|
|
|
|
from .base import guard
|
|
|
|
|
from paddle.fluid.dygraph.jit import SaveLoadConfig
|
|
|
|
|
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs
|
|
|
|
|
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
|
|
|
|
|
|
|
|
|
|
__all__ = [
|
|
|
|
@ -42,9 +42,9 @@ def deprecate_keep_name_table(func):
|
|
|
|
|
warnings.warn(
|
|
|
|
|
"The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.",
|
|
|
|
|
DeprecationWarning)
|
|
|
|
|
configs = SaveLoadConfig()
|
|
|
|
|
configs.keep_name_table = keep_name_table
|
|
|
|
|
return configs
|
|
|
|
|
config = SaveLoadConfig()
|
|
|
|
|
config.keep_name_table = keep_name_table
|
|
|
|
|
return config
|
|
|
|
|
|
|
|
|
|
# deal with arg `keep_name_table`
|
|
|
|
|
if len(args) > 1 and isinstance(args[1], bool):
|
|
|
|
@ -52,7 +52,7 @@ def deprecate_keep_name_table(func):
|
|
|
|
|
args[1] = __warn_and_build_configs__(args[1])
|
|
|
|
|
# deal with kwargs
|
|
|
|
|
elif 'keep_name_table' in kwargs:
|
|
|
|
|
kwargs['configs'] = __warn_and_build_configs__(kwargs[
|
|
|
|
|
kwargs['config'] = __warn_and_build_configs__(kwargs[
|
|
|
|
|
'keep_name_table'])
|
|
|
|
|
kwargs.pop('keep_name_table')
|
|
|
|
|
else:
|
|
|
|
@ -135,8 +135,9 @@ def save_dygraph(state_dict, model_path):
|
|
|
|
|
# TODO(qingqing01): remove dygraph_only to support loading static model.
|
|
|
|
|
# maybe need to unify the loading interface after 2.0 API is ready.
|
|
|
|
|
# @dygraph_only
|
|
|
|
|
@deprecate_save_load_configs
|
|
|
|
|
@deprecate_keep_name_table
|
|
|
|
|
def load_dygraph(model_path, configs=None):
|
|
|
|
|
def load_dygraph(model_path, config=None):
|
|
|
|
|
'''
|
|
|
|
|
:api_attr: imperative
|
|
|
|
|
|
|
|
|
@ -151,7 +152,7 @@ def load_dygraph(model_path, configs=None):
|
|
|
|
|
Args:
|
|
|
|
|
model_path(str) : The file prefix store the state_dict.
|
|
|
|
|
(The path should Not contain suffix '.pdparams')
|
|
|
|
|
configs (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig`
|
|
|
|
|
config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig`
|
|
|
|
|
object that specifies additional configuration options, these options
|
|
|
|
|
are for compatibility with ``jit.save/io.save_inference_model`` formats.
|
|
|
|
|
Default None.
|
|
|
|
@ -195,6 +196,7 @@ def load_dygraph(model_path, configs=None):
|
|
|
|
|
opti_file_path = model_prefix + ".pdopt"
|
|
|
|
|
|
|
|
|
|
# deal with argument `configs`
|
|
|
|
|
configs = config
|
|
|
|
|
if configs is None:
|
|
|
|
|
configs = SaveLoadConfig()
|
|
|
|
|
|
|
|
|
|