!1084 add new interface quant combined
Merge pull request !1084 from SanjayChan/04quantpull/1084/MERGE
commit
4bb5c7b39a
@ -0,0 +1,182 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Use combination of Conv, Dense, Relu, Batchnorm."""
|
||||
|
||||
from .normalization import BatchNorm2d
|
||||
from .activation import get_activation
|
||||
from ..cell import Cell
|
||||
from . import conv, basic
|
||||
from ..._checkparam import ParamValidator as validator
|
||||
|
||||
|
||||
__all__ = ['Conv2d', 'Dense']
|
||||
|
||||
class Conv2d(Cell):
|
||||
r"""
|
||||
A combination of convolution, Batchnorm, activation layer.
|
||||
|
||||
For a more Detailed overview of Conv2d op.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of input channel :math:`C_{in}`.
|
||||
out_channels (int): The number of output channel :math:`C_{out}`.
|
||||
kernel_size (Union[int, tuple]): The data type is int or tuple with 2 integers. Specifies the height
|
||||
and width of the 2D convolution window. Single int means the value if for both height and width of
|
||||
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
|
||||
width of the kernel.
|
||||
stride (int): Specifies stride for all spatial dimensions with the same value. Value of stride should be
|
||||
greater or equal to 1 but bounded by the height and width of the input. Default: 1.
|
||||
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
|
||||
padding (int): Implicit paddings on both sides of the input. Default: 0.
|
||||
dilation (int): Specifying the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
|
||||
there will be :math:`k - 1` pixels skipped for each sampling location. Its value should be greater
|
||||
or equal to 1 and bounded by the height and width of the input. Default: 1.
|
||||
group (int): Split filter into groups, `in_ channels` and `out_channels` should be
|
||||
divisible by the number of groups. Default: 1.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
|
||||
It can be a Tensor, a string, an Initializer or a numbers.Number. When a string is specified,
|
||||
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
|
||||
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
|
||||
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
|
||||
Initializer for more details. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
|
||||
Initializer and string are the same as 'weight_init'. Refer to the values of
|
||||
Initializer for more details. Default: 'zeros'.
|
||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||
activation (string): Specifies activation type. The optional values are as following:
|
||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
|
||||
|
||||
Examples:
|
||||
>>> net = combined.Conv2d(120, 240, 4, batchnorm=True, activation='ReLU')
|
||||
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
|
||||
>>> net(input).shape()
|
||||
(1, 240, 1024, 640)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
pad_mode='same',
|
||||
padding=0,
|
||||
dilation=1,
|
||||
group=1,
|
||||
has_bias=False,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
batchnorm=None,
|
||||
activation=None):
|
||||
super(Conv2d, self).__init__()
|
||||
self.conv = conv.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
pad_mode,
|
||||
padding,
|
||||
dilation,
|
||||
group,
|
||||
has_bias,
|
||||
weight_init,
|
||||
bias_init)
|
||||
self.has_bn = batchnorm is not None
|
||||
self.has_act = activation is not None
|
||||
self.batchnorm = batchnorm
|
||||
if batchnorm is True:
|
||||
self.batchnorm = BatchNorm2d(out_channels)
|
||||
elif batchnorm is not None:
|
||||
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
||||
self.activation = get_activation(activation)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
if self.has_bn:
|
||||
x = self.batchnorm(x)
|
||||
if self.has_act:
|
||||
x = self.activation(x)
|
||||
return x
|
||||
|
||||
|
||||
class Dense(Cell):
|
||||
r"""
|
||||
A combination of Dense, Batchnorm, activation layer.
|
||||
|
||||
For a more Detailed overview of Dense op.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels in the input space.
|
||||
out_channels (int): The number of channels in the output space.
|
||||
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
|
||||
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
|
||||
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
|
||||
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
|
||||
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
|
||||
activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None.
|
||||
batchnorm (bool): Specifies to used batchnorm or not. Default: None.
|
||||
activation (string): Specifies activation type. The optional values are as following:
|
||||
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
|
||||
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
|
||||
|
||||
Outputs:
|
||||
Tensor of shape :math:`(N, out\_channels)`.
|
||||
|
||||
Examples:
|
||||
>>> net = nn.Dense(3, 4)
|
||||
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
|
||||
>>> net(input)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init='normal',
|
||||
bias_init='zeros',
|
||||
has_bias=True,
|
||||
batchnorm=None,
|
||||
activation=None):
|
||||
super(Dense, self).__init__()
|
||||
self.dense = basic.Dense(
|
||||
in_channels,
|
||||
out_channels,
|
||||
weight_init,
|
||||
bias_init,
|
||||
has_bias)
|
||||
self.has_bn = batchnorm is not None
|
||||
self.has_act = activation is not None
|
||||
if batchnorm is True:
|
||||
self.batchnorm = BatchNorm2d(out_channels)
|
||||
elif batchnorm is not None:
|
||||
validator.check_isinstance('batchnorm', batchnorm, (BatchNorm2d,))
|
||||
self.activation = get_activation(activation)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.dense(x)
|
||||
if self.has_bn:
|
||||
x = self.batchnorm(x)
|
||||
if self.has_act:
|
||||
x = self.activation(x)
|
||||
return x
|
@ -0,0 +1,26 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
quantization.
|
||||
|
||||
User can use aware quantization to train a model. Mindspore supports quantization aware training,
|
||||
which models quantization errors in both the forward and backward passes using fake-quantization
|
||||
ops. Note that the entire computation is carried out in floating point. At the end of quantization
|
||||
aware training, Mindspore provides conversion functions to convert the trained model into lower precision.
|
||||
"""
|
||||
|
||||
from .quant import convert_quant_network
|
||||
|
||||
__all__ = ["convert_quant_network"]
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,100 @@
|
||||
"""MobileNetV2"""
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
def make_divisible(input_x, div_by=8):
|
||||
return int((input_x + div_by) // div_by)
|
||||
|
||||
|
||||
def _conv_bn(in_channel,
|
||||
out_channel,
|
||||
ksize,
|
||||
stride=1):
|
||||
"""Get a conv2d batchnorm and relu layer."""
|
||||
return nn.SequentialCell(
|
||||
[nn.Conv2d(in_channel,
|
||||
out_channel,
|
||||
kernel_size=ksize,
|
||||
stride=stride),
|
||||
nn.BatchNorm2d(out_channel)])
|
||||
|
||||
|
||||
class InvertedResidual(nn.Cell):
|
||||
def __init__(self, inp, oup, stride, expend_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expend_ratio)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
if expend_ratio == 1:
|
||||
self.conv = nn.SequentialCell([
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(),
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1),
|
||||
nn.BatchNorm2d(oup)
|
||||
])
|
||||
else:
|
||||
self.conv = nn.SequentialCell([
|
||||
nn.Conv2d(inp, hidden_dim, 1, 1),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(),
|
||||
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, group=hidden_dim),
|
||||
nn.BatchNorm2d(hidden_dim),
|
||||
nn.ReLU6(),
|
||||
|
||||
nn.Conv2d(hidden_dim, oup, 1, 1),
|
||||
nn.BatchNorm2d(oup)
|
||||
])
|
||||
|
||||
def construct(self, input_x):
|
||||
out = self.conv(input_x)
|
||||
if self.use_res_connect:
|
||||
out = input_x + out
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV2(nn.Cell):
|
||||
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
|
||||
super(MobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
inverted_residual_setting = [
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 230, 1, 1],
|
||||
]
|
||||
if width_mul > 1.0:
|
||||
last_channel = make_divisible(last_channel * width_mul)
|
||||
self.last_channel = last_channel
|
||||
features = [_conv_bn(3, input_channel, 3, 2)]
|
||||
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
out_channel = make_divisible(c * width_mul) if t > 1 else c
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
features.append(block(input_channel, out_channel, s, t))
|
||||
else:
|
||||
features.append(block(input_channel, out_channel, 1, t))
|
||||
input_channel = out_channel
|
||||
|
||||
features.append(_conv_bn(input_channel, self.last_channel, 1))
|
||||
|
||||
self.features = nn.SequentialCell(features)
|
||||
self.mean = P.ReduceMean(keep_dims=False)
|
||||
self.classifier = nn.Dense(self.last_channel, num_class)
|
||||
|
||||
def construct(self, input_x):
|
||||
out = input_x
|
||||
out = self.features(out)
|
||||
out = self.mean(out, (2, 3))
|
||||
out = self.classifier(out)
|
||||
return out
|
@ -0,0 +1,108 @@
|
||||
"""mobile net v2"""
|
||||
from mindspore import nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.layer import combined
|
||||
|
||||
|
||||
def make_divisible(input_x, div_by=8):
|
||||
return int((input_x + div_by) // div_by)
|
||||
|
||||
|
||||
def _conv_bn(in_channel,
|
||||
out_channel,
|
||||
ksize,
|
||||
stride=1):
|
||||
"""Get a conv2d batchnorm and relu layer."""
|
||||
return nn.SequentialCell(
|
||||
[combined.Conv2d(in_channel,
|
||||
out_channel,
|
||||
kernel_size=ksize,
|
||||
stride=stride,
|
||||
batchnorm=True)])
|
||||
|
||||
|
||||
class InvertedResidual(nn.Cell):
|
||||
def __init__(self, inp, oup, stride, expend_ratio):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
assert stride in [1, 2]
|
||||
|
||||
hidden_dim = int(inp * expend_ratio)
|
||||
self.use_res_connect = self.stride == 1 and inp == oup
|
||||
if expend_ratio == 1:
|
||||
self.conv = nn.SequentialCell([
|
||||
combined.Conv2d(hidden_dim,
|
||||
hidden_dim,
|
||||
3,
|
||||
stride,
|
||||
group=hidden_dim,
|
||||
batchnorm=True,
|
||||
activation='relu6'),
|
||||
combined.Conv2d(hidden_dim, oup, 1, 1,
|
||||
batchnorm=True)
|
||||
])
|
||||
else:
|
||||
self.conv = nn.SequentialCell([
|
||||
combined.Conv2d(inp, hidden_dim, 1, 1,
|
||||
batchnorm=True,
|
||||
activation='relu6'),
|
||||
combined.Conv2d(hidden_dim,
|
||||
hidden_dim,
|
||||
3,
|
||||
stride,
|
||||
group=hidden_dim,
|
||||
batchnorm=True,
|
||||
activation='relu6'),
|
||||
combined.Conv2d(hidden_dim, oup, 1, 1,
|
||||
batchnorm=True)
|
||||
])
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, input_x):
|
||||
out = self.conv(input_x)
|
||||
if self.use_res_connect:
|
||||
out = self.add(input_x, out)
|
||||
return out
|
||||
|
||||
|
||||
class MobileNetV2(nn.Cell):
|
||||
def __init__(self, num_class=1000, input_size=224, width_mul=1.):
|
||||
super(MobileNetV2, self).__init__()
|
||||
block = InvertedResidual
|
||||
input_channel = 32
|
||||
last_channel = 1280
|
||||
inverted_residual_setting = [
|
||||
[1, 16, 1, 1],
|
||||
[6, 24, 2, 2],
|
||||
[6, 32, 3, 2],
|
||||
[6, 64, 4, 2],
|
||||
[6, 96, 3, 1],
|
||||
[6, 160, 3, 2],
|
||||
[6, 230, 1, 1],
|
||||
]
|
||||
if width_mul > 1.0:
|
||||
last_channel = make_divisible(last_channel * width_mul)
|
||||
self.last_channel = last_channel
|
||||
features = [_conv_bn(3, input_channel, 3, 2)]
|
||||
|
||||
for t, c, n, s in inverted_residual_setting:
|
||||
out_channel = make_divisible(c * width_mul) if t > 1 else c
|
||||
for i in range(n):
|
||||
if i == 0:
|
||||
features.append(block(input_channel, out_channel, s, t))
|
||||
else:
|
||||
features.append(block(input_channel, out_channel, 1, t))
|
||||
input_channel = out_channel
|
||||
|
||||
features.append(_conv_bn(input_channel, self.last_channel, 1))
|
||||
|
||||
self.features = nn.SequentialCell(features)
|
||||
self.mean = P.ReduceMean(keep_dims=False)
|
||||
self.classifier = combined.Dense(self.last_channel, num_class)
|
||||
|
||||
def construct(self, input_x):
|
||||
out = input_x
|
||||
out = self.features(out)
|
||||
out = self.mean(out, (2, 3))
|
||||
out = self.classifier(out)
|
||||
return out
|
@ -0,0 +1,94 @@
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
""" tests for quant """
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.quant import quant as qat
|
||||
from mindspore import nn
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore.nn.layer import combined
|
||||
import mindspore.context as context
|
||||
from mobilenetv2_combined import MobileNetV2
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
"""
|
||||
Lenet network
|
||||
|
||||
Args:
|
||||
num_class (int): Num classes. Default: 10.
|
||||
|
||||
Returns:
|
||||
Tensor, output tensor
|
||||
Examples:
|
||||
>>> LeNet(num_class=10)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, num_class=10):
|
||||
super(LeNet5, self).__init__()
|
||||
self.num_class = num_class
|
||||
self.conv1 = combined.Conv2d(
|
||||
1, 6, kernel_size=5, batchnorm=True, activation='relu6')
|
||||
self.conv2 = combined.Conv2d(6, 16, kernel_size=5, activation='relu')
|
||||
self.fc1 = combined.Dense(16 * 5 * 5, 120, activation='relu')
|
||||
self.fc2 = combined.Dense(120, 84, activation='relu')
|
||||
self.fc3 = combined.Dense(84, self.num_class)
|
||||
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.flattern = nn.Flatten()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.flattern(x)
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
def test_qat_lenet():
|
||||
net = LeNet5()
|
||||
net = qat.convert_quant_network(
|
||||
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
|
||||
|
||||
|
||||
def test_qat_mobile():
|
||||
net = MobileNetV2()
|
||||
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
|
||||
net = qat.convert_quant_network(
|
||||
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
|
||||
net(img)
|
||||
|
||||
|
||||
def test_qat_mobile_train():
|
||||
net = MobileNetV2(num_class=10)
|
||||
img = Tensor(np.ones((1, 3, 224, 224)).astype(np.float32))
|
||||
label = Tensor(np.ones((1, 10)).astype(np.float32))
|
||||
net = qat.convert_quant_network(
|
||||
net, quant_delay=0, bn_fold=False, freeze_bn=10000, weight_bits=8, act_bits=8)
|
||||
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(reduction='mean')
|
||||
optimizer = nn.Momentum(net.trainable_params(),
|
||||
learning_rate=0.1, momentum=0.9)
|
||||
net = nn.WithLossCell(net, loss)
|
||||
net = nn.TrainOneStepCell(net, optimizer)
|
||||
net(img, label)
|
Loading…
Reference in new issue