!7463 move train.quant to compression module & add QuantizationAwareTraining

Merge pull request !7463 from yuchaojie/quant2
pull/7463/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 3b55a25f8d

@ -0,0 +1,17 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Compression export module.
"""

@ -13,14 +13,9 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
""" """
Quantization. Compression quant module.
User can use quantization aware to train a model. MindSpore supports quantization aware training,
which models quantization errors in both the forward and backward passes using fake-quantization
operations. Note that the entire computation is carried out in floating point. At the end of quantization
aware training, MindSpore provides conversion functions to convert the trained model into lower precision.
""" """
from .quant import convert_quant_network, export, manual_export from .quantizer import *
from .qat import *
__all__ = ["convert_quant_network", "export", "manual_export"] from .quant_utils import *

File diff suppressed because it is too large Load Diff

@ -17,6 +17,9 @@
import numpy as np import numpy as np
__all__ = ["load_nonquant_param_into_quant_net"]
def cal_quantization_params(input_min, def cal_quantization_params(input_min,
input_max, input_max,
data_type, data_type,

@ -0,0 +1,52 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Base Class of Quantizer."""
from abc import ABC, abstractmethod
from enum import Enum
__all__ = ["OptimizeOption", "Quantizer"]
class OptimizeOption(Enum):
"""
An enum for the model quantization optimize option.
"""
# using quantization aware training
QAT = "QAT"
def __str__(self):
return self.value
class Quantizer(ABC):
"""
Base class of Quantizer. You can implement different kind of quantizer to get different quantization result.
Notes:
This class is an abstract class.
Args:
optimize_option (OptimizeOption, list or tuple): Specifies the quant algorithm and options. Default: None.
"""
def __init__(self,
optimize_option=None):
if not isinstance(optimize_option, list) and not isinstance(optimize_option, tuple):
optimize_option = [optimize_option]
self.optimize_option = optimize_option
@abstractmethod
def quantize(self, network):
pass

@ -30,7 +30,7 @@ from mindspore.common.parameter import Parameter
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore._checkparam import check_input_data, Validator from mindspore._checkparam import check_input_data, Validator
from mindspore.train.quant import quant from mindspore.compression.export import quant_export
import mindspore.context as context import mindspore.context as context
__all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print", __all__ = ["save_checkpoint", "load_checkpoint", "load_param_into_net", "export", "parse_print",
@ -596,14 +596,14 @@ def _quant_export(network, *inputs, file_format, **kwargs):
network.set_train(False) network.set_train(False)
if file_format == "MINDIR": if file_format == "MINDIR":
if quant_mode == 'MANUAL': if quant_mode == 'MANUAL':
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True) exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs, is_mindir=True)
else: else:
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True) exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs, is_mindir=True)
else: else:
if quant_mode == 'MANUAL': if quant_mode == 'MANUAL':
exporter = quant.ExportManualQuantNetwork(network, mean, std_dev, *inputs) exporter = quant_export.ExportManualQuantNetwork(network, mean, std_dev, *inputs)
else: else:
exporter = quant.ExportToQuantInferNetwork(network, mean, std_dev, *inputs) exporter = quant_export.ExportToQuantInferNetwork(network, mean, std_dev, *inputs)
deploy_net = exporter.run() deploy_net = exporter.run()
return deploy_net return deploy_net

@ -25,7 +25,7 @@ from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from src.dataset import create_dataset from src.dataset import create_dataset
from src.config import mnist_cfg as cfg from src.config import mnist_cfg as cfg
from src.lenet_fusion import LeNet5 as LeNet5Fusion from src.lenet_fusion import LeNet5 as LeNet5Fusion
@ -47,8 +47,12 @@ if __name__ == "__main__":
# define fusion network # define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, quantizer = QuantizationAwareTraining(quant_delay=0,
per_channel=[True, False], symmetric=[True, False]) bn_fold=False,
freeze_bn=10000,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)
# define loss # define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")

@ -22,7 +22,7 @@ import numpy as np
import mindspore import mindspore
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from src.config import mnist_cfg as cfg from src.config import mnist_cfg as cfg
@ -44,8 +44,12 @@ if __name__ == "__main__":
# define fusion network # define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, quantizer = QuantizationAwareTraining(quant_delay=0,
per_channel=[True, False], symmetric=[True, False]) bn_fold=False,
freeze_bn=10000,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)
# load quantization aware network checkpoint # load quantization aware network checkpoint
param_dict = load_checkpoint(args.ckpt_path) param_dict = load_checkpoint(args.ckpt_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)

@ -26,8 +26,8 @@ from mindspore.train.serialization import load_checkpoint
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train import Model from mindspore.train import Model
from mindspore.nn.metrics import Accuracy from mindspore.nn.metrics import Accuracy
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
from mindspore.common import set_seed from mindspore.common import set_seed
from src.dataset import create_dataset from src.dataset import create_dataset
from src.config import mnist_cfg as cfg from src.config import mnist_cfg as cfg
@ -59,8 +59,11 @@ if __name__ == "__main__":
load_nonquant_param_into_quant_net(network, param_dict) load_nonquant_param_into_quant_net(network, param_dict)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False], quantizer = QuantizationAwareTraining(quant_delay=900,
bn_fold=False,
per_channel=[True, False],
symmetric=[True, False]) symmetric=[True, False])
network = quantizer.quantize(network)
# define network loss # define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")

@ -21,7 +21,7 @@ from mindspore import context
from mindspore import nn from mindspore import nn
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from src.mobilenetV2 import mobilenetV2 from src.mobilenetV2 import mobilenetV2
from src.dataset import create_dataset from src.dataset import create_dataset
@ -51,7 +51,10 @@ if __name__ == '__main__':
# define fusion network # define fusion network
network = mobilenetV2(num_classes=config_device_target.num_classes) network = mobilenetV2(num_classes=config_device_target.num_classes)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)
# define network loss # define network loss
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

@ -21,7 +21,7 @@ import mindspore
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from src.mobilenetV2 import mobilenetV2 from src.mobilenetV2 import mobilenetV2
from src.config import config_ascend_quant from src.config import config_ascend_quant
@ -42,7 +42,10 @@ if __name__ == '__main__':
# define fusion network # define fusion network
network = mobilenetV2(num_classes=cfg.num_classes) network = mobilenetV2(num_classes=cfg.num_classes)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)
# load checkpoint # load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path) param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(network, param_dict) load_param_into_net(network, param_dict)

@ -26,8 +26,8 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint from mindspore.train.serialization import load_checkpoint
from mindspore.communication.management import init, get_group_size, get_rank from mindspore.communication.management import init, get_group_size, get_rank
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
from mindspore.common import set_seed from mindspore.common import set_seed
from src.dataset import create_dataset from src.dataset import create_dataset
@ -99,10 +99,10 @@ def train_on_ascend():
param_dict = load_checkpoint(args_opt.pre_trained) param_dict = load_checkpoint(args_opt.pre_trained)
load_nonquant_param_into_quant_net(network, param_dict) load_nonquant_param_into_quant_net(network, param_dict)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quantizer = QuantizationAwareTraining(bn_fold=True,
bn_fold=True,
per_channel=[True, False], per_channel=[True, False],
symmetric=[True, False]) symmetric=[True, False])
network = quantizer.quantize(network)
# get learning rate # get learning rate
lr = Tensor(get_lr(global_step=config.start_epoch * step_size, lr = Tensor(get_lr(global_step=config.start_epoch * step_size,
@ -162,12 +162,12 @@ def train_on_gpu():
load_nonquant_param_into_quant_net(network, param_dict) load_nonquant_param_into_quant_net(network, param_dict)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quantizer = QuantizationAwareTraining(bn_fold=True,
bn_fold=True,
per_channel=[True, False], per_channel=[True, False],
symmetric=[True, False], symmetric=[True, False],
freeze_bn=1000000, freeze_bn=1000000,
quant_delay=step_size * 2) quant_delay=step_size * 2)
network = quantizer.quantize(network)
# get learning rate # get learning rate
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)

@ -26,7 +26,7 @@ from models.resnet_quant_manual import resnet50_quant #manually construct quanta
from mindspore import context from mindspore import context
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
@ -43,12 +43,13 @@ if args_opt.device_target == "Ascend":
if __name__ == '__main__': if __name__ == '__main__':
# define fusion network # define fusion network
net = resnet50_quant(class_num=config.class_num) network = resnet50_quant(class_num=config.class_num)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
net = quant.convert_quant_network(net, quantizer = QuantizationAwareTraining(bn_fold=True,
bn_fold=True, per_channel=[True, False],
per_channel=[True, False], symmetric=[True, False])
symmetric=[True, False]) network = quantizer.quantize(network)
# define network loss # define network loss
if not config.use_label_smooth: if not config.use_label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
@ -65,13 +66,13 @@ if __name__ == '__main__':
# load checkpoint # load checkpoint
if args_opt.checkpoint_path: if args_opt.checkpoint_path:
param_dict = load_checkpoint(args_opt.checkpoint_path) param_dict = load_checkpoint(args_opt.checkpoint_path)
not_load_param = load_param_into_net(net, param_dict) not_load_param = load_param_into_net(network, param_dict)
if not_load_param: if not_load_param:
raise ValueError("Load param into net fail!") raise ValueError("Load param into network fail!")
net.set_train(False) network.set_train(False)
# define model # define model
model = Model(net, loss_fn=loss, metrics={'acc'}) model = Model(network, loss_fn=loss, metrics={'acc'})
print("============== Starting Validation ==============") print("============== Starting Validation ==============")
res = model.eval(dataset) res = model.eval(dataset)

@ -17,14 +17,14 @@ import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant
from mindspore.train.quant import quant from mindspore.compression.quant import qat
_ema_decay = 0.999 _ema_decay = 0.999
_symmetric = True _symmetric = True
_fake = True _fake = True
_per_channel = True _per_channel = True
_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) _quant_config = qat.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))
def _weight_variable(shape, factor=0.01): def _weight_variable(shape, factor=0.01):
@ -90,8 +90,8 @@ class ConvBNReLU(nn.Cell):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__() super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, conv = Conv2dBnFoldQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
group=groups, fake=_fake, quant_config=_quant_config) group=groups, fake=_fake, quant_config=_quant_config)
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
self.features = nn.SequentialCell(layers) self.features = nn.SequentialCell(layers)
@ -126,14 +126,14 @@ class ResidualBlock(nn.Cell):
channel = out_channel // self.expansion channel = out_channel // self.expansion
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, self.conv3 = nn.SequentialCell([Conv2dBnFoldQuant(channel, out_channel, fake=_fake,
quant_config=_quant_config, quant_config=_quant_config,
kernel_size=1, stride=1, pad_mode='same', padding=0), kernel_size=1, stride=1, pad_mode='same', padding=0),
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False)
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, ]) if _fake else Conv2dBnFoldQuant(channel, out_channel, fake=_fake,
quant_config=_quant_config, quant_config=_quant_config,
kernel_size=1, stride=1, kernel_size=1, stride=1,
pad_mode='same', padding=0) pad_mode='same', padding=0)
self.down_sample = False self.down_sample = False
@ -142,20 +142,19 @@ class ResidualBlock(nn.Cell):
self.down_sample_layer = None self.down_sample_layer = None
if self.down_sample: if self.down_sample:
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, self.down_sample_layer = nn.SequentialCell([Conv2dBnFoldQuant(in_channel, out_channel,
quant_config=_quant_config, quant_config=_quant_config,
kernel_size=1, stride=stride, kernel_size=1, stride=stride,
pad_mode='same', padding=0), pad_mode='same', padding=0),
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay,
symmetric=False) symmetric=False)
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, ]) if _fake else Conv2dBnFoldQuant(in_channel, out_channel,
fake=_fake, fake=_fake,
quant_config=\ quant_config=_quant_config,
_quant_config, kernel_size=1,
kernel_size=1, stride=stride,
stride=stride, pad_mode='same',
pad_mode='same', padding=0)
padding=0)
self.add = nn.TensorAddQuant() self.add = nn.TensorAddQuant()
self.relu = P.ReLU() self.relu = P.ReLU()

@ -25,8 +25,8 @@ from mindspore.context import ParallelMode
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.serialization import load_checkpoint from mindspore.train.serialization import load_checkpoint
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
from mindspore.communication.management import init from mindspore.communication.management import init
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.initializer as weight_init import mindspore.common.initializer as weight_init
@ -113,7 +113,10 @@ if __name__ == '__main__':
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
net = quant.convert_quant_network(net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)
# get learning rate # get learning rate
lr = get_lr(lr_init=config.lr_init, lr = get_lr(lr_init=config.lr_init,

@ -29,7 +29,7 @@ from mindspore.context import ParallelMode
from mindspore import context from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
import mindspore as ms import mindspore as ms
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from src.yolo import YOLOV3DarkNet53 from src.yolo import YOLOV3DarkNet53
from src.logger import get_logger from src.logger import get_logger
@ -265,10 +265,10 @@ def test():
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
if config.quantization_aware: if config.quantization_aware:
network = quant.convert_quant_network(network, quantizer = QuantizationAwareTraining(bn_fold=True,
bn_fold=True,
per_channel=[True, False], per_channel=[True, False],
symmetric=[True, False]) symmetric=[True, False])
network = quantizer.quantize(network)
args.logger.info(args.pretrained) args.logger.info(args.pretrained)
if os.path.isfile(args.pretrained): if os.path.isfile(args.pretrained):

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""hub config.""" """hub config."""
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from src.yolo import YOLOV3DarkNet53 from src.yolo import YOLOV3DarkNet53
from src.config import ConfigYOLOV3DarkNet53 from src.config import ConfigYOLOV3DarkNet53
@ -24,9 +24,9 @@ def create_network(name, *args, **kwargs):
config = ConfigYOLOV3DarkNet53() config = ConfigYOLOV3DarkNet53()
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
if config.quantization_aware: if config.quantization_aware:
yolov3_darknet53_quant = quant.convert_quant_network(yolov3_darknet53_quant, quantizer = QuantizationAwareTraining(bn_fold=True,
bn_fold=True, per_channel=[True, False],
per_channel=[True, False], symmetric=[True, False])
symmetric=[True, False]) yolov3_darknet53_quant = quantizer.quantize(yolov3_darknet53_quant)
return yolov3_darknet53_quant return yolov3_darknet53_quant
raise NotImplementedError(f"{name} is not implemented in the repo") raise NotImplementedError(f"{name} is not implemented in the repo")

@ -27,7 +27,7 @@ from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import ModelCheckpoint, RunContext from mindspore.train.callback import ModelCheckpoint, RunContext
from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig from mindspore.train.callback import _InternalCallbackParam, CheckpointConfig
import mindspore as ms import mindspore as ms
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.common import set_seed from mindspore.common import set_seed
from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper from src.yolo import YOLOV3DarkNet53, YoloWithLossCell, TrainingWrapper
@ -168,10 +168,10 @@ def train():
config = ConfigYOLOV3DarkNet53() config = ConfigYOLOV3DarkNet53()
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
if config.quantization_aware: if config.quantization_aware:
network = quant.convert_quant_network(network, quantizer = QuantizationAwareTraining(bn_fold=True,
bn_fold=True,
per_channel=[True, False], per_channel=[True, False],
symmetric=[True, False]) symmetric=[True, False])
network = quantizer.quantize(network)
network = YoloWithLossCell(network) network = YoloWithLossCell(network)
args.logger.info('finish get network') args.logger.info('finish get network')

@ -26,8 +26,8 @@ from mindspore.nn.metrics import Accuracy
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
from mindspore.train import Model from mindspore.train import Model
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.train.quant.quant_utils import load_nonquant_param_into_quant_net from mindspore.compression.quant.quant_utils import load_nonquant_param_into_quant_net
from dataset import create_dataset from dataset import create_dataset
from config import nonquant_cfg, quant_cfg from config import nonquant_cfg, quant_cfg
from lenet import LeNet5 from lenet import LeNet5
@ -73,8 +73,11 @@ def train_lenet_quant():
load_nonquant_param_into_quant_net(network, param_dict) load_nonquant_param_into_quant_net(network, param_dict)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=900, bn_fold=False, per_channel=[True, False], quantizer = QuantizationAwareTraining(quant_delay=900,
symmetric=[False, False]) bn_fold=False,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)
# define network loss # define network loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
@ -103,8 +106,12 @@ def eval_quant():
# define fusion network # define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, quantizer = QuantizationAwareTraining(quant_delay=0,
per_channel=[True, False]) bn_fold=False,
freeze_bn=10000,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)
# define loss # define loss
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
@ -131,8 +138,12 @@ def export_lenet():
# define fusion network # define fusion network
network = LeNet5Fusion(cfg.num_classes) network = LeNet5Fusion(cfg.num_classes)
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000, quantizer = QuantizationAwareTraining(quant_delay=0,
per_channel=[True, False], symmetric=[True, False]) bn_fold=False,
freeze_bn=10000,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)
# export network # export network
inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32) inputs = Tensor(np.ones([1, 1, cfg.image_height, cfg.image_width]), mstype.float32)

@ -23,7 +23,7 @@ from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn from mindspore import nn
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.common import set_seed from mindspore.common import set_seed
from dataset import create_dataset from dataset import create_dataset
@ -84,10 +84,10 @@ def test_mobilenetv2_quant():
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
network = quant.convert_quant_network(network, quantizer = QuantizationAwareTraining(bn_fold=True,
bn_fold=True,
per_channel=[True, False], per_channel=[True, False],
symmetric=[True, False]) symmetric=[True, False])
network = quantizer.quantize(network)
# get learning rate # get learning rate
lr = Tensor(get_lr(global_step=config.start_epoch * step_size, lr = Tensor(get_lr(global_step=config.start_epoch * step_size,

@ -18,14 +18,14 @@ import mindspore.nn as nn
import mindspore.common.initializer as weight_init import mindspore.common.initializer as weight_init
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant as Conv2dBatchNormQuant from mindspore.nn import FakeQuantWithMinMaxObserver, Conv2dBnFoldQuant
from mindspore.train.quant import quant from mindspore.compression.quant import qat
_ema_decay = 0.999 _ema_decay = 0.999
_symmetric = True _symmetric = True
_fake = True _fake = True
_per_channel = True _per_channel = True
_quant_config = quant.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False)) _quant_config = qat.get_quant_config(per_channel=(_per_channel, False), symmetric=(_symmetric, False))
def _weight_variable(shape, factor=0.01): def _weight_variable(shape, factor=0.01):
@ -91,8 +91,8 @@ class ConvBNReLU(nn.Cell):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
super(ConvBNReLU, self).__init__() super(ConvBNReLU, self).__init__()
padding = (kernel_size - 1) // 2 padding = (kernel_size - 1) // 2
conv = Conv2dBatchNormQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding, conv = Conv2dBnFoldQuant(in_planes, out_planes, kernel_size, stride, pad_mode='pad', padding=padding,
group=groups, fake=_fake, quant_config=_quant_config) group=groups, fake=_fake, quant_config=_quant_config)
layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()] layers = [conv, nn.ActQuant(nn.ReLU())] if _fake else [conv, nn.ReLU()]
self.features = nn.SequentialCell(layers) self.features = nn.SequentialCell(layers)
@ -127,14 +127,14 @@ class ResidualBlock(nn.Cell):
channel = out_channel // self.expansion channel = out_channel // self.expansion
self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1) self.conv1 = ConvBNReLU(in_channel, channel, kernel_size=1, stride=1)
self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride) self.conv2 = ConvBNReLU(channel, channel, kernel_size=3, stride=stride)
self.conv3 = nn.SequentialCell([Conv2dBatchNormQuant(channel, out_channel, fake=_fake, self.conv3 = nn.SequentialCell([Conv2dBnFoldQuant(channel, out_channel, fake=_fake,
quant_config=_quant_config, quant_config=_quant_config,
kernel_size=1, stride=1, pad_mode='same', padding=0), kernel_size=1, stride=1, pad_mode='same', padding=0),
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False) FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, symmetric=False)
]) if _fake else Conv2dBatchNormQuant(channel, out_channel, fake=_fake, ]) if _fake else Conv2dBnFoldQuant(channel, out_channel, fake=_fake,
quant_config=_quant_config, quant_config=_quant_config,
kernel_size=1, stride=1, kernel_size=1, stride=1,
pad_mode='same', padding=0) pad_mode='same', padding=0)
self.down_sample = False self.down_sample = False
@ -143,20 +143,19 @@ class ResidualBlock(nn.Cell):
self.down_sample_layer = None self.down_sample_layer = None
if self.down_sample: if self.down_sample:
self.down_sample_layer = nn.SequentialCell([Conv2dBatchNormQuant(in_channel, out_channel, self.down_sample_layer = nn.SequentialCell([Conv2dBnFoldQuant(in_channel, out_channel,
quant_config=_quant_config, quant_config=_quant_config,
kernel_size=1, stride=stride, kernel_size=1, stride=stride,
pad_mode='same', padding=0), pad_mode='same', padding=0),
FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay, FakeQuantWithMinMaxObserver(ema=True, ema_decay=_ema_decay,
symmetric=False) symmetric=False)
]) if _fake else Conv2dBatchNormQuant(in_channel, out_channel, ]) if _fake else Conv2dBnFoldQuant(in_channel, out_channel,
fake=_fake, fake=_fake,
quant_config=\ quant_config=_quant_config,
_quant_config, kernel_size=1,
kernel_size=1, stride=stride,
stride=stride, pad_mode='same',
pad_mode='same', padding=0)
padding=0)
self.add = nn.TensorAddQuant() self.add = nn.TensorAddQuant()
self.relu = P.ReLU() self.relu = P.ReLU()

@ -22,7 +22,7 @@ from mindspore import context
from mindspore import Tensor from mindspore import Tensor
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.quant import quant from mindspore.compression.quant import QuantizationAwareTraining
from mindspore import set_seed from mindspore import set_seed
from resnet_quant_manual import resnet50_quant from resnet_quant_manual import resnet50_quant
@ -89,10 +89,10 @@ def test_resnet50_quant():
step_size = dataset.get_dataset_size() step_size = dataset.get_dataset_size()
# convert fusion network to quantization aware network # convert fusion network to quantization aware network
net = quant.convert_quant_network(net, quantizer = QuantizationAwareTraining(bn_fold=True,
bn_fold=True, per_channel=[True, False],
per_channel=[True, False], symmetric=[True, False])
symmetric=[True, False]) net = quantizer.quantize(net)
# get learning rate # get learning rate
lr = Tensor(get_lr(lr_init=config.lr_init, lr = Tensor(get_lr(lr_init=config.lr_init,

@ -19,7 +19,8 @@ import pytest
import mindspore.context as context import mindspore.context as context
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn from mindspore import nn
from mindspore.train.quant import quant as qat from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.compression.export import quant_export
from model_zoo.official.cv.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2 from model_zoo.official.cv.mobilenetv2_quant.src.mobilenetV2 import mobilenetV2
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
@ -66,27 +67,35 @@ class LeNet5(nn.Cell):
def test_qat_lenet(): def test_qat_lenet():
img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32)) img = Tensor(np.ones((32, 1, 32, 32)).astype(np.float32))
net = LeNet5() net = LeNet5()
net = qat.convert_quant_network( quantizer = QuantizationAwareTraining(bn_fold=True,
net, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) per_channel=[True, False],
symmetric=[True, False])
net = quantizer.quantize(net)
# should load the checkpoint. mock here # should load the checkpoint. mock here
net.init_parameters_data() net.init_parameters_data()
qat.export(net, img, file_name="quant.pb") quant_export.export(net, img, file_name="quant.pb")
@pytest.mark.skip(reason="no `te.lang.cce` in ut env") @pytest.mark.skip(reason="no `te.lang.cce` in ut env")
def test_qat_mobile_per_channel_tf(): def test_qat_mobile_per_channel_tf():
network = mobilenetV2(num_classes=1000) network = mobilenetV2(num_classes=1000)
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
network = qat.convert_quant_network(network, bn_fold=True, per_channel=[True, False], symmetric=[True, False]) quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[True, False],
symmetric=[True, False])
network = quantizer.quantize(network)
# should load the checkpoint. mock here # should load the checkpoint. mock here
network.init_parameters_data() network.init_parameters_data()
qat.export(network, img, file_name="quant.pb") quant_export.export(network, img, file_name="quant.pb")
@pytest.mark.skip(reason="no `te.lang.cce` in ut env") @pytest.mark.skip(reason="no `te.lang.cce` in ut env")
def test_qat_mobile_per_channel_ff(): def test_qat_mobile_per_channel_ff():
network = mobilenetV2(num_classes=1000) network = mobilenetV2(num_classes=1000)
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32)) img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
network = qat.convert_quant_network(network, bn_fold=True, per_channel=[False, False], symmetric=[True, False]) quantizer = QuantizationAwareTraining(bn_fold=True,
per_channel=[False, False],
symmetric=[True, False])
network = quantizer.quantize(network)
# should load the checkpoint. mock here # should load the checkpoint. mock here
network.init_parameters_data() network.init_parameters_data()
qat.export(network, img, file_name="quant.pb") quant_export.export(network, img, file_name="quant.pb")

Loading…
Cancel
Save