|
|
|
@ -132,6 +132,13 @@ def __get_used_layers__(output_layers):
|
|
|
|
|
add_parent(mem.layer_name, mem.boot_layer_name)
|
|
|
|
|
add_parent(mem.link_name, mem.layer_name)
|
|
|
|
|
|
|
|
|
|
if sub_model.HasField('generator'):
|
|
|
|
|
# according to the implementation of text generation
|
|
|
|
|
# in recurrent layer group, the generated word must be
|
|
|
|
|
# the first out link
|
|
|
|
|
add_parent(sub_model.out_links[0].layer_name,
|
|
|
|
|
sub_model.generator.eos_layer_name)
|
|
|
|
|
|
|
|
|
|
def dfs_travel(layer_name):
|
|
|
|
|
if layer_name in layer_names:
|
|
|
|
|
return
|
|
|
|
@ -175,8 +182,6 @@ def __get_used_submodels__(layer_names):
|
|
|
|
|
for submodel in cp.g_config.model_config.sub_models:
|
|
|
|
|
if submodel.name in layer_names:
|
|
|
|
|
submodel_names.add(submodel.name)
|
|
|
|
|
if submodel.is_recurrent_layer_group:
|
|
|
|
|
layer_names |= set(submodel.layer_names)
|
|
|
|
|
return submodel_names
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|