@ -594,6 +594,13 @@ def save(layer, path, input_spec=None, **configs):
# avoid change user given input_spec
inner_input_spec = None
if input_spec is not None :
for attr_func in dir ( layer ) :
static_func = getattr ( layer , attr_func , None )
if isinstance ( static_func ,
StaticFunction ) and ' forward ' != attr_func :
raise ValueError (
" If there are static functions other than ' forward ' that need to be saved, the input ' input_spec ' should be None, but received the type of ' input_spec ' is %s . "
% type ( input_spec ) )
if not isinstance ( input_spec , list ) :
raise TypeError (
" The input input_spec should be ' list ' , but received input_spec ' s type is %s . "
@ -612,19 +619,23 @@ def save(layer, path, input_spec=None, **configs):
# parse configs
configs = _parse_save_configs ( configs )
# 2. get program from Layer
# TODO(chenweihang): add support for other method, not only forward
if isinstance ( layer . forward , StaticFunction ) :
concrete_program = layer . forward . concrete_program
else :
scope = core . Scope ( )
extra_var_info = dict ( )
for attr_func in dir ( layer ) :
static_func = getattr ( layer , attr_func , None )
if isinstance ( static_func , StaticFunction ) :
concrete_program = static_func . concrete_program
elif ' forward ' == attr_func :
# transform in jit.save, if input_spec is incomplete, declarative will throw error
static_forward = declarative ( layer . forward , input_spec = inner_input_spec )
static_forward = declarative (
layer . forward , input_spec = inner_input_spec )
concrete_program = static_forward . concrete_program
# the input_spec has been used in declarative, which is equal to
# @declarative with input_spec and jit.save without input_spec,
# avoid needless warning
inner_input_spec = None
else :
continue
# 3. build input & output of save_infernece_model
# NOTE(chenweihang): [ Get input variables name ]
@ -654,14 +665,14 @@ def save(layer, path, input_spec=None, **configs):
state_names_dict [ var . name ] = structured_name
# 4. share parameters from Layer to scope & record var info
scope = core . Scope ( )
extra_var_info = dict ( )
for param_or_buffer in concrete_program . parameters :
# share to scope
param_or_buffer_tensor = scope . var ( param_or_buffer . name ) . get_tensor ( )
param_or_buffer_tensor = scope . var ( param_or_buffer . name ) . get_tensor (
)
src_tensor = param_or_buffer . value ( ) . get_tensor ( )
param_or_buffer_tensor . _share_data_with ( src_tensor )
# record var info
if param_or_buffer . name not in extra_var_info :
extra_info_dict = dict ( )
if param_or_buffer . name in state_names_dict :
extra_info_dict [ ' structured_name ' ] = state_names_dict [
@ -678,8 +689,12 @@ def save(layer, path, input_spec=None, **configs):
model_path = dirname
# NOTE(chenweihang): because prefix contains model and params filename,
# so we don't support set model_filename & params_filename
if ' forward ' == attr_func :
model_filename = file_prefix + INFER_MODEL_SUFFIX
params_filename = file_prefix + INFER_PARAMS_SUFFIX
else :
model_filename = file_prefix + ' . ' + attr_func + INFER_MODEL_SUFFIX
params_filename = file_prefix + ' . ' + attr_func + INFER_PARAMS_SUFFIX
with scope_guard ( scope ) :
save_inference_model (
@ -708,6 +723,7 @@ def save(layer, path, input_spec=None, **configs):
# but we can save these information in `jit.save` without changing the original
# storage to improve user experience. So we save extra information into
# file `***.pdiparams.info`
with scope_guard ( scope ) :
extra_var_info_path = path + INFER_PARAMS_INFO_SUFFIX
with open ( extra_var_info_path , ' wb ' ) as f :
pickle . dump ( extra_var_info , f , protocol = 2 )