From 9607778f01d07abd36de23c527cc983e1c71afb4 Mon Sep 17 00:00:00 2001 From: bai-yangfan Date: Wed, 21 Oct 2020 15:39:54 +0800 Subject: [PATCH] mode_export_v3 --- mindspore/train/quant/quant.py | 12 +++++++----- mindspore/train/serialization.py | 4 ++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 6fcf6d7609..4c564d95ce 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -391,14 +391,16 @@ class ExportToQuantInferNetwork: scale_w, zp_w, param_dict["filter_maxq"], param_dict["filter_minq"] = \ quant_utils.scale_zp_max_min_from_fake_quant_cell(cell_core.fake_quant_weight, np_type) - _, _, param_dict["output_maxq"], param_dict["output_minq"] = \ - quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type) + if fake_quant_a_out is not None: + _, _, param_dict["output_maxq"], param_dict["output_minq"] = \ + quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_a_out, np_type) info = self.quant_info_table.get(w_minq_name, None) if info: fake_quant_a_in_op, minq_name = info if minq_name == 'input': - scale_a_in, zp_a_in = self.input_scale, self.input_zero_point + scale_a_in, zp_a_in, param_dict["input_maxq"], param_dict["input_minq"] = \ + self.input_scale, self.input_zero_point, 'None', 'None' else: maxq = self.all_parameters[minq_name[:-4] + "maxq"] minq = self.all_parameters[minq_name] @@ -483,11 +485,11 @@ class ExportToQuantInferNetwork: if isinstance(subcell, quant.Conv2dBnAct): cell_core = subcell.conv activation = subcell.activation - fake_quant_act = activation.fake_quant_act + fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None elif isinstance(subcell, quant.DenseBnAct): cell_core = subcell.dense activation = subcell.activation - fake_quant_act = activation.fake_quant_act + fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None if cell_core is not None: new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) if new_subcell: diff --git a/mindspore/train/serialization.py b/mindspore/train/serialization.py index 5220f4a8fd..2d1f7818c7 100644 --- a/mindspore/train/serialization.py +++ b/mindspore/train/serialization.py @@ -519,7 +519,7 @@ def export(net, *inputs, file_name, file_format='AIR', **kwargs): logger.info("exporting model file:%s format:%s.", file_name, file_format) check_input_data(*inputs, data_class=Tensor) - net = _quant_export(net, *inputs, file_format='AIR', **kwargs) + net = _quant_export(net, *inputs, file_format=file_format, **kwargs) _export(net, file_name, file_format, *inputs) @@ -566,7 +566,7 @@ def _export(net, file_name, file_format, *inputs): net.set_train(mode=True) -def _quant_export(network, *inputs, file_format='AIR', **kwargs): +def _quant_export(network, *inputs, file_format, **kwargs): """ Exports MindSpore quantization predict model to deploy with AIR and MINDIR. """