|
|
|
@ -21,41 +21,42 @@ from paddle.v2.parameters import Parameters
|
|
|
|
|
from paddle.proto import ModelConfig_pb2
|
|
|
|
|
from paddle.v2.topology import Topology
|
|
|
|
|
|
|
|
|
|
def merge_model(net_out, param_file, output_file):
|
|
|
|
|
|
|
|
|
|
def merge_v2_model(net, param_file, output_file):
|
|
|
|
|
'''Integrate the model config and model parameters into one file.
|
|
|
|
|
|
|
|
|
|
The model configuration file describes the model structure which
|
|
|
|
|
ends with .py. The parameters file stores the parameters of the model
|
|
|
|
|
which ends with .tar.gz.
|
|
|
|
|
|
|
|
|
|
@param net_out the output layer of the network
|
|
|
|
|
@param param_file path of the model parameters file(a gzip file).
|
|
|
|
|
@param output_file path of the merged file which will be generated
|
|
|
|
|
@param net The output layer of the network.
|
|
|
|
|
@param param_file Path of the model parameters(.tar.gz) which is stored by v2 api.
|
|
|
|
|
@param output_file Path of the merged file which will be generated.
|
|
|
|
|
|
|
|
|
|
Usage:
|
|
|
|
|
|
|
|
|
|
from paddle.util.merge_model import merge_model
|
|
|
|
|
from paddle.util.merge_model import merge_v2_model
|
|
|
|
|
# import your network configuration
|
|
|
|
|
from mobilenet import mobile_net
|
|
|
|
|
|
|
|
|
|
net_out = mobile_net(3*224*224, 102)
|
|
|
|
|
param_file = YOUR_MODEL_PARAM_PATH
|
|
|
|
|
output_file = OUTPUT_MERGED_FILE_PATH
|
|
|
|
|
net = mobile_net(3*224*224, 102)
|
|
|
|
|
param_file = './param_pass_00000.tar.gz'
|
|
|
|
|
output_file = './output.paddle'
|
|
|
|
|
|
|
|
|
|
merge_model(net_out, param_file, output_file)
|
|
|
|
|
merge_v2_model(net, param_file, output_file)
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
assert isinstance(net_out, LayerOutput), \
|
|
|
|
|
"The net_out should be the output of the network"
|
|
|
|
|
assert isinstance(net, LayerOutput), \
|
|
|
|
|
"The net should be the output of the network"
|
|
|
|
|
assert os.path.exists(param_file), \
|
|
|
|
|
"The model parameters file %s does not exists " % (param_file)
|
|
|
|
|
|
|
|
|
|
model_proto = Topology(net_out).proto()
|
|
|
|
|
model_proto = Topology(net).proto()
|
|
|
|
|
assert isinstance(model_proto, ModelConfig_pb2.ModelConfig)
|
|
|
|
|
|
|
|
|
|
with gzip.open(param_file) as f:
|
|
|
|
|
params = Parameters.from_tar(f)
|
|
|
|
|
with gzip.open(param_file) as f:
|
|
|
|
|
params = Parameters.from_tar(f)
|
|
|
|
|
|
|
|
|
|
if os.path.exists(output_file):
|
|
|
|
|
os.remove(output_file)
|
|
|
|
|