diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index 4bd87d1480..ddd4722c1b 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -474,8 +474,8 @@ def export(network, *inputs, file_name, mean=127.5, std_dev=127.5, file_format=' def convert_quant_network(network, - bn_fold=False, - freeze_bn=10000, + bn_fold=True, + freeze_bn=1e7, quant_delay=(0, 0), num_bits=(8, 8), per_channel=(False, False), @@ -487,11 +487,11 @@ def convert_quant_network(network, Args: network (Cell): Obtain a pipeline through network for saving graph summary. - bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. - freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0. + bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: True. + freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 1e7. quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during eval. The first element represent weights and second element represent data flow. Default: (0, 0) - num_bits (int, list or tuple): Number of bits to use for quantizing weights and activations. The first + num_bits (int, list or tuple): Number of bits to use for quantize weights and activations. The first element represent weights and second element represent data flow. Default: (8, 8) per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True` then base on per channel otherwise base on per layer. The first element represent weights diff --git a/model_zoo/official/cv/lenet_quant/src/lenet.py b/model_zoo/official/cv/lenet_quant/src/lenet.py index 1efcf9e7d7..18d310c2c7 100644 --- a/model_zoo/official/cv/lenet_quant/src/lenet.py +++ b/model_zoo/official/cv/lenet_quant/src/lenet.py @@ -35,7 +35,9 @@ class LeNet5(nn.Cell): self.num_class = num_class self.conv1 = nn.Conv2d(channel, 6, 5, pad_mode='valid') + self.bn1 = nn.BatchNorm2d(6) self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid') + self.bn2 = nn.BatchNorm2d(16) self.fc1 = nn.Dense(16 * 5 * 5, 120) self.fc2 = nn.Dense(120, 84) self.fc3 = nn.Dense(84, self.num_class) @@ -46,9 +48,11 @@ class LeNet5(nn.Cell): def construct(self, x): x = self.conv1(x) + x = self.bn1(x) x = self.relu(x) x = self.max_pool2d(x) x = self.conv2(x) + x = self.bn2(x) x = self.relu(x) x = self.max_pool2d(x) x = self.flatten(x) diff --git a/model_zoo/official/cv/lenet_quant/src/lenet_fusion.py b/model_zoo/official/cv/lenet_quant/src/lenet_fusion.py index 88b3593502..88b5685218 100644 --- a/model_zoo/official/cv/lenet_quant/src/lenet_fusion.py +++ b/model_zoo/official/cv/lenet_quant/src/lenet_fusion.py @@ -36,8 +36,8 @@ class LeNet5(nn.Cell): self.num_class = num_class # change `nn.Conv2d` to `nn.Conv2dBnAct` - self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', activation='relu') - self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', activation='relu') + self.conv1 = nn.Conv2dBnAct(channel, 6, 5, pad_mode='valid', has_bn=True, activation='relu') + self.conv2 = nn.Conv2dBnAct(6, 16, 5, pad_mode='valid', has_bn=True, activation='relu') # change `nn.Dense` to `nn.DenseBnAct` self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu') self.fc2 = nn.DenseBnAct(120, 84, activation='relu') diff --git a/model_zoo/official/cv/lenet_quant/src/lenet_quant.py b/model_zoo/official/cv/lenet_quant/src/lenet_quant.py new file mode 100644 index 0000000000..12be3f28fa --- /dev/null +++ b/model_zoo/official/cv/lenet_quant/src/lenet_quant.py @@ -0,0 +1,61 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Manual construct network for LeNet""" + +import mindspore.nn as nn + + +class LeNet5(nn.Cell): + """ + Lenet network + + Args: + num_class (int): Num classes. Default: 10. + + Returns: + Tensor, output tensor + Examples: + >>> LeNet(num_class=10) + + """ + + def __init__(self, num_class=10, channel=1): + super(LeNet5, self).__init__() + self.num_class = num_class + + self.conv1 = nn.Conv2dBnFoldQuant(channel, 6, 5, pad_mode='valid', per_channel=True, quant_delay=900) + self.conv2 = nn.Conv2dBnFoldQuant(6, 16, 5, pad_mode='valid', per_channel=True, quant_delay=900) + self.fc1 = nn.DenseQuant(16 * 5 * 5, 120, per_channel=True, quant_delay=900) + self.fc2 = nn.DenseQuant(120, 84, per_channel=True, quant_delay=900) + self.fc3 = nn.DenseQuant(84, self.num_class, per_channel=True, quant_delay=900) + + self.relu = nn.ActQuant(nn.ReLU(), per_channel=False, quant_delay=900) + self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2) + self.flatten = nn.Flatten() + + def construct(self, x): + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.relu(x) + x = self.fc3(x) + return x diff --git a/model_zoo/official/cv/lenet_quant/train_quant.py b/model_zoo/official/cv/lenet_quant/train_quant.py index 33c322f4b5..dd6b59a9c8 100644 --- a/model_zoo/official/cv/lenet_quant/train_quant.py +++ b/model_zoo/official/cv/lenet_quant/train_quant.py @@ -57,7 +57,7 @@ if __name__ == "__main__": load_param_into_net(network, param_dict) # convert fusion network to quantization aware network - network = quant.convert_quant_network(network, quant_delay=0, bn_fold=False, freeze_bn=10000) + network = quant.convert_quant_network(network, quant_delay=900, per_channel=[True, False], symmetric=[False, False]) # define network loss net_loss = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True, reduction="mean")