parent
3a5d6e5e64
commit
31287cdb43
@ -1,45 +0,0 @@
|
|||||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from paddle.trainer.config_parser import parse_config
|
|
||||||
from paddle.proto import TrainerConfig_pb2
|
|
||||||
import sys
|
|
||||||
|
|
||||||
__all__ = []
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
whole_conf = False
|
|
||||||
binary = False
|
|
||||||
if len(sys.argv) == 2:
|
|
||||||
conf = parse_config(sys.argv[1], '')
|
|
||||||
elif len(sys.argv) == 3:
|
|
||||||
conf = parse_config(sys.argv[1], sys.argv[2])
|
|
||||||
elif len(sys.argv) == 4:
|
|
||||||
conf = parse_config(sys.argv[1], sys.argv[2])
|
|
||||||
if sys.argv[3] == '--whole':
|
|
||||||
whole_conf = True
|
|
||||||
elif sys.argv[3] == '--binary':
|
|
||||||
binary = True
|
|
||||||
else:
|
|
||||||
raise RuntimeError()
|
|
||||||
|
|
||||||
assert isinstance(conf, TrainerConfig_pb2.TrainerConfig)
|
|
||||||
|
|
||||||
if whole_conf:
|
|
||||||
print(conf)
|
|
||||||
else:
|
|
||||||
if binary:
|
|
||||||
sys.stdout.write(conf.model_config.SerializeToString())
|
|
||||||
else:
|
|
||||||
print(conf.model_config)
|
|
@ -1,62 +0,0 @@
|
|||||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
import collections
|
|
||||||
|
|
||||||
from paddle.trainer_config_helpers.layers import LayerOutput
|
|
||||||
from paddle.v2.layer import parse_network
|
|
||||||
from paddle.proto import TrainerConfig_pb2
|
|
||||||
|
|
||||||
__all__ = ["dump_v2_config"]
|
|
||||||
|
|
||||||
|
|
||||||
def dump_v2_config(topology, save_path, binary=False):
|
|
||||||
""" Dump the network topology to a specified file.
|
|
||||||
|
|
||||||
This function is only used to dump network defined by using PaddlePaddle V2
|
|
||||||
APIs. This function will NOT dump configurations related to PaddlePaddle
|
|
||||||
optimizer.
|
|
||||||
|
|
||||||
:param topology: The output layers (can be more than one layers given in a
|
|
||||||
Python List or Tuple) of the entire network. Using the
|
|
||||||
specified layers (if more than one layer is given) as root,
|
|
||||||
traversing back to the data layer(s), all the layers
|
|
||||||
connected to the specified output layers will be dumped.
|
|
||||||
Layers not connceted to the specified will not be dumped.
|
|
||||||
:type topology: LayerOutput|List|Tuple
|
|
||||||
:param save_path: The path to save the dumped network topology.
|
|
||||||
:type save_path: str
|
|
||||||
:param binary: Whether to dump the serialized network topology or not.
|
|
||||||
The default value is false. NOTE that, if you call this
|
|
||||||
function to generate network topology for PaddlePaddle C-API,
|
|
||||||
a serialized version of network topology is required. When
|
|
||||||
using PaddlePaddle C-API, this flag MUST be set to True.
|
|
||||||
:type binary: bool
|
|
||||||
"""
|
|
||||||
|
|
||||||
if isinstance(topology, LayerOutput):
|
|
||||||
topology = [topology]
|
|
||||||
elif isinstance(topology, collections.Sequence):
|
|
||||||
for out_layer in topology:
|
|
||||||
assert isinstance(out_layer, LayerOutput), (
|
|
||||||
"The type of each element in the parameter topology "
|
|
||||||
"should be LayerOutput.")
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Error input type for parameter topology.")
|
|
||||||
|
|
||||||
model_str = parse_network(topology)
|
|
||||||
with open(save_path, "w") as fout:
|
|
||||||
if binary:
|
|
||||||
fout.write(model_str.SerializeToString())
|
|
||||||
else:
|
|
||||||
fout.write(str(model_str))
|
|
File diff suppressed because it is too large
Load Diff
@ -1,140 +0,0 @@
|
|||||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# Generate dot diagram file for the given paddle model config
|
|
||||||
# The generated file can be viewed using Graphviz (http://graphviz.org)
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import six
|
|
||||||
import sys
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
from paddle.trainer.config_parser import parse_config
|
|
||||||
|
|
||||||
|
|
||||||
def make_layer_label(layer_config):
|
|
||||||
label = '%s type=%s' % (layer_config.name, layer_config.type)
|
|
||||||
if layer_config.reversed:
|
|
||||||
label += ' <=='
|
|
||||||
|
|
||||||
label2 = ''
|
|
||||||
if layer_config.active_type:
|
|
||||||
label2 += 'act=%s ' % layer_config.active_type
|
|
||||||
if layer_config.bias_parameter_name:
|
|
||||||
label2 += 'bias=%s ' % layer_config.bias_parameter_name
|
|
||||||
|
|
||||||
if label2:
|
|
||||||
label += '\l' + label2
|
|
||||||
return label
|
|
||||||
|
|
||||||
|
|
||||||
def make_diagram(config_file, dot_file, config_arg_str):
|
|
||||||
config = parse_config(config_file, config_arg_str)
|
|
||||||
make_diagram_from_proto(config.model_config, dot_file)
|
|
||||||
|
|
||||||
|
|
||||||
def make_diagram_from_proto(model_config, dot_file):
|
|
||||||
# print >> sys.stderr, config
|
|
||||||
name2id = {}
|
|
||||||
f = open(dot_file, 'w')
|
|
||||||
submodel_layers = set()
|
|
||||||
|
|
||||||
def make_link(link):
|
|
||||||
return 'l%s -> l%s;' % (name2id[link.layer_name],
|
|
||||||
name2id[link.link_name])
|
|
||||||
|
|
||||||
def make_mem(mem):
|
|
||||||
s = ''
|
|
||||||
if mem.boot_layer_name:
|
|
||||||
s += 'l%s -> l%s;\n' % (name2id[mem.boot_layer_name],
|
|
||||||
name2id[mem.layer_name])
|
|
||||||
s += 'l%s -> l%s [style=dashed];' % (name2id[mem.layer_name],
|
|
||||||
name2id[mem.link_name])
|
|
||||||
return s
|
|
||||||
|
|
||||||
print('digraph graphname {', file=f)
|
|
||||||
print('node [width=0.375,height=0.25];', file=f)
|
|
||||||
for i in six.moves.xrange(len(model_config.layers)):
|
|
||||||
l = model_config.layers[i]
|
|
||||||
name2id[l.name] = i
|
|
||||||
|
|
||||||
i = 0
|
|
||||||
for sub_model in model_config.sub_models:
|
|
||||||
if sub_model.name == 'root':
|
|
||||||
continue
|
|
||||||
print('subgraph cluster_%s {' % i, file=f)
|
|
||||||
print('style=dashed;', file=f)
|
|
||||||
label = '%s ' % sub_model.name
|
|
||||||
if sub_model.reversed:
|
|
||||||
label += '<=='
|
|
||||||
print('label = "%s";' % label, file=f)
|
|
||||||
i += 1
|
|
||||||
submodel_layers.add(sub_model.name)
|
|
||||||
for layer_name in sub_model.layer_names:
|
|
||||||
submodel_layers.add(layer_name)
|
|
||||||
lid = name2id[layer_name]
|
|
||||||
layer_config = model_config.layers[lid]
|
|
||||||
label = make_layer_label(layer_config)
|
|
||||||
print('l%s [label="%s", shape=box];' % (lid, label), file=f)
|
|
||||||
print('}', file=f)
|
|
||||||
|
|
||||||
for i in six.moves.xrange(len(model_config.layers)):
|
|
||||||
l = model_config.layers[i]
|
|
||||||
if l.name not in submodel_layers:
|
|
||||||
label = make_layer_label(l)
|
|
||||||
print('l%s [label="%s", shape=box];' % (i, label), file=f)
|
|
||||||
|
|
||||||
for sub_model in model_config.sub_models:
|
|
||||||
if sub_model.name == 'root':
|
|
||||||
continue
|
|
||||||
for link in sub_model.in_links:
|
|
||||||
print(make_link(link), file=f)
|
|
||||||
for link in sub_model.out_links:
|
|
||||||
print(make_link(link), file=f)
|
|
||||||
for mem in sub_model.memories:
|
|
||||||
print(make_mem(mem), file=f)
|
|
||||||
|
|
||||||
for i in six.moves.xrange(len(model_config.layers)):
|
|
||||||
for l in model_config.layers[i].inputs:
|
|
||||||
print(
|
|
||||||
'l%s -> l%s [label="%s"];' % (name2id[l.input_layer_name], i,
|
|
||||||
l.input_parameter_name),
|
|
||||||
file=f)
|
|
||||||
|
|
||||||
print('}', file=f)
|
|
||||||
f.close()
|
|
||||||
|
|
||||||
|
|
||||||
def usage():
|
|
||||||
print(
|
|
||||||
("Usage: python show_model_diagram.py" +
|
|
||||||
" CONFIG_FILE DOT_FILE [config_str]"),
|
|
||||||
file=sys.stderr)
|
|
||||||
exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
if len(sys.argv) < 3 or len(sys.argv) > 4:
|
|
||||||
usage()
|
|
||||||
|
|
||||||
config_file = sys.argv[1]
|
|
||||||
dot_file = sys.argv[2]
|
|
||||||
config_arg_str = sys.argv[3] if len(sys.argv) == 4 else ''
|
|
||||||
|
|
||||||
try:
|
|
||||||
make_diagram(config_file, dot_file, config_arg_str)
|
|
||||||
except:
|
|
||||||
traceback.print_exc()
|
|
||||||
raise
|
|
@ -1,73 +0,0 @@
|
|||||||
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import gzip
|
|
||||||
import struct
|
|
||||||
import os
|
|
||||||
|
|
||||||
from paddle.trainer_config_helpers.layers import LayerOutput
|
|
||||||
from paddle.v2.parameters import Parameters
|
|
||||||
from paddle.proto import ModelConfig_pb2
|
|
||||||
from paddle.v2.topology import Topology
|
|
||||||
|
|
||||||
|
|
||||||
def merge_v2_model(net, param_file, output_file):
|
|
||||||
'''Merge the model config and parameters into one file.
|
|
||||||
|
|
||||||
The model configuration file describes the model structure which
|
|
||||||
ends with .py. The parameters file stores the parameters of the model
|
|
||||||
which ends with .tar.gz.
|
|
||||||
|
|
||||||
@param net The output layer of the network for inference.
|
|
||||||
@param param_file Path of the parameters (.tar.gz) which is stored by
|
|
||||||
v2 api.
|
|
||||||
@param output_file Path of the merged file which will be generated.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
|
|
||||||
from paddle.utils.merge_model import merge_v2_model
|
|
||||||
# import your network configuration
|
|
||||||
from example_net import net_conf
|
|
||||||
|
|
||||||
net = net_conf(is_predict=True)
|
|
||||||
param_file = './param_pass_00000.tar.gz'
|
|
||||||
output_file = './output.paddle'
|
|
||||||
|
|
||||||
merge_v2_model(net, param_file, output_file)
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
assert isinstance(net, LayerOutput), \
|
|
||||||
"The net should be the output of the network for inference"
|
|
||||||
assert os.path.exists(param_file), \
|
|
||||||
"The model parameters file %s does not exists " % (param_file)
|
|
||||||
|
|
||||||
model_proto = Topology(net).proto()
|
|
||||||
assert isinstance(model_proto, ModelConfig_pb2.ModelConfig)
|
|
||||||
|
|
||||||
with gzip.open(param_file) as f:
|
|
||||||
params = Parameters.from_tar(f)
|
|
||||||
|
|
||||||
if os.path.exists(output_file):
|
|
||||||
os.remove(output_file)
|
|
||||||
|
|
||||||
with open(output_file, 'w') as f:
|
|
||||||
param_names = [param.name for param in model_proto.parameters]
|
|
||||||
conf_str = model_proto.SerializeToString()
|
|
||||||
f.write(struct.pack('q', len(conf_str)))
|
|
||||||
f.write(conf_str)
|
|
||||||
for pname in param_names:
|
|
||||||
params.serialize(pname, f)
|
|
||||||
|
|
||||||
print('Generate %s success!' % (output_file))
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in new issue