!5927 export quant binary model

Merge pull request !5927 from changzherui/export_quant_minary
pull/5927/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 0b37c55272

@ -1393,3 +1393,66 @@ class QuantBlock(Cell):
str_info = str_info + f', activation={self.activation}'
str_info = str_info + f', dequant={self.dequant}'
return str_info
class QuantMindirBlock(Cell):
"""A quant binary block of Conv/Dense, activation layer for export MINDIR model.
Args:
core_op (Cell): The operation cell.
weight (Tensor): The weigth of the cell.
bias (Tensor): The bias of the cell. Default: None.
activation (str): The regularization function applied to the output of the layer, eg. 'relu'. Default: None.
param_dict (dict): The information of the cell.
"""
def __init__(self,
core_op,
weight,
bias=None,
activation=None,
param_dict=None):
super(QuantMindirBlock, self).__init__()
self.core_op = core_op
if activation is not None:
self.core_op.add_prim_attr("activation_name", activation.__class__.__name__)
self.core_op.add_prim_attr("filter_maxq", Tensor(param_dict["filter_maxq"]))
self.core_op.add_prim_attr("filter_minq", Tensor(param_dict["filter_minq"]))
self.core_op.add_prim_attr("output_maxq", Tensor(param_dict["output_maxq"]))
self.core_op.add_prim_attr("output_minq", Tensor(param_dict["output_minq"]))
self.core_op.add_prim_attr("symmetric", Tensor(param_dict["symmetric"]))
if hasattr(core_op, 'pad_mode'):
self.core_op.add_prim_attr("pad_mode", core_op.pad_mode)
self.core_op.add_prim_attr("num_bits", Tensor(8))
self.core_op.add_prim_attr("narrow_range", Tensor(False))
if param_dict["input_maxq"] is not None:
self.core_op.add_prim_attr("input_maxq", Tensor(param_dict["input_maxq"]))
self.core_op.add_prim_attr("input_minq", Tensor(param_dict["input_minq"]))
else:
self.core_op.add_prim_attr("mean", Tensor(param_dict["mean"]))
self.core_op.add_prim_attr("std_dev", Tensor(param_dict["std_dev"]))
self.weight = weight
self.bias = bias
self.has_bias = bias is not None
self.activation = activation
self.has_act = activation is not None
if isinstance(activation, ReLU):
self.activation = None
self.has_act = False
self.bias_add = P.BiasAdd()
def construct(self, x):
if self.has_bias:
x = self.core_op(x, self.weight, self.bias)
else:
x = self.core_op(x, self.weight)
return x
def extend_repr(self):
str_info = f'core_op={type(self.core_op)}, weight=shape[{self.weight.shape}]'
if self.has_bias:
str_info = str_info + f', bias=shape[{self.bias.shape}]'
if self.has_act:
str_info = str_info + f', activation={self.activation}'
return str_info

@ -304,13 +304,14 @@ class ExportToQuantInferNetwork:
inputs (Tensor): Input tensors of the `quantization aware training network`.
mean (int): Input data mean. Default: 127.5.
std_dev (int, float): Input data variance. Default: 127.5.
is_mindir (bool): Whether is MINDIR format. Default: False.
Returns:
Cell, Infer network.
"""
__quant_op_name__ = ["TensorAdd", "Sub", "Mul", "RealDiv"]
def __init__(self, network, mean, std_dev, *inputs):
def __init__(self, network, mean, std_dev, *inputs, is_mindir=False):
network = validator.check_isinstance('network', network, (nn.Cell,))
# quantize for inputs: q = f / scale + zero_point
# dequantize for outputs: f = (q - zero_point) * scale
@ -320,6 +321,9 @@ class ExportToQuantInferNetwork:
self.network = copy.deepcopy(network)
self.all_parameters = {p.name: p for p in self.network.get_parameters()}
self.get_inputs_table(inputs)
self.mean = mean
self.std_dev = std_dev
self.is_mindir = is_mindir
def get_inputs_table(self, inputs):
"""Get the support info for quant export."""
@ -341,8 +345,24 @@ class ExportToQuantInferNetwork:
# Calculate the scale and zero point
w_minq_name = cell_core.fake_quant_weight.minq.name
np_type = mstype.dtype_to_nptype(self.data_type)
scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type)
param_dict = dict()
param_dict["filter_maxq"] = None
param_dict["filter_minq"] = None
param_dict["output_maxq"] = None
param_dict["output_minq"] = None
param_dict["input_maxq"] = None
param_dict["input_minq"] = None
param_dict["mean"] = self.mean
param_dict["std_dev"] = self.std_dev
param_dict["symmetric"] = fake_quant_a_out.symmetric
if self.is_mindir:
scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \
quant_utils.scale_zp_max_min_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \
quant_utils.scale_zp_max_min_from_fack_quant_cell(fake_quant_a_out, np_type)
else:
scale_w, zp_w = quant_utils.scale_zp_from_fack_quant_cell(cell_core.fake_quant_weight, np_type)
scale_a_out, _ = quant_utils.scale_zp_from_fack_quant_cell(fake_quant_a_out, np_type)
info = self.quant_info_table.get(w_minq_name, None)
if info:
fack_quant_a_in_op, minq_name = info
@ -351,7 +371,11 @@ class ExportToQuantInferNetwork:
else:
maxq = self.all_parameters[minq_name[:-4] + "maxq"]
minq = self.all_parameters[minq_name]
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, minq, maxq, np_type)
if self.is_mindir:
scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \
quant_utils.scale_zp_max_min_from_data(fack_quant_a_in_op, minq, maxq, np_type)
else:
scale_a_in, zp_a_in = quant_utils.scale_zp_from_data(fack_quant_a_in_op, minq, maxq, np_type)
else:
logger.warning(f"Do not find `fake_quant` from input with `fake_quant.minq` {w_minq_name}")
return None
@ -377,7 +401,8 @@ class ExportToQuantInferNetwork:
weight, bias = quant_utils.fold_batchnorm(weight, cell_core)
elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant):
weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core)
weight_b = weight
bias_b = bias
# apply the quant
weight = quant_utils.weight2int(weight, scale_w, zp_w)
if bias is not None:
@ -398,10 +423,16 @@ class ExportToQuantInferNetwork:
if isinstance(cell_core, quant.DenseQuant):
op_core = P.MatMul()
weight = np.transpose(weight)
weight_b = np.transpose(weight_b)
else:
op_core = cell_core.conv
weight = Tensor(weight, self.data_type)
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
weight_b = Tensor(weight_b)
bias_b = Tensor(bias_b, mstype.float32)
if self.is_mindir:
block = quant.QuantMindirBlock(op_core, weight_b, bias_b, activation, param_dict)
else:
block = quant.QuantBlock(op_core, weight, quant_op, dequant_op, scale_deq, bias, activation)
return block
def _convert_quant2deploy(self, network):
@ -475,8 +506,10 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format='
raise ValueError('Illegal file format {}.'.format(file_format))
network.set_train(False)
exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
if file_format == "MINDIR":
exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
else:
exporter = ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
deploy_net = exporter.run()
serialization.export(deploy_net, *inputs, file_name=file_name, file_format=file_format)

@ -146,6 +146,20 @@ def scale_zp_from_fack_quant_cell(cell, data_type):
return scale, zp
def scale_zp_max_min_from_fack_quant_cell(cell, data_type):
"""Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`."""
minq = cell.minq.data.asnumpy()
maxq = cell.maxq.data.asnumpy()
op = cell.fake_quant_infer
scale, zp = cal_quantization_params(
minq, maxq, data_type,
num_bits=op.num_bits,
symmetric=op.symmetric,
narrow_range=op.narrow_range)
return scale, zp, maxq, minq
def scale_zp_from_data(op, minq, maxq, data_type):
r"""
Get calculate quantization params for scale and zero point.
@ -174,6 +188,19 @@ def scale_zp_from_data(op, minq, maxq, data_type):
return scale, zp
def scale_zp_max_min_from_data(op, minq, maxq, data_type):
"""Get calculate quantization params for scale, zero point, max and min."""
minq = minq.data.asnumpy()
maxq = maxq.data.asnumpy()
scale, zp = cal_quantization_params(
minq, maxq, data_type,
num_bits=op.num_bits,
symmetric=op.symmetric,
narrow_range=op.narrow_range)
return scale, zp, maxq, minq
def fold_batchnorm(weight, cell_quant):
r"""
Fold the batchnorm in `Conv2dBnFoldQuant` to weight.

Loading…
Cancel
Save