|
|
@ -437,8 +437,16 @@ def _load_persistable_vars(model_path,
|
|
|
|
value: key
|
|
|
|
value: key
|
|
|
|
for key, value in program_holder._suffix_varname_dict.items()
|
|
|
|
for key, value in program_holder._suffix_varname_dict.items()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
# NOTE: some var may not be Parameter
|
|
|
|
|
|
|
|
for name in sorted(extra_var_info):
|
|
|
|
# NOTE(chenweihang): we need load persistable vars based the program,
|
|
|
|
|
|
|
|
# because the program may be pruned when `save_inference_model`, some
|
|
|
|
|
|
|
|
# var in `extra_var_info` may have been pruned
|
|
|
|
|
|
|
|
for name in sorted(inv_suffix_varname_dict):
|
|
|
|
|
|
|
|
if name not in extra_var_info:
|
|
|
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
|
|
|
"The model to be loaded is not complete."
|
|
|
|
|
|
|
|
"The variable `%s` of program cannot be found in loaded model.",
|
|
|
|
|
|
|
|
name)
|
|
|
|
# get suffix var name, see [why need to append suffix to persistable vars]
|
|
|
|
# get suffix var name, see [why need to append suffix to persistable vars]
|
|
|
|
new_name = inv_suffix_varname_dict[name]
|
|
|
|
new_name = inv_suffix_varname_dict[name]
|
|
|
|
# create output varbase
|
|
|
|
# create output varbase
|
|
|
@ -641,19 +649,21 @@ class TranslatedLayer(layers.Layer):
|
|
|
|
# name contains `.` originally, such as `linear_0.w_0`, so here
|
|
|
|
# name contains `.` originally, such as `linear_0.w_0`, so here
|
|
|
|
# need to generate new var name for each var
|
|
|
|
# need to generate new var name for each var
|
|
|
|
self._persistable_var_name_dict = dict()
|
|
|
|
self._persistable_var_name_dict = dict()
|
|
|
|
for name, var in persistable_vars.items():
|
|
|
|
# the TranslatedLayer object holded var names count started from 0
|
|
|
|
if isinstance(var, framework.ParamBase):
|
|
|
|
with unique_name.guard():
|
|
|
|
dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX)
|
|
|
|
for name, var in persistable_vars.items():
|
|
|
|
self._persistable_var_name_dict[name] = dy_name
|
|
|
|
if isinstance(var, framework.ParamBase):
|
|
|
|
self.add_parameter(dy_name, var)
|
|
|
|
dy_name = _generate_unique_var_name(PARAMETER_NAME_PREFIX)
|
|
|
|
elif isinstance(var, core.VarBase):
|
|
|
|
self._persistable_var_name_dict[name] = dy_name
|
|
|
|
dy_name = _generate_unique_var_name(BUFFER_NAME_PREFIX)
|
|
|
|
self.add_parameter(dy_name, var)
|
|
|
|
self._persistable_var_name_dict[name] = dy_name
|
|
|
|
elif isinstance(var, core.VarBase):
|
|
|
|
self.register_buffer(dy_name, var)
|
|
|
|
dy_name = _generate_unique_var_name(BUFFER_NAME_PREFIX)
|
|
|
|
else:
|
|
|
|
self._persistable_var_name_dict[name] = dy_name
|
|
|
|
raise TypeError(
|
|
|
|
self.register_buffer(dy_name, var)
|
|
|
|
"Adding persistent variable which to layer is not supported now"
|
|
|
|
else:
|
|
|
|
)
|
|
|
|
raise TypeError(
|
|
|
|
|
|
|
|
"Adding persistent variable which to layer is not supported now"
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
self._is_test = True
|
|
|
|
self._is_test = True
|
|
|
|
|
|
|
|
|
|
|
|