@ -24,8 +24,8 @@ from . import learning_rate_scheduler
import warnings
from . . import core
from . base import guard
from paddle . fluid . dygraph . jit import SaveLoadConfig, deprecate_save_load_configs
from paddle . fluid . dygraph . io import _construct_program_holders , _construct_params_and_buffers , EXTRA_VAR_INFO_FILENAME
from paddle . fluid . dygraph . jit import _ SaveLoadConfig
from paddle . fluid . dygraph . io import _construct_program_holders , _construct_params_and_buffers
__all__ = [
' save_dygraph ' ,
@ -33,35 +33,23 @@ __all__ = [
]
# NOTE(chenweihang): deprecate load_dygraph's argument keep_name_table,
# ensure compatibility when user still use keep_name_table argument
def deprecate_keep_name_table ( func ) :
@functools.wraps ( func )
def wrapper ( * args , * * kwargs ) :
def __warn_and_build_configs__ ( keep_name_table ) :
warnings . warn (
" The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`. " ,
DeprecationWarning )
config = SaveLoadConfig ( )
config . keep_name_table = keep_name_table
return config
# deal with arg `keep_name_table`
if len ( args ) > 1 and isinstance ( args [ 1 ] , bool ) :
args = list ( args )
args [ 1 ] = __warn_and_build_configs__ ( args [ 1 ] )
# deal with kwargs
elif ' keep_name_table ' in kwargs :
kwargs [ ' config ' ] = __warn_and_build_configs__ ( kwargs [
' keep_name_table ' ] )
kwargs . pop ( ' keep_name_table ' )
else :
# do nothing
pass
def _parse_load_config ( configs ) :
supported_configs = [ ' model_filename ' , ' params_filename ' , ' keep_name_table ' ]
# input check
for key in configs :
if key not in supported_configs :
raise ValueError (
" The additional config ( %s ) of `paddle.fluid.load_dygraph` is not supported. "
% ( key ) )
return func ( * args , * * kwargs )
# construct inner config
inner_config = _SaveLoadConfig ( )
inner_config . model_filename = configs . get ( ' model_filename ' , None )
inner_config . params_filename = configs . get ( ' params_filename ' , None )
inner_config . keep_name_table = configs . get ( ' keep_name_table ' , None )
return wrapper
return inner_config
@dygraph_only
@ -132,12 +120,12 @@ def save_dygraph(state_dict, model_path):
pickle . dump ( model_dict , f , protocol = 2 )
# NOTE(chenweihang): load_dygraph will deprecated in future, we don't
# support new loading features for it
# TODO(qingqing01): remove dygraph_only to support loading static model.
# maybe need to unify the loading interface after 2.0 API is ready.
# @dygraph_only
@deprecate_save_load_configs
@deprecate_keep_name_table
def load_dygraph ( model_path , config = None ) :
def load_dygraph ( model_path , * * configs ) :
'''
: api_attr : imperative
@ -152,10 +140,13 @@ def load_dygraph(model_path, config=None):
Args :
model_path ( str ) : The file prefix store the state_dict .
( The path should Not contain suffix ' .pdparams ' )
config ( SaveLoadConfig , optional ) : : ref : ` api_imperative_jit_saveLoadConfig `
object that specifies additional configuration options , these options
are for compatibility with ` ` jit . save / io . save_inference_model ` ` formats .
Default None .
* * configs ( dict , optional ) : other save configuration options for compatibility . We do not
recommend using these configurations , if not necessary , DO NOT use them . Default None .
The following options are currently supported :
( 1 ) model_filename ( string ) : The inference model file name of the paddle 1. x ` ` save_inference_model ` `
save format . Default file name is : code : ` __model__ ` .
( 2 ) params_filename ( string ) : The persistable variables file name of the paddle 1. x ` ` save_inference_model ` `
save format . No default file name , save variables separately by default .
Returns :
state_dict ( dict ) : the dict store the state_dict
@ -196,8 +187,7 @@ def load_dygraph(model_path, config=None):
opti_file_path = model_prefix + " .pdopt "
# deal with argument `config`
if config is None :
config = SaveLoadConfig ( )
config = _parse_load_config ( configs )
if os . path . exists ( params_file_path ) or os . path . exists ( opti_file_path ) :
# Load state dict by `save_dygraph` save format
@ -246,7 +236,6 @@ def load_dygraph(model_path, config=None):
persistable_var_dict = _construct_params_and_buffers (
model_prefix ,
programs ,
config . separate_params ,
config . params_filename ,
append_suffix = False )
@ -255,9 +244,9 @@ def load_dygraph(model_path, config=None):
for var_name in persistable_var_dict :
para_dict [ var_name ] = persistable_var_dict [ var_name ] . numpy ( )
# if __variables.info__ exists, we can recover structured_name
var_info_ path = os . path . join ( model_prefix ,
EXTRA_VAR_INFO_FILENAME )
# if *.info exists, we can recover structured_name
var_info_ filename = str ( config . params_filename ) + " .info "
var_info_path = os . path . join ( model_prefix , var_info_filename )
if os . path . exists ( var_info_path ) :
with open ( var_info_path , ' rb ' ) as f :
extra_var_info = pickle . load ( f )