diff --git a/mindspore/compression/export/quant_export.py b/mindspore/compression/export/quant_export.py index 979ad23864..96a9121bf1 100644 --- a/mindspore/compression/export/quant_export.py +++ b/mindspore/compression/export/quant_export.py @@ -228,10 +228,17 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork): __quant_op_name__ = ["Add", "Sub", "Mul", "RealDiv"] def __init__(self, network, mean, std_dev, *inputs, is_mindir=False): - super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir) + super(ExportManualQuantNetwork, self).__init__(network, mean, std_dev, *inputs, is_mindir=is_mindir) self.upcell = None self.upname = None + def _add_output_min_max_for_op(self, origin_op, fake_quant_cell): + if self.is_mindir: + np_type = mstype.dtype_to_nptype(self.data_type) + _, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(fake_quant_cell, np_type) + origin_op.add_prim_attr('output_maxq', Tensor(maxq)) + origin_op.add_prim_attr('output_minq', Tensor(minq)) + def _convert_quant2deploy(self, network): """Convert network's all quant subcell to deploy subcell.""" cells = network.name_cells() @@ -247,18 +254,31 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork): elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant, quant.Conv2dQuant, quant.DenseQuant)): network, change = self._convert_subcell(network, change, name, subcell, core=False) - elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver) and self.upcell: - np_type = mstype.dtype_to_nptype(self.data_type) - _, _, maxq, minq = quant_utils.scale_zp_max_min_from_fake_quant_cell(subcell, np_type) - self.upcell.core_op.add_prim_attr('output_maxq', Tensor(maxq)) - self.upcell.core_op.add_prim_attr('output_minq', Tensor(minq)) - network.insert_child_to_cell(self.upname, self.upcell) + elif isinstance(subcell, nn.ActQuant) and hasattr(subcell, "get_origin"): + if self.upcell: + self._add_output_min_max_for_op(self.upcell.core_op, subcell.fake_quant_act) + activation = subcell.get_origin() + network.insert_child_to_cell(name, activation) + change = True + elif isinstance(subcell, nn.TensorAddQuant): + if isinstance(subcell.add, _AddFakeQuantAfterSubCell): + add_op = subcell.add.subcell + subcell.__delattr__("add") + subcell.__setattr__("add", add_op) + add_op = subcell.add + if add_op: + self._add_output_min_max_for_op(add_op, subcell.fake_quant_act) + subcell.__delattr__("fake_quant_act") + subcell.__setattr__("fake_quant_act", P.identity()) + elif isinstance(subcell, quant.FakeQuantWithMinMaxObserver): + if self.upcell: + self._add_output_min_max_for_op(self.upcell.core_op, subcell) + network.__delattr__(name) + network.__setattr__(name, P.identity()) elif isinstance(subcell, _AddFakeQuantAfterSubCell): op = subcell.subcell if op.name in QuantizationAwareTraining.__quant_op_name__ and isinstance(op, ops.Primitive): - if self.is_mindir: - op.add_prim_attr('output_maxq', Tensor(subcell.fake_quant_act.maxq.data.asnumpy())) - op.add_prim_attr('output_minq', Tensor(subcell.fake_quant_act.minq.data.asnumpy())) + self._add_output_min_max_for_op(op, subcell.fake_quant_act) network.__delattr__(name) network.__setattr__(name, op) change = True @@ -271,15 +291,18 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork): def _convert_subcell(self, network, change, name, subcell, core=True, conv=True): """Convert subcell to ant subcell.""" + new_subcell = None if core: cell_core = subcell.conv if conv else subcell.dense activation = subcell.activation - fake_quant_act = activation.fake_quant_act + if hasattr(activation, 'fake_quant_act'): + fake_quant_act = activation.fake_quant_act + new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) else: cell_core = subcell activation = None fake_quant_act = None - new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) + new_subcell = self._get_quant_block(cell_core, activation, fake_quant_act) if new_subcell: prefix = subcell.param_prefix new_subcell.update_parameters_name(prefix + '.') diff --git a/model_zoo/official/cv/resnet50_quant/README.md b/model_zoo/official/cv/resnet50_quant/README.md index 6c75d0d9f4..29d1d99d6d 100644 --- a/model_zoo/official/cv/resnet50_quant/README.md +++ b/model_zoo/official/cv/resnet50_quant/README.md @@ -87,6 +87,7 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil │ ├──crossentropy.py # define the crossentropy of resnet50-quant ├── train.py # training script ├── eval.py # evaluation script + ├── export.py # export script ``` diff --git a/model_zoo/official/cv/resnet50_quant/README_CN.md b/model_zoo/official/cv/resnet50_quant/README_CN.md index ac0f43332d..dba61f2ac5 100644 --- a/model_zoo/official/cv/resnet50_quant/README_CN.md +++ b/model_zoo/official/cv/resnet50_quant/README_CN.md @@ -95,6 +95,7 @@ ResNet-50总体网络架构如下: │ ├──crossentropy.py # 定义ResNet-50-Quant的交叉熵 ├── train.py # 训练脚本 ├── eval.py # 评估脚本 + ├── export.py # 导出脚本 ``` diff --git a/model_zoo/official/cv/resnet50_quant/export.py b/model_zoo/official/cv/resnet50_quant/export.py new file mode 100644 index 0000000000..81f2ef3caa --- /dev/null +++ b/model_zoo/official/cv/resnet50_quant/export.py @@ -0,0 +1,53 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Export Resnet50 on ImageNet""" + +import argparse +import numpy as np + +import mindspore +from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export +from mindspore.compression.quant import QuantizationAwareTraining + +from models.resnet_quant_manual import resnet50_quant +from src.config import config_quant + +parser = argparse.ArgumentParser(description='Image classification') +parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') +parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format") +parser.add_argument('--device_target', type=str, default=None, help='Run device target') +args_opt = parser.parse_args() + +if __name__ == '__main__': + context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, save_graphs=False) + # define fusion network + network = resnet50_quant(class_num=config_quant.class_num) + # convert fusion network to quantization aware network + quantizer = QuantizationAwareTraining(bn_fold=True, + per_channel=[True, False], + symmetric=[True, False]) + network = quantizer.quantize(network) + # load checkpoint + if args_opt.checkpoint_path: + param_dict = load_checkpoint(args_opt.checkpoint_path) + not_load_param = load_param_into_net(network, param_dict) + if not_load_param: + raise ValueError("Load param into network fail!") + # export network + print("============== Starting export ==============") + inputs = Tensor(np.ones([1, 3, 224, 224]), mindspore.float32) + export(network, inputs, file_name="resnet50_quant", file_format=args_opt.file_format, + quant_mode='MANUAL', mean=0., std_dev=48.106) + print("============== End export ==============")