|
|
|
@ -23,6 +23,7 @@ from paddle import compat as cpt
|
|
|
|
|
from paddle.fluid import core
|
|
|
|
|
from paddle.fluid import framework
|
|
|
|
|
from paddle.fluid import backward
|
|
|
|
|
from paddle.fluid import unique_name
|
|
|
|
|
from paddle.fluid.dygraph import layers
|
|
|
|
|
from paddle.fluid.layers import nn
|
|
|
|
|
from paddle.fluid.dygraph.base import switch_to_static_graph
|
|
|
|
@ -31,6 +32,9 @@ __all__ = ['TranslatedLayer']
|
|
|
|
|
|
|
|
|
|
VARIABLE_FILENAME = "__variables__"
|
|
|
|
|
EXTRA_VAR_INFO_FILENAME = "__variables.info__"
|
|
|
|
|
LOADED_VAR_SUFFIX = "load"
|
|
|
|
|
PARAMETER_NAME_PREFIX = "param"
|
|
|
|
|
BUFFER_NAME_PREFIX = "buffer"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_program_desc(model_file_path):
|
|
|
|
@ -107,33 +111,30 @@ def _get_all_var_names(program_desc):
|
|
|
|
|
return all_var_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _append_loaded_suffix(name):
|
|
|
|
|
"""
|
|
|
|
|
Append loaded suffix to the given variable name
|
|
|
|
|
e.g. x ==> x@LOADED
|
|
|
|
|
e.g. x ==> x.load_0, x.load_0 ==> x.load_0.load_0
|
|
|
|
|
"""
|
|
|
|
|
suffix = core.loaded_var_suffix()
|
|
|
|
|
suffix = LOADED_VAR_SUFFIX
|
|
|
|
|
name = cpt.to_text(name)
|
|
|
|
|
if suffix not in name:
|
|
|
|
|
name = name + suffix
|
|
|
|
|
return name
|
|
|
|
|
new_name = unique_name.generate_with_ignorable_key('.'.join((name, suffix)))
|
|
|
|
|
return new_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _remove_loaded_suffix(name):
|
|
|
|
|
"""
|
|
|
|
|
Remove loaded suffix to the given variable name
|
|
|
|
|
e.g. x@LOADED ==> x
|
|
|
|
|
"""
|
|
|
|
|
suffix = core.loaded_var_suffix()
|
|
|
|
|
name = cpt.to_text(name)
|
|
|
|
|
return name.replace(suffix, '')
|
|
|
|
|
@switch_to_static_graph
|
|
|
|
|
def _generate_unique_var_name(prefix):
|
|
|
|
|
return unique_name.generate_with_ignorable_key(prefix)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _append_loaded_suffix_to_var(program_desc):
|
|
|
|
|
suffix_varname_dict = dict()
|
|
|
|
|
persistable_vars = _get_persistable_vars(program_desc)
|
|
|
|
|
for var_desc in persistable_vars:
|
|
|
|
|
old_name = var_desc.name()
|
|
|
|
|
new_name = _append_loaded_suffix(var_desc.name())
|
|
|
|
|
suffix_varname_dict[new_name] = old_name
|
|
|
|
|
var_desc.set_name(new_name)
|
|
|
|
|
for block_idx in six.moves.range(program_desc.num_blocks()):
|
|
|
|
|
block = program_desc.block(block_idx)
|
|
|
|
@ -141,6 +142,7 @@ def _append_loaded_suffix_to_var(program_desc):
|
|
|
|
|
op = block.op(op_idx)
|
|
|
|
|
op._rename_input(old_name, new_name)
|
|
|
|
|
op._rename_output(old_name, new_name)
|
|
|
|
|
return suffix_varname_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@switch_to_static_graph
|
|
|
|
@ -187,6 +189,9 @@ class _ProgramHolder(object):
|
|
|
|
|
# execution scope
|
|
|
|
|
self._inner_scope = core.Scope()
|
|
|
|
|
|
|
|
|
|
# append suffix var name dict
|
|
|
|
|
self._suffix_varname_dict = None
|
|
|
|
|
|
|
|
|
|
# forward program
|
|
|
|
|
self._infer_program_desc = self._preprocess(program_desc)
|
|
|
|
|
# forward + backward program
|
|
|
|
@ -272,7 +277,7 @@ class _ProgramHolder(object):
|
|
|
|
|
self._append_scale_to_output(tmp_program)
|
|
|
|
|
|
|
|
|
|
# 4. Persistable vars processing
|
|
|
|
|
# - append @LOADED suffix to persistable vars
|
|
|
|
|
# - append loaded suffix to persistable vars
|
|
|
|
|
# NOTE: [why need to append suffix to persistable vars]
|
|
|
|
|
# Dygraph and static graph mode use the same naming mechanism.
|
|
|
|
|
# If users want to load the model fine-tune, it is possible
|
|
|
|
@ -281,10 +286,7 @@ class _ProgramHolder(object):
|
|
|
|
|
# and later after loading, a new linear is added. At this time,
|
|
|
|
|
# there will be a problem of duplicate names, so here is unified
|
|
|
|
|
# to add the LOADED suffix to the parameters of the model loaded
|
|
|
|
|
# during training. And in order to avoid multiple @LOADED suffix
|
|
|
|
|
# are appended to variable name, we only append @LOADED suffix to
|
|
|
|
|
# the variable that not contains @LOADED suffix.
|
|
|
|
|
_append_loaded_suffix_to_var(program_desc)
|
|
|
|
|
self._suffix_varname_dict = _append_loaded_suffix_to_var(program_desc)
|
|
|
|
|
# - get persistable var
|
|
|
|
|
self._persistable_names = _get_persistable_var_names(program_desc)
|
|
|
|
|
|
|
|
|
@ -298,7 +300,7 @@ class _ProgramHolder(object):
|
|
|
|
|
for i, out in enumerate(self._output_descs):
|
|
|
|
|
var = program.global_block().var(out.name())
|
|
|
|
|
var = nn.scale(
|
|
|
|
|
var, 1., name="static_model_runner/scale_{}".format(i))
|
|
|
|
|
var, 1., name="translated_layer/scale_{}".format(i))
|
|
|
|
|
scale_output_vars.append(var)
|
|
|
|
|
# 2. update output names & descs
|
|
|
|
|
for i, var in enumerate(scale_output_vars):
|
|
|
|
@ -363,7 +365,7 @@ def _load_persistable_vars_by_program(model_path,
|
|
|
|
|
persistable_vars = _get_persistable_vars(program_holder.infer_program)
|
|
|
|
|
load_var_dict = {}
|
|
|
|
|
for each_var in persistable_vars:
|
|
|
|
|
orig_each_name = _remove_loaded_suffix(each_var.name())
|
|
|
|
|
orig_each_name = program_holder._suffix_varname_dict[each_var.name()]
|
|
|
|
|
if _is_parameter(each_var, program_holder.infer_program):
|
|
|
|
|
# create output varbase
|
|
|
|
|
new_var = framework.ParamBase(
|
|
|
|
@ -421,6 +423,7 @@ def _load_persistable_vars_by_program(model_path,
|
|
|
|
|
|
|
|
|
|
def _load_persistable_vars(model_path,
|
|
|
|
|
var_info_path,
|
|
|
|
|
program_holder,
|
|
|
|
|
separate_params=False,
|
|
|
|
|
params_filename=None):
|
|
|
|
|
# 1. load extra var info
|
|
|
|
@ -430,10 +433,14 @@ def _load_persistable_vars(model_path,
|
|
|
|
|
# 2. construct var dict
|
|
|
|
|
load_var_dict = dict()
|
|
|
|
|
load_var_list = []
|
|
|
|
|
inv_suffix_varname_dict = {
|
|
|
|
|
value: key
|
|
|
|
|
for key, value in program_holder._suffix_varname_dict.items()
|
|
|
|
|
}
|
|
|
|
|
# NOTE: some var may not be Parameter
|
|
|
|
|
for name in sorted(extra_var_info):
|
|
|
|
|
# append suffix, see [why need to append suffix to persistable vars]
|
|
|
|
|
new_name = _append_loaded_suffix(name)
|
|
|
|
|
# get suffix var name, see [why need to append suffix to persistable vars]
|
|
|
|
|
new_name = inv_suffix_varname_dict[name]
|
|
|
|
|
# create output varbase
|
|
|
|
|
if extra_var_info[name].get('trainable', None) is not None:
|
|
|
|
|
# use default shape and dtype
|
|
|
|
@ -506,7 +513,8 @@ def _construct_params_and_buffers(model_path,
|
|
|
|
|
var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME)
|
|
|
|
|
if os.path.exists(var_info_path):
|
|
|
|
|
var_dict = _load_persistable_vars(model_path, var_info_path,
|
|
|
|
|
separate_params, params_filename)
|
|
|
|
|
programs['forward'], separate_params,
|
|
|
|
|
params_filename)
|
|
|
|
|
else:
|
|
|
|
|
var_dict = _load_persistable_vars_by_program(
|
|
|
|
|
model_path, programs['forward'], params_filename)
|
|
|
|
@ -625,11 +633,23 @@ class TranslatedLayer(layers.Layer):
|
|
|
|
|
|
|
|
|
|
self._program_holder_dict = programs
|
|
|
|
|
|
|
|
|
|
# NOTE(chenweihang): [ why not use var name directly? ]
|
|
|
|
|
# When add parameter or buffer to Layer by follow apis,
|
|
|
|
|
# the variable name can't contain `.`, beccause which may cause
|
|
|
|
|
# AttributeError when access the newly added parameter or buffer
|
|
|
|
|
# in the form of `self.**.**``, but the ParamBase or BarBase
|
|
|
|
|
# name contains `.` originally, such as `linear_0.w_0`, so here
|
|
|
|
|
# need to generate new var name for each var
|
|
|
|
|
self._persistable_var_name_dict = dict()
|
|
|
|
|
for name, var in persistable_vars.items():
|
|
|
|
|
if isinstance(var, framework.ParamBase):
|
|
|
|
|
self.add_parameter(name, var)
|
|
|
|
|
dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX)
|
|
|
|
|
self._persistable_var_name_dict[name] = dy_name
|
|
|
|
|
self.add_parameter(dy_name, var)
|
|
|
|
|
elif isinstance(var, core.VarBase):
|
|
|
|
|
self.register_buffer(name, var)
|
|
|
|
|
dy_name = _generate_unique_var_name(BUFFER_NAME_PREFIX)
|
|
|
|
|
self._persistable_var_name_dict[name] = dy_name
|
|
|
|
|
self.register_buffer(dy_name, var)
|
|
|
|
|
else:
|
|
|
|
|
raise TypeError(
|
|
|
|
|
"Adding persistent variable which to layer is not supported now"
|
|
|
|
@ -700,10 +720,11 @@ class TranslatedLayer(layers.Layer):
|
|
|
|
|
|
|
|
|
|
persistable_vars = []
|
|
|
|
|
for var_name in program_holder.persistable_names:
|
|
|
|
|
if var_name in self._parameters:
|
|
|
|
|
persistable_vars.append(self._parameters[var_name])
|
|
|
|
|
elif var_name in self._buffers:
|
|
|
|
|
persistable_vars.append(self._buffers[var_name])
|
|
|
|
|
dy_var_name = self._persistable_var_name_dict[var_name]
|
|
|
|
|
if dy_var_name in self._parameters:
|
|
|
|
|
persistable_vars.append(self._parameters[dy_var_name])
|
|
|
|
|
elif dy_var_name in self._buffers:
|
|
|
|
|
persistable_vars.append(self._buffers[dy_var_name])
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"The persistable variable %s is not exists in current TranslatedLayer."
|
|
|
|
|