From 2b07d7ffb305434e3b0397e56b32b53b74e60f63 Mon Sep 17 00:00:00 2001 From: bai-yangfan Date: Wed, 23 Sep 2020 14:32:36 +0800 Subject: [PATCH] mindir_mode --- mindspore/train/quant/quant.py | 27 +++++++++++++++++++-------- mindspore/train/quant/quant_utils.py | 25 ------------------------- 2 files changed, 19 insertions(+), 33 deletions(-) diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index c5de924500..c32380b036 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -342,6 +342,14 @@ class ExportToQuantInferNetwork: network = self._convert_quant2deploy(network) return network + def statistic_weight(self, weight): + out_nums = np.shape(weight)[0] + sta_metric = np.zeros((out_nums, 2), dtype=np.float32) + for num in range(out_nums): + sta_metric[num, 0] = np.min(weight[num]) + sta_metric[num, 1] = np.max(weight[num]) + return np.mean(sta_metric[:, 1]).tolist(), np.mean(sta_metric[:, 0]).tolist() + def _get_quant_block(self, cell_core, activation, fake_quant_a_out): """convet network's quant subcell to deploy subcell""" # Calculate the scale and zero point @@ -357,14 +365,12 @@ class ExportToQuantInferNetwork: param_dict["mean"] = self.mean param_dict["std_dev"] = self.std_dev param_dict["symmetric"] = fake_quant_a_out.symmetric - if self.is_mindir: - scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \ - quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) - scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \ - quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type) - else: - scale_w, zp_w = quant_utils.scale_zp_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) - scale_a_out, _ = quant_utils.scale_zp_from_fake_quant_cell(fake_quant_a_out, np_type) + + scale_w, zp_w, _, _ = \ + quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) + scale_a_out, _, param_dict["output_maxq"], param_dict["output_minq"] = \ + quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type) + info = self.quant_info_table.get(w_minq_name, None) if info: fack_quant_a_in_op, minq_name = info @@ -403,6 +409,8 @@ class ExportToQuantInferNetwork: weight, bias = quant_utils.fold_batchnorm(weight, cell_core) elif isinstance(cell_core, quant.Conv2dBnWithoutFoldQuant): weight, bias = quant_utils.without_fold_batchnorm(weight, cell_core) + if self.is_mindir: + param_dict["filter_maxq"], param_dict["filter_minq"] = self.statistic_weight(weight) weight_b = weight bias_b = bias # apply the quant @@ -467,6 +475,9 @@ class ExportToQuantInferNetwork: elif isinstance(subcell, _AddFakeQuantAfterSubCell): op = subcell.subcell if op.name in ConvertToQuantNetwork.__quant_op_name__ and isinstance(op, ops.Primitive): + if self.is_mindir: + op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy())) + op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy())) network.__delattr__(name) network.__setattr__(name, op) change = True diff --git a/mindspore/train/quant/quant_utils.py b/mindspore/train/quant/quant_utils.py index d115cb6e82..1ca8ae0bd3 100644 --- a/mindspore/train/quant/quant_utils.py +++ b/mindspore/train/quant/quant_utils.py @@ -120,31 +120,6 @@ def weight2int(data, scale, zero_point): return np.round((data / scale) + zero_point) - -def scale_zp_from_fake_quant_cell(cell, data_type): - r""" - Get calculate quantization params for scale and zero point From `FakeQuantWithMinMax`. - - Args: - cell (Cell): `mindspore.nn.layer.FakeQuantWithMinMax` - data_type (numpy type): Can ben `numpy.int8` or `numpy.uint8`. - - Returns: - scale (numpy.ndarray): quantization param. - zero point (numpy.ndarray): quantization param. - """ - minq = cell.minq.data.asnumpy() - maxq = cell.maxq.data.asnumpy() - op = cell.fake_quant_infer - - scale, zp = cal_quantization_params( - minq, maxq, data_type, - num_bits=op.num_bits, - symmetric=op.symmetric, - narrow_range=op.narrow_range) - return scale, zp - - def scale_zp_max_min_from_fake_quant_cell(cell, data_type): """Get calculate quantization params for scale, zero point, max and min from `FakeQuantWithMinMax`.""" minq = cell.minq.data.asnumpy()