|
|
|
@ -581,6 +581,16 @@ def save(layer, path, input_spec=None, **configs):
|
|
|
|
|
"The input layer of paddle.jit.save should be 'Layer', but received layer type is %s."
|
|
|
|
|
% type(layer))
|
|
|
|
|
|
|
|
|
|
# NOTE(chenweihang): If the input layer be wrapped by DataParallel,
|
|
|
|
|
# the args and kwargs of forward method will can't be parsed by
|
|
|
|
|
# function_spec, so here we save DataParallel._layers instead
|
|
|
|
|
# DataParallel it self
|
|
|
|
|
# NOTE(chenweihang): using inner_layer, do not change input layer
|
|
|
|
|
if isinstance(layer, paddle.DataParallel):
|
|
|
|
|
inner_layer = layer._layers
|
|
|
|
|
else:
|
|
|
|
|
inner_layer = layer
|
|
|
|
|
|
|
|
|
|
# path check
|
|
|
|
|
file_prefix = os.path.basename(path)
|
|
|
|
|
if file_prefix == "":
|
|
|
|
@ -596,8 +606,8 @@ def save(layer, path, input_spec=None, **configs):
|
|
|
|
|
# avoid change user given input_spec
|
|
|
|
|
inner_input_spec = None
|
|
|
|
|
if input_spec is not None:
|
|
|
|
|
for attr_func in dir(layer):
|
|
|
|
|
static_func = getattr(layer, attr_func, None)
|
|
|
|
|
for attr_func in dir(inner_layer):
|
|
|
|
|
static_func = getattr(inner_layer, attr_func, None)
|
|
|
|
|
if isinstance(static_func,
|
|
|
|
|
StaticFunction) and 'forward' != attr_func:
|
|
|
|
|
raise ValueError(
|
|
|
|
@ -623,14 +633,14 @@ def save(layer, path, input_spec=None, **configs):
|
|
|
|
|
configs = _parse_save_configs(configs)
|
|
|
|
|
scope = core.Scope()
|
|
|
|
|
extra_var_info = dict()
|
|
|
|
|
for attr_func in dir(layer):
|
|
|
|
|
static_func = getattr(layer, attr_func, None)
|
|
|
|
|
for attr_func in dir(inner_layer):
|
|
|
|
|
static_func = getattr(inner_layer, attr_func, None)
|
|
|
|
|
if isinstance(static_func, StaticFunction):
|
|
|
|
|
concrete_program = static_func.concrete_program
|
|
|
|
|
elif 'forward' == attr_func:
|
|
|
|
|
# transform in jit.save, if input_spec is incomplete, declarative will throw error
|
|
|
|
|
static_forward = declarative(
|
|
|
|
|
layer.forward, input_spec=inner_input_spec)
|
|
|
|
|
inner_layer.forward, input_spec=inner_input_spec)
|
|
|
|
|
concrete_program = static_forward.concrete_program
|
|
|
|
|
# the input_spec has been used in declarative, which is equal to
|
|
|
|
|
# @declarative with input_spec and jit.save without input_spec,
|
|
|
|
@ -663,7 +673,7 @@ def save(layer, path, input_spec=None, **configs):
|
|
|
|
|
# saved to inference program may not need by dygraph Layer,
|
|
|
|
|
# we only record the state_dict variable's structured name
|
|
|
|
|
state_names_dict = dict()
|
|
|
|
|
for structured_name, var in six.iteritems(layer.state_dict()):
|
|
|
|
|
for structured_name, var in six.iteritems(inner_layer.state_dict()):
|
|
|
|
|
state_names_dict[var.name] = structured_name
|
|
|
|
|
|
|
|
|
|
# 4. share parameters from Layer to scope & record var info
|
|
|
|
|