!1084 add new interface quant combined

Merge pull request !1084 from SanjayChan/04quant
pull/1084/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 4bb5c7b39a

@ -97,7 +97,7 @@ class Cell:
After invoked, can get all the cell's children's name prefix by '_param_prefix'.
"""
cells = self.cells_and_names
cells = self.cells_and_names()
for cell_name, cell in cells:
cell._param_prefix = cell_name

@ -0,0 +1,182 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Use combination of Conv, Dense, Relu, Batchnorm."""
from .normalization import BatchNorm2d
from .activation import get_activation
from ..cell import Cell
from . import conv, basic
from ..._checkparam import ParamValidator as validator
__all__ = ['Conv2d', 'Dense']
class Conv2d(Cell):
r"""
A combination of convolution, Batchnorm, activation layer.
For a more Detailed overview of Conv2d op.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value if for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be
greater or equal to 1 but bounded by the height and width of the input. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater
or equal to 1 and bounded by the height and width of the input. Default: 1.
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> net(input).shape()
(1, 240, 1024, 640)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros',
batchnorm=None,
activation=None):
super(Conv2d, self).__init__()
self.conv = conv.Conv2d(
in_channels,
out_channels,
kernel_size,
stride,
pad_mode,
padding,
dilation,
group,
has_bias,
weight_init,
bias_init)
self.has_bn = batchnorm is not None
self.has_act = activation is not None
self.batchnorm = batchnorm
if batchnorm is True:
self.batchnorm = BatchNorm2d(out_channels)
elif batchnorm is not None:
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
self.activation = get_activation(activation)
def construct(self, x):
x = self.conv(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x
class Dense(Cell):
r"""
A combination of Dense, Batchnorm, activation layer.
For a more Detailed overview of Dense op.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
activation (string): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.Dense(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> net(input)
"""
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
batchnorm=None,
activation=None):
super(Dense, self).__init__()
self.dense = basic.Dense(
in_channels,
out_channels,
weight_init,
bias_init,
has_bias)
self.has_bn = batchnorm is not None
self.has_act = activation is not None
if batchnorm is True:
self.batchnorm = BatchNorm2d(out_channels)
elif batchnorm is not None:
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
self.activation = get_activation(activation)
def construct(self, x):
x = self.dense(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x

@ -191,6 +191,8 @@ class Conv2dBatchNormQuant(Cell):
stride,
pad_mode,
padding=0,
dilation=1,
group=1,
eps=1e-5,
momentum=0.9,
weight_init=None,
@ -198,7 +200,6 @@ class Conv2dBatchNormQuant(Cell):
gamma_init=None,
mean_init=None,
var_init=None,
group=1,
quant_delay=0,
freeze_bn=100000,
fake=True,

@ -0,0 +1,26 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
quantization.
User can use aware quantization to train a model. Mindspore supports quantization aware training,
which models quantization errors in both the forward and backward passes using fake-quantization
ops. Note that the entire computation is carried out in floating point. At the end of quantization
aware training, Mindspore provides conversion functions to convert the trained model into lower precision.
"""
from .quant import convert_quant_network
__all__ = ["convert_quant_network"]

File diff suppressed because it is too large Load Diff

@ -0,0 +1,100 @@
"""MobileNetV2"""
from mindspore import nn
from mindspore.ops import operations as P
def make_divisible(input_x, div_by=8):
return int((input_x + div_by) // div_by)
def _conv_bn(in_channel,
out_channel,
ksize,
stride=1):
"""Get a conv2d batchnorm and relu layer."""
return nn.SequentialCell(
[nn.Conv2d(in_channel,
out_channel,
kernel_size=ksize,
stride=stride),
nn.BatchNorm2d(out_channel)])
class InvertedResidual(nn.Cell):
def __init__(self, inp, oup, stride, expend_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expend_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expend_ratio == 1:
self.conv = nn.SequentialCell([
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
nn.Conv2d(hidden_dim, oup, 1, 1),
nn.BatchNorm2d(oup)
])
else:
self.conv = nn.SequentialCell([
nn.Conv2d(inp, hidden_dim, 1, 1),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(),
nn.Conv2d(hidden_dim, oup, 1, 1),
nn.BatchNorm2d(oup)
])
def construct(self, input_x):
out = self.conv(input_x)
if self.use_res_connect:
out = input_x + out
return out
class MobileNetV2(nn.Cell):
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
inverted_residual_setting = [
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 230, 1, 1],
]
if width_mul > 1.0:
last_channel = make_divisible(last_channel * width_mul)
self.last_channel = last_channel
features = [_conv_bn(3, input_channel, 3, 2)]
for t, c, n, s in inverted_residual_setting:
out_channel = make_divisible(c * width_mul) if t > 1 else c
for i in range(n):
if i == 0:
features.append(block(input_channel, out_channel, s, t))
else:
features.append(block(input_channel, out_channel, 1, t))
input_channel = out_channel
features.append(_conv_bn(input_channel, self.last_channel, 1))
self.features = nn.SequentialCell(features)
self.mean = P.ReduceMean(keep_dims=False)
self.classifier = nn.Dense(self.last_channel, num_class)
def construct(self, input_x):
out = input_x
out = self.features(out)
out = self.mean(out, (2, 3))
out = self.classifier(out)
return out

@ -0,0 +1,108 @@
"""mobile net v2"""
from mindspore import nn
from mindspore.ops import operations as P
from mindspore.nn.layer import combined
def make_divisible(input_x, div_by=8):
return int((input_x + div_by) // div_by)
def _conv_bn(in_channel,
out_channel,
ksize,
stride=1):
"""Get a conv2d batchnorm and relu layer."""
return nn.SequentialCell(
[combined.Conv2d(in_channel,
out_channel,
kernel_size=ksize,
stride=stride,
batchnorm=True)])
class InvertedResidual(nn.Cell):
def __init__(self, inp, oup, stride, expend_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(inp * expend_ratio)
self.use_res_connect = self.stride == 1 and inp == oup
if expend_ratio == 1:
self.conv = nn.SequentialCell([
combined.Conv2d(hidden_dim,
hidden_dim,
3,
stride,
group=hidden_dim,
batchnorm=True,
activation='relu6'),
combined.Conv2d(hidden_dim, oup, 1, 1,
batchnorm=True)
])
else:
self.conv = nn.SequentialCell([
combined.Conv2d(inp, hidden_dim, 1, 1,
batchnorm=True,
activation='relu6'),
combined.Conv2d(hidden_dim,
hidden_dim,
3,
stride,
group=hidden_dim,
batchnorm=True,
activation='relu6'),
combined.Conv2d(hidden_dim, oup, 1, 1,
batchnorm=True)
])
self.add = P.TensorAdd()
def construct(self, input_x):
out = self.conv(input_x)
if self.use_res_connect:
out = self.add(input_x, out)
return out
class MobileNetV2(nn.Cell):
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
inverted_residual_setting = [
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 230, 1, 1],
]
if width_mul > 1.0:
last_channel = make_divisible(last_channel * width_mul)
self.last_channel = last_channel
features = [_conv_bn(3, input_channel, 3, 2)]
for t, c, n, s in inverted_residual_setting:
out_channel = make_divisible(c * width_mul) if t > 1 else c
for i in range(n):
if i == 0:
features.append(block(input_channel, out_channel, s, t))
else:
features.append(block(input_channel, out_channel, 1, t))
input_channel = out_channel
features.append(_conv_bn(input_channel, self.last_channel, 1))
self.features = nn.SequentialCell(features)
self.mean = P.ReduceMean(keep_dims=False)
self.classifier = combined.Dense(self.last_channel, num_class)
def construct(self, input_x):
out = input_x
out = self.features(out)
out = self.mean(out, (2, 3))
out = self.classifier(out)
return out

@ -0,0 +1,94 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" tests for quant """
import numpy as np
from mindspore import Tensor
from mindspore.train.quant import quant as qat
from mindspore import nn
import mindspore.ops.operations as P
from mindspore.nn.layer import combined
import mindspore.context as context
from mobilenetv2_combined import MobileNetV2
context.set_context(mode=context.GRAPH_MODE)
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10)
"""
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = combined.Conv2d(
1, 6, kernel_size=5, batchnorm=True, activation='relu6')
self.conv2 = combined.Conv2d(6, 16, kernel_size=5, activation='relu')
self.fc1 = combined.Dense(16 * 5 * 5, 120, activation='relu')
self.fc2 = combined.Dense(120, 84, activation='relu')
self.fc3 = combined.Dense(84, self.num_class)
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flattern = nn.Flatten()
def construct(self, x):
x = self.conv1(x)
x = self.bn(x)
x = self.relu(x)
x = self.max_pool2d(x)
x = self.conv2(x)
x = self.max_pool2d(x)
x = self.flattern(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
def test_qat_lenet():
net = LeNet5()
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
def test_qat_mobile():
net = MobileNetV2()
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
net(img)
def test_qat_mobile_train():
net = MobileNetV2(num_class=10)
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
label = Tensor(np.ones((1, 10)).astype(np.float32))
net = qat.convert_quant_network(
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
loss = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
optimizer = nn.Momentum(net.trainable_params(),
learning_rate=0.1, momentum=0.9)
net = nn.WithLossCell(net, loss)
net = nn.TrainOneStepCell(net, optimizer)
net(img, label)
Loading…
Cancel
Save