!8187 move Conv2dBnAct,DenseBnAct to combined.py

Merge pull request !8187 from yuchaojie/quant
pull/8187/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit a0cab904ce

@ -181,11 +181,11 @@ class ExportToQuantInferNetwork:
cell_core = None
fake_quant_act = None
activation = None
if isinstance(subcell, quant.Conv2dBnAct):
if isinstance(subcell, nn.Conv2dBnAct):
cell_core = subcell.conv
activation = subcell.activation
fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
elif isinstance(subcell, quant.DenseBnAct):
elif isinstance(subcell, nn.DenseBnAct):
cell_core = subcell.dense
activation = subcell.activation
fake_quant_act = activation.fake_quant_act if hasattr(activation, "fake_quant_act") else None
@ -240,9 +240,9 @@ class ExportManualQuantNetwork(ExportToQuantInferNetwork):
subcell = cells[name]
if subcell == network:
continue
if isinstance(subcell, quant.Conv2dBnAct):
if isinstance(subcell, nn.Conv2dBnAct):
network, change = self._convert_subcell(network, change, name, subcell)
elif isinstance(subcell, quant.DenseBnAct):
elif isinstance(subcell, nn.DenseBnAct):
network, change = self._convert_subcell(network, change, name, subcell, conv=False)
elif isinstance(subcell, (quant.Conv2dBnFoldQuant, quant.Conv2dBnWithoutFoldQuant,
quant.Conv2dQuant, quant.DenseQuant)):

@ -36,7 +36,7 @@ from .quantizer import Quantizer, OptimizeOption
__all__ = ["QuantizationAwareTraining", "create_quant_config"]
def create_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver),
def create_quant_config(quant_observer=(nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver),
quant_delay=(0, 0),
quant_dtype=(QuantDtype.INT8, QuantDtype.INT8),
per_channel=(False, False),
@ -48,7 +48,7 @@ def create_quant_config(quant_observer=(quant.FakeQuantWithMinMaxObserver, quant
Args:
quant_observer (Observer, list or tuple): The oberser type to do quantization. The first element represent
weights and second element represent data flow.
Default: (quant.FakeQuantWithMinMaxObserver, quant.FakeQuantWithMinMaxObserver)
Default: (nn.FakeQuantWithMinMaxObserver, nn.FakeQuantWithMinMaxObserver)
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
quant_dtype (QuantDtype, list or tuple): Datatype to use for quantize weights and activations. The first
@ -210,8 +210,8 @@ class QuantizationAwareTraining(Quantizer):
self.act_symmetric = Validator.check_bool(symmetric[-1], "symmetric")
self.weight_range = Validator.check_bool(narrow_range[0], "narrow range")
self.act_range = Validator.check_bool(narrow_range[-1], "narrow range")
self._convert_method_map = {quant.Conv2dBnAct: self._convert_conv,
quant.DenseBnAct: self._convert_dense}
self._convert_method_map = {nn.Conv2dBnAct: self._convert_conv,
nn.DenseBnAct: self._convert_dense}
self.quant_config = create_quant_config(quant_delay=quant_delay,
quant_dtype=quant_dtype,
per_channel=per_channel,
@ -257,7 +257,7 @@ class QuantizationAwareTraining(Quantizer):
subcell = cells[name]
if subcell == network:
continue
elif isinstance(subcell, (quant.Conv2dBnAct, quant.DenseBnAct)):
elif isinstance(subcell, (nn.Conv2dBnAct, nn.DenseBnAct)):
prefix = subcell.param_prefix
new_subcell = self._convert_method_map[type(subcell)](subcell)
new_subcell.update_parameters_name(prefix + '.')

@ -17,7 +17,7 @@ Layer.
The high-level components(Cells) used to construct the neural network.
"""
from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math
from . import activation, normalization, container, conv, lstm, basic, embedding, pooling, image, quant, math, combined
from .activation import *
from .normalization import *
from .container import *
@ -29,6 +29,7 @@ from .pooling import *
from .image import *
from .quant import *
from .math import *
from .combined import *
__all__ = []
__all__.extend(activation.__all__)
@ -42,3 +43,4 @@ __all__.extend(pooling.__all__)
__all__.extend(image.__all__)
__all__.extend(quant.__all__)
__all__.extend(math.__all__)
__all__.extend(combined.__all__)

@ -0,0 +1,215 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Combined cells."""
from mindspore import nn
from mindspore.ops.primitive import Primitive
from mindspore._checkparam import Validator
from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, LeakyReLU
from ..cell import Cell
__all__ = [
'Conv2dBnAct',
'DenseBnAct'
]
class Conv2dBnAct(Cell):
r"""
A combination of convolution, Batchnorm, activation layer.
This part is a more detailed overview of Conv2d op.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): The data type is int or a tuple of 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value is for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (int): Specifies stride for all spatial dimensions with the same value. The value of stride must be
greater than or equal to 1 and lower than any one of the height and width of the input. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): Specifies the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling location. Its value must be greater than
or equal to 1 and lower than any one of the height and width of the input. Default: 1.
group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
has_bn (bool): Specifies to used batchnorm or not. Default: False.
momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9
eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater
than 0. Default: 1e-5.
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2.
after_fake(bool): Determine whether there must be a fake quantization operation after Cond2dBnAct.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> net = nn.Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> result = net(input)
>>> result.shape
(1, 240, 1024, 640)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros',
has_bn=False,
momentum=0.9,
eps=1e-5,
activation=None,
alpha=0.2,
after_fake=True):
super(Conv2dBnAct, self).__init__()
self.conv = nn.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init)
self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None
self.after_fake = Validator.check_bool(after_fake, "after_fake")
if has_bn:
self.batchnorm = BatchNorm2d(out_channels, eps, momentum)
if activation == "leakyrelu":
self.activation = LeakyReLU(alpha)
else:
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
def construct(self, x):
x = self.conv(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x
class DenseBnAct(Cell):
r"""
A combination of Dense, Batchnorm, and the activation layer.
This part is a more detailed overview of Dense op.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None.
has_bn (bool): Specifies to use batchnorm or not. Default: False.
momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9
eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater
than 0. Default: 1e-5.
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2.
after_fake(bool): Determine whether there must be a fake quantization operation after DenseBnAct.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.DenseBnAct(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> result = net(input)
>>> result.shape
(2, 4)
"""
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
has_bn=False,
momentum=0.9,
eps=1e-5,
activation=None,
alpha=0.2,
after_fake=True):
super(DenseBnAct, self).__init__()
self.dense = nn.Dense(
in_channels,
out_channels,
weight_init,
bias_init,
has_bias)
self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None
self.after_fake = Validator.check_bool(after_fake, "after_fake")
if has_bn:
self.batchnorm = BatchNorm1d(out_channels, eps, momentum)
if activation == "leakyrelu":
self.activation = LeakyReLU(alpha)
else:
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
def construct(self, x):
x = self.dense(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x

@ -17,7 +17,6 @@
from functools import partial
from collections import namedtuple
import numpy as np
from mindspore import nn
import mindspore.common.dtype as mstype
from mindspore.ops.primitive import Primitive
from mindspore.ops import operations as P
@ -28,14 +27,12 @@ from mindspore.common.tensor import Tensor
from mindspore._checkparam import Validator, Rel, twice
from mindspore.compression.common import QuantDtype
import mindspore.context as context
from .normalization import BatchNorm2d, BatchNorm1d
from .activation import get_activation, ReLU, LeakyReLU
from .normalization import BatchNorm2d
from .activation import get_activation, ReLU
from ..cell import Cell
from ...ops.operations import _quant_ops as Q
__all__ = [
'Conv2dBnAct',
'DenseBnAct',
'FakeQuantWithMinMaxObserver',
'Conv2dBnFoldQuant',
'Conv2dBnWithoutFoldQuant',
@ -47,192 +44,6 @@ __all__ = [
]
class Conv2dBnAct(Cell):
r"""
A combination of convolution, Batchnorm, activation layer.
This part is a more detailed overview of Conv2d op.
Args:
in_channels (int): The number of input channel :math:`C_{in}`.
out_channels (int): The number of output channel :math:`C_{out}`.
kernel_size (Union[int, tuple]): The data type is int or a tuple of 2 integers. Specifies the height
and width of the 2D convolution window. Single int means the value is for both height and width of
the kernel. A tuple of 2 ints means the first value is for the height and the other is for the
width of the kernel.
stride (int): Specifies stride for all spatial dimensions with the same value. The value of stride must be
greater than or equal to 1 and lower than any one of the height and width of the input. Default: 1.
pad_mode (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding (int): Implicit paddings on both sides of the input. Default: 0.
dilation (int): Specifies the dilation rate to use for dilated convolution. If set to be :math:`k > 1`,
there will be :math:`k - 1` pixels skipped for each sampling location. Its value must be greater than
or equal to 1 and lower than any one of the height and width of the input. Default: 1.
group (int): Splits filter into groups, `in_ channels` and `out_channels` must be
divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
It can be a Tensor, a string, an Initializer or a number. When a string is specified,
values from 'TruncatedNormal', 'Normal', 'Uniform', 'HeUniform' and 'XavierUniform' distributions as well
as constant 'One' and 'Zero' distributions are possible. Alias 'xavier_uniform', 'he_uniform', 'ones'
and 'zeros' are acceptable. Uppercase and lowercase are both acceptable. Refer to the values of
Initializer for more details. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Possible
Initializer and string are the same as 'weight_init'. Refer to the values of
Initializer for more details. Default: 'zeros'.
has_bn (bool): Specifies to used batchnorm or not. Default: False.
momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9
eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater
than 0. Default: 1e-5.
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'softmax', 'logsoftmax', 'relu', 'relu6', 'tanh', 'gelu', 'sigmoid',
'prelu', 'leakyrelu', 'hswish', 'hsigmoid'. Default: None.
alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2.
after_fake(bool): Determine whether there must be a fake quantization operation after Cond2dBnAct.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`.
Outputs:
Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`.
Examples:
>>> net = nn.Conv2dBnAct(120, 240, 4, has_bn=True, activation='ReLU')
>>> input = Tensor(np.ones([1, 120, 1024, 640]), mindspore.float32)
>>> result = net(input)
>>> result.shape
(1, 240, 1024, 640)
"""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
pad_mode='same',
padding=0,
dilation=1,
group=1,
has_bias=False,
weight_init='normal',
bias_init='zeros',
has_bn=False,
momentum=0.9,
eps=1e-5,
activation=None,
alpha=0.2,
after_fake=True):
super(Conv2dBnAct, self).__init__()
self.conv = nn.Conv2d(in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
pad_mode=pad_mode,
padding=padding,
dilation=dilation,
group=group,
has_bias=has_bias,
weight_init=weight_init,
bias_init=bias_init)
self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None
self.after_fake = Validator.check_bool(after_fake, "after_fake")
if has_bn:
self.batchnorm = BatchNorm2d(out_channels, eps, momentum)
if activation == "leakyrelu":
self.activation = LeakyReLU(alpha)
else:
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
def construct(self, x):
x = self.conv(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x
class DenseBnAct(Cell):
r"""
A combination of Dense, Batchnorm, and the activation layer.
This part is a more detailed overview of Dense op.
Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (Cell): The regularization function applied to the output of the layer, eg. 'ReLU'. Default: None.
has_bn (bool): Specifies to use batchnorm or not. Default: False.
momentum (float): Momentum for moving average for batchnorm, must be [0, 1]. Default:0.9
eps (float): Term added to the denominator to improve numerical stability for batchnorm, should be greater
than 0. Default: 1e-5.
activation (Union[str, Cell, Primitive]): Specifies activation type. The optional values are as following:
'Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid',
'PReLU', 'LeakyReLU', 'h-Swish', and 'h-Sigmoid'. Default: None.
alpha (float): Slope of the activation function at x < 0 for LeakyReLU. Default: 0.2.
after_fake(bool): Determine whether there must be a fake quantization operation after DenseBnAct.
Inputs:
- **input** (Tensor) - Tensor of shape :math:`(N, in\_channels)`.
Outputs:
Tensor of shape :math:`(N, out\_channels)`.
Examples:
>>> net = nn.DenseBnAct(3, 4)
>>> input = Tensor(np.random.randint(0, 255, [2, 3]), mindspore.float32)
>>> result = net(input)
>>> result.shape
(2, 4)
"""
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
has_bn=False,
momentum=0.9,
eps=1e-5,
activation=None,
alpha=0.2,
after_fake=True):
super(DenseBnAct, self).__init__()
self.dense = nn.Dense(
in_channels,
out_channels,
weight_init,
bias_init,
has_bias)
self.has_bn = Validator.check_bool(has_bn, "has_bn")
self.has_act = activation is not None
self.after_fake = Validator.check_bool(after_fake, "after_fake")
if has_bn:
self.batchnorm = BatchNorm1d(out_channels, eps, momentum)
if activation == "leakyrelu":
self.activation = LeakyReLU(alpha)
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
def construct(self, x):
x = self.dense(x)
if self.has_bn:
x = self.batchnorm(x)
if self.has_act:
x = self.activation(x)
return x
class BatchNormFoldCell(Cell):
"""
Batch normalization folded.

Loading…
Cancel
Save