integrate_export

pull/7330/head
bai-yangfan 4 years ago
parent 60da54651b
commit 7dd91ffa26

@ -21,6 +21,6 @@ operations. Note that the entire computation is carried out in floating point. A
aware training, MindSpore provides conversion functions to convert the trained model into lower precision. aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
""" """
from .quant import convert_quant_network, export from .quant import convert_quant_network, export, manual_export
__all__ = ["convert_quant_network", "export"] __all__ = ["convert_quant_network", "export", "manual_export"]

@ -634,7 +634,7 @@ class ExportManualQuantNetwork:
""" """
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"] __quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self, network, mean, std_dev, *inputs, is_mindir): def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
network = Validator.check_isinstance('network', network, (nn.Cell,)) network = Validator.check_isinstance('network', network, (nn.Cell,))
self.input_scale = 1 / std_dev self.input_scale = 1 / std_dev
self.input_zero_point = round(mean) self.input_zero_point = round(mean)

@ -30,6 +30,9 @@ from mindspore.common.parameter import Parameter
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data from mindspore._checkparam import check_input_data
from mindspore.train.quant import quant
import mindspore.context as context
from .._checkparam import Validator
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print", __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
"build_searched_strategy", "merge_sliced_parameter"] "build_searched_strategy", "merge_sliced_parameter"]
@ -460,7 +463,7 @@ def _fill_param_into_net(net, parameter_list):
load_param_into_net(net, parameter_dict) load_param_into_net(net, parameter_dict)
def export(net, *inputs, file_name, file_format='AIR'): def export(net, *inputs, file_name, file_format='AIR', quant_export=None, **kwargs):
""" """
Export the MindSpore prediction model to a file in the specified format. Export the MindSpore prediction model to a file in the specified format.
@ -469,7 +472,6 @@ def export(net, *inputs, file_name, file_format='AIR'):
inputs (Tensor): Inputs of the `net`. inputs (Tensor): Inputs of the `net`.
file_name (str): File name of the model to be exported. file_name (str): File name of the model to be exported.
file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model. file_format (str): MindSpore currently supports 'AIR', 'ONNX' and 'MINDIR' format for exported model.
- AIR: Ascend Intermidiate Representation. An intermidiate representation format of Ascend model. - AIR: Ascend Intermidiate Representation. An intermidiate representation format of Ascend model.
Recommended suffix for output file is '.air'. Recommended suffix for output file is '.air'.
- ONNX: Open Neural Network eXchange. An open format built to represent machine learning models. - ONNX: Open Neural Network eXchange. An open format built to represent machine learning models.
@ -477,7 +479,17 @@ def export(net, *inputs, file_name, file_format='AIR'):
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format - MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
for MindSpore models. for MindSpore models.
Recommended suffix for output file is '.mindir'. Recommended suffix for output file is '.mindir'.
quant_export (str): Quantitative export choise. Default: None.
""" """
if quant_export == 'MANUAL':
mean = kwargs.get('mean', None)
std_dev = kwargs.get('std_dev', None)
QuantExport(net, *inputs, file_name, mean, std_dev, file_format='AIR', quant_manual_export=True)
elif quant_export == 'AUTO':
mean = kwargs.get('mean', None)
std_dev = kwargs.get('std_dev', None)
QuantExport(net, *inputs, file_name, mean, std_dev, file_format='AIR')
else:
logger.info("exporting model file:%s format:%s.", file_name, file_format) logger.info("exporting model file:%s format:%s.", file_name, file_format)
check_input_data(*inputs, data_class=Tensor) check_input_data(*inputs, data_class=Tensor)
@ -516,6 +528,55 @@ def export(net, *inputs, file_name, file_format='AIR'):
if is_dump_onnx_in_training: if is_dump_onnx_in_training:
net.set_train(mode=True) net.set_train(mode=True)
def QuantExport(network, file_name, mean, std_dev, *inputs, file_format='AIR', quant_manual_export=False):
"""
Exports MindSpore quantization predict model to deploy with AIR and MINDIR.
Args:
network (Cell): MindSpore network produced by `convert_quant_network`.
file_name (str): File name of model to export.
mean (int, float): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
inputs (Tensor): Inputs of the `quantization aware training network`.
file_format (str): MindSpore currently supports 'AIR' and 'MINDIR' format for exported
quantization aware model. Default: 'AIR'.
- AIR: Graph Engine Intermidiate Representation. An intermidiate representation format of
Ascend model.
- MINDIR: MindSpore Native Intermidiate Representation for Anf. An intermidiate representation format
for MindSpore models.
Recommended suffix for output file is '.mindir'.
quant_manual_export (bool): Is it manual quantitative export. Default: False.
"""
supported_device = ["Ascend", "GPU"]
supported_formats = ['AIR', 'MINDIR']
mean = mean if mean else 127.5
std_dev = std_dev if std_dev else 127.5
mean = Validator.check_type("mean", mean, (int, float))
std_dev = Validator.check_type("std_dev", std_dev, (int, float))
if context.get_context('device_target') not in supported_device:
raise KeyError("Unsupported {} device target.".format(context.get_context('device_target')))
if file_format not in supported_formats:
raise ValueError('Illegal file format {}.'.format(file_format))
network.set_train(False)
if file_format == "MINDIR":
if quant_manual_export:
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
else:
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
else:
if quant_manual_export:
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs)
else:
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
deploy_net = exporter.run()
export(deploy_net, *inputs, file_name=file_name, file_format=file_format)
def parse_print(print_file_name): def parse_print(print_file_name):
""" """

Loading…
Cancel
Save