diff --git a/mindspore/nn/layer/__init__.py b/mindspore/nn/layer/__init__.py index dae18fe663..aed6cb7776 100644 --- a/mindspore/nn/layer/__init__.py +++ b/mindspore/nn/layer/__init__.py @@ -17,7 +17,7 @@ Layer. The high-level components(Cells) used to construct the neural network. """ -from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU +from .activation import Softmax, LogSoftmax, ReLU, ReLU6, Tanh, GELU, ELU, Sigmoid, PReLU, get_activation, LeakyReLU, HSigmoid, HSwish from .normalization import BatchNorm1d, BatchNorm2d, LayerNorm from .container import SequentialCell, CellList from .conv import Conv2d, Conv2dTranspose @@ -26,8 +26,9 @@ from .basic import Dropout, Flatten, Dense, ClipByNorm, Norm, OneHot, ImageGradi from .embedding import Embedding from .pooling import AvgPool2d, MaxPool2d -__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', 'PReLU', 'get_activation', 'LeakyReLU', - 'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'ELU', +__all__ = ['Softmax', 'LogSoftmax', 'ReLU', 'ReLU6', 'Tanh', 'GELU', 'Sigmoid', + 'PReLU', 'get_activation', 'LeakyReLU', 'HSigmoid', 'HSwish', 'ELU', + 'BatchNorm1d', 'BatchNorm2d', 'LayerNorm', 'SequentialCell', 'CellList', 'Conv2d', 'Conv2dTranspose', 'LSTM', diff --git a/mindspore/nn/layer/_quant.py b/mindspore/nn/layer/_quant.py new file mode 100644 index 0000000000..f27af8b269 --- /dev/null +++ b/mindspore/nn/layer/_quant.py @@ -0,0 +1,703 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Aware quantization.""" + +import numpy as np +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import functional as F +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer +from mindspore.common.tensor import Tensor +from mindspore._checkparam import check_int_positive, check_bool, twice +from mindspore.nn.cell import Cell +from mindspore.nn.layer.conv import _Conv +from mindspore.nn.layer.activation import get_activation + +__all__ = [ + 'FakeQuantWithMinMax', + 'Conv2dBatchNormQuant', + 'Conv2dQuant', + 'DenseQuant', + 'ReLUQuant', + 'ReLU6Quant', + 'HSwishQuant', + 'HSigmoidQuant', + 'TensorAddQuant', +] + + +class FakeQuantWithMinMax(Cell): + r""" + Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max. + + Args: + min_init (int, list): The dimension of channel or 1(layer). Default: -6. + max_init (int, list): The dimension of channel or 1(layer). Default: 6. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + ema (bool): Exponential Moving Average algorithm update min and max. Default: False. + ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.9999. + per_channel (bool): Quantization by layer or channel. Default: False. + channel_size (int): declarate the min and max channel size, Default: 1. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + + Inputs: + - **x** (Tensor) - The input of FakeQuantWithMinMax. + + Outputs: + Tensor, with the same type and shape as the `x`. + + """ + + def __init__(self, + min_init=-6, + max_init=6, + num_bits=8, + ema=False, + ema_decay=0.999, + per_channel=False, + channel_size=1, + quant_delay=0, + symmetric=False, + narrow_range=False): + super(FakeQuantWithMinMax, self).__init__() + + self.min_init = min_init + self.num_bits = num_bits + self.max_init = max_init + self.ema = ema + self.ema_decay = ema_decay + self.per_channel = per_channel + self.channel_size = channel_size + self.quant_delay = quant_delay + self.symmetric = symmetric + self.narrow_range = narrow_range + + if per_channel: + min_array = np.array([self.min_init for i in range( + 0, self.channel_size)]).astype(np.float32) + max_array = np.array([self.max_init for i in range( + 0, self.channel_size)]).astype(np.float32) + self.fake_quant_train = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, + ema=self.ema, + ema_decay=self.ema_decay, + quant_delay=self.quant_delay, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=True) + self.fake_quant_infer = P.FakeQuantWithMinMaxPerChannel(num_bits=self.num_bits, + ema=self.ema, + ema_decay=ema_decay, + quant_delay=quant_delay, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=False) + else: + min_array = np.array([min_init]).reshape(1).astype(np.float32) + max_array = np.array([max_init]).reshape(1).astype(np.float32) + self.fake_quant_train = P.FakeQuantWithMinMax(num_bits=self.num_bits, + ema=self.ema, + ema_decay=self.ema_decay, + quant_delay=self.quant_delay, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=True) + self.fake_quant_infer = P.FakeQuantWithMinMax(num_bits=self.num_bits, + ema=self.ema, + ema_decay=ema_decay, + quant_delay=quant_delay, + symmetric=self.symmetric, + narrow_range=self.narrow_range, + training=False) + + self.min = Parameter( + Tensor(min_array), name='quant_min', requires_grad=False) + self.max = Parameter( + Tensor(max_array), name='quant_max', requires_grad=False) + + def extend_repr(self): + s = 'min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay={}'.format( + self.min_init, self.max_init, self.ema, self.ema_decay, self.per_channel, self.channel_size, + self.quant_delay) + return s + + def construct(self, x): + if self.training: + out = self.fake_quant_train(x, self.min, self.max) + else: + out = self.fake_quant_infer(x, self.min, self.max) + return out + + +class Conv2dBatchNormQuant(Cell): + r""" + 2D convolution with BatchNormal op folded layer. + + For a more Detailed overview of Conv2d op. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. + stride (int): Specifies stride for all spatial dimensions with the same value. + pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". + padding: (int): Implicit paddings on both sides of the input. Default: 0. + eps (int): Parameters for BatchNormal. Default: 1e-5. + momentum (int): Parameters for BatchNormal op. Default: 0.9. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + convolution kernel. Default: 'None'. + beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + beta vector. Default: 'None'. + gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + gamma vector. Default: 'None'. + mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + mean vector. Default: 'None'. + var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the + variance vector. Default: 'None'. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. + fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + pad_mode, + padding=0, + eps=1e-5, + momentum=0.9, + weight_init=None, + beta_init=None, + gamma_init=None, + mean_init=None, + var_init=None, + group=1, + quant_delay=0, + freeze_bn=100000, + fake=True, + num_bits=8, + per_channel=False, + symmetric=False, + narrow_range=False): + super(Conv2dBatchNormQuant, self).__init__() + self.stride = stride + self.conv = P.Conv2D(out_channel=out_channels, + kernel_size=kernel_size, + mode=1, + pad_mode=pad_mode, + pad=padding, + stride=stride, + dilation=1, + group=group) + self.fake = fake + self.freeze_bn = freeze_bn + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + + if weight_init is None: + weight_init = initializer( + 'normal', [out_channels, in_channels // group, *kernel_size]) + self.weight = Parameter(weight_init, name='weight') + if gamma_init is None: + gamma_init = initializer('ones', [out_channels]) + self.gamma = Parameter(gamma_init, name='gamma') + if beta_init is None: + beta_init = initializer('zeros', [out_channels]) + self.beta = Parameter(beta_init, name='beta') + if mean_init is None: + mean_init = initializer('zeros', [out_channels]) + self.moving_mean = Parameter( + mean_init, name='moving_mean', requires_grad=False) + if var_init is None: + var_init = initializer('ones', [out_channels]) + self.moving_variance = Parameter( + var_init, name='moving_variance', requires_grad=False) + + self.step = Parameter(initializer( + 'normal', [1], dtype=mstype.int32), name='step', requires_grad=False) + + self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=False, + num_bits=num_bits, + quant_delay=quant_delay, + per_channel=per_channel, + channel_size=out_channels, + symmetric=symmetric, + narrow_range=narrow_range) + + self.batchnorm_fold_train = P.BatchNormFold(epsilon=eps, + momentum=momentum, + is_training=True, + freeze_bn=freeze_bn) + self.batchnorm_fold_infer = P.BatchNormFold(epsilon=eps, + momentum=momentum, + is_training=False, + freeze_bn=freeze_bn) + self.correct_mul = P.CorrectionMul() + self.relu = P.ReLU() + self.batchnorm_fold2 = P.BatchNormFold2(freeze_bn=freeze_bn) + self.batchnorm_fold2_infer = P.BatchNormFold2(freeze_bn=0) + self.one = Tensor(1, mstype.int32) + self.assignadd = P.AssignAdd() + + def extend_repr(self): + s = 'fake={}, freeze_bn={}'.format(self.fake, self.freeze_bn) + return s + + def construct(self, x): + if self.training: + beta = self.beta + gamma = self.gamma + gmean = self.moving_mean + gvar = self.moving_variance + step = self.step + out_conv = self.conv(x, self.weight) + batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_train( + out_conv, gmean, gvar, step) + # BN fold1 + weight = self.correct_mul(self.weight, gamma, running_std) + if self.fake: + weight = self.fake_quant_weight(weight) + out = self.conv(x, weight) + # BN fold2 + out = self.batchnorm_fold2( + out, beta, gamma, batch_std, batch_mean, running_std, running_mean, step) + F.control_depend(out, self.assignadd(self.step, self.one)) + else: + step = self.step + out_conv = self.conv(x, self.weight) + batch_mean, batch_std, running_mean, running_std = self.batchnorm_fold_infer( + out_conv, self.moving_mean, self.moving_variance, step) + weight = self.correct_mul(self.weight, self.gamma, running_std) + if self.fake: + weight = self.fake_quant_weight(weight) + out = self.conv(x, weight) + out = self.batchnorm_fold2_infer(out, self.beta, self.gamma, batch_std, batch_mean, + running_std, running_mean, step) + return out + + +class Conv2dQuant(_Conv): + r""" + 2D convolution with fake quant op layer. + + For a more Detailed overview of Conv2d op. + + Args: + in_channels (int): The number of input channel :math:`C_{in}`. + out_channels (int): The number of output channel :math:`C_{out}`. + kernel_size (Union[int, tuple]): Specifies the height and width of the 2D convolution window. + stride (int): Specifies stride for all spatial dimensions with the same value. Default: 1. + pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". + padding: (int): Implicit paddings on both sides of the input. Default: 0. + dilation (int): Specifying the dilation rate to use for dilated convolution. Default: 1. + group (int): Split filter into groups, `in_ channels` and `out_channels` should be + divisible by the number of groups. Default: 1. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. + Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + + """ + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + pad_mode='same', + padding=0, + dilation=1, + group=1, + has_bias=False, + weight_init='normal', + bias_init='zeros', + quant_delay=0, + num_bits=8, + per_channel=False, + symmetric=False, + narrow_range=False): + kernel_size = twice(kernel_size) + super(Conv2dQuant, self).__init__(in_channels, out_channels, kernel_size, stride, pad_mode, padding, dilation, + group, has_bias, weight_init, bias_init) + self.conv2d = P.Conv2D(out_channel=self.out_channels, kernel_size=self.kernel_size, mode=1, + pad_mode=self.pad_mode, pad=self.padding, stride=self.stride, dilation=self.dilation, + group=self.group) + self.bias_add = P.BiasAdd() + if pad_mode not in ('valid', 'same', 'pad'): + raise ValueError('Attr \'pad_mode\' of \'Conv2d\' Op passed ' + + str(pad_mode) + ', should be one of values in \'valid\', \'same\', \'pad\'.') + self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=False, + num_bits=num_bits, + quant_delay=quant_delay, + per_channel=per_channel, + channel_size=out_channels, + symmetric=symmetric, + narrow_range=narrow_range) + + def construct(self, x): + weight_q = self.fake_quant_weight(self.weight) + out = self.conv2d(x, weight_q) + if self.has_bias: + return self.bias_add(out, self.bias) + return out + + +class DenseQuant(Cell): + r""" + The fully connected layer with fake quant op. + + For a more Detailed overview of Dense op. + + Args: + in_channels (int): The dimension of the input space. + out_channels (int): The dimension of the output space. + weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype + is same as input x. The values of str refer to the function `initializer`. Default: 'normal'. + bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is + same as input x. The values of str refer to the function `initializer`. Default: 'zeros'. + has_bias (bool): Specifies whether the layer uses a bias vector. Default: True. + activation (str): Regularizer function applied to the output of the layer, eg. 'relu'. Default: None. + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(N, C_{in}, H_{in}, W_{in})`. + + Outputs: + Tensor of shape :math:`(N, C_{out}, H_{out}, W_{out})`. + """ + + def __init__( + self, + in_channels, + out_channels, + weight_init='normal', + bias_init='zeros', + has_bias=True, + activation=None, + num_bits=8, + quant_delay=0, + per_channel=False, + symmetric=False, + narrow_range=False): + super(DenseQuant, self).__init__() + self.in_channels = check_int_positive(in_channels) + self.out_channels = check_int_positive(out_channels) + self.has_bias = check_bool(has_bias) + + if isinstance(weight_init, Tensor): + if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \ + weight_init.shape()[1] != in_channels: + raise ValueError("weight_init shape error") + + self.weight = Parameter(initializer( + weight_init, [out_channels, in_channels]), name="weight") + + if self.has_bias: + if isinstance(bias_init, Tensor): + if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels: + raise ValueError("bias_init shape error") + + self.bias = Parameter(initializer( + bias_init, [out_channels]), name="bias") + + self.matmul = P.MatMul(transpose_b=True) + self.bias_add = P.BiasAdd() + + self.activation = get_activation(activation) + self.activation_flag = self.activation is not None + self.fake_quant_weight = nn.FakeQuantWithMinMax(min_init=-6, + max_init=6, + ema=False, + num_bits=num_bits, + quant_delay=quant_delay, + per_channel=per_channel, + channel_size=out_channels, + symmetric=symmetric, + narrow_range=narrow_range) + + def construct(self, x): + """Use operators to construct to Dense layer.""" + output = self.fake_quant_weight(self.weight) + output = self.matmul(x, output) + if self.has_bias: + output = self.bias_add(output, self.bias) + if self.activation_flag: + return self.activation(output) + return output + + def extend_repr(self): + """A pretty print for Dense layer.""" + str_info = 'in_channels={}, out_channels={}, weight={}, has_bias={}'.format( + self.in_channels, self.out_channels, self.weight, self.has_bias) + if self.has_bias: + str_info = str_info + ', bias={}'.format(self.bias) + if self.activation_flag: + str_info = str_info + ', activation={}'.format(self.activation) + + return str_info + + +class ReLUQuant(Cell): + r""" + ReLUQuant activation function. Add Fake Quant OP after Relu OP. + + For a more Detailed overview of ReLU op. + + Args: + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + + Inputs: + - **x** (Tensor) - The input of ReLUQuant. + + Outputs: + Tensor, with the same type and shape as the `x`. + + """ + + def __init__(self, + num_bits=8, + quant_delay=0, + symmetric=False, + narrow_range=False): + super(ReLUQuant, self).__init__() + self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) + self.relu = P.ReLU() + + def construct(self, x): + x = self.relu(x) + x = self.fake_quant_act(x) + return x + + +class ReLU6Quant(Cell): + r""" + ReLU6Quant activation function. + + Add Fake Quant OP after Relu6. Not Recommand to used these cell for Fake Quant Op + Will climp the max range of the activation and the relu6 do the same operation. + For a more Detailed overview of ReLU6 op. + + Args: + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + + Inputs: + - **x** (Tensor) - The input of ReLU6Quant. + + Outputs: + Tensor, with the same type and shape as the `x`. + + """ + + def __init__(self, num_bits=8, quant_delay=0, symmetric=False, + narrow_range=False): + super(ReLU6Quant, self).__init__() + self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=0, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) + self.relu6 = P.ReLU6() + + def construct(self, x): + x = self.relu6(x) + x = self.fake_quant_act(x) + return x + + +class HSwishQuant(Cell): + r""" + HSwishQuant activation function. Add Fake Quant OP after HSwish OP. + + For a more Detailed overview of HSwish op. + + Args: + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + + Inputs: + - **x** (Tensor) - The input of HSwishQuant. + + Outputs: + Tensor, with the same type and shape as the `x`. + + """ + + def __init__(self, + num_bits=8, + quant_delay=0, + symmetric=False, + narrow_range=False): + super(HSwishQuant, self).__init__() + self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) + self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) + self.act = P.HSwish() + + def construct(self, x): + x = self.fake_quant_act_before(x) + x = self.act(x) + x = self.fake_quant_act_after(x) + return x + + +class HSigmoidQuant(Cell): + r""" + HSigmoidQuant activation function. Add Fake Quant OP before and after HSigmoid OP. + + For a more Detailed overview of HSigmoid op. + + Args: + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + + Inputs: + - **x** (Tensor) - The input of HSigmoidQuant. + + Outputs: + Tensor, with the same type and shape as the `x`. + + """ + + def __init__(self, + num_bits=8, + quant_delay=0, + symmetric=False, + narrow_range=False): + super(HSigmoidQuant, self).__init__() + self.fake_quant_act_before = nn.FakeQuantWithMinMax(min_init=0, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) + self.fake_quant_act_after = nn.FakeQuantWithMinMax(min_init=0, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) + self.act = P.HSigmoid() + + def construct(self, x): + x = self.fake_quant_act_before(x) + x = self.act(x) + x = self.fake_quant_act_after(x) + return x + + +class TensorAddQuant(Cell): + r""" + Add Fake Quant OP after TensorAdd OP. + + For a more Detailed overview of TensorAdd op. + + Args: + num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. + quant_delay (int): Quantization delay parameters according by global step. Default: 0. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + + Inputs: + - **x** (Tensor) - The input of TensorAddQuant. + + Outputs: + Tensor, with the same type and shape as the `x`. + + """ + + def __init__(self, + num_bits=8, + quant_delay=0, + symmetric=False, + narrow_range=False): + super(TensorAddQuant, self).__init__() + self.fake_quant_act = nn.FakeQuantWithMinMax(min_init=-6, + max_init=6, + num_bits=num_bits, + quant_delay=quant_delay, + ema=True, + symmetric=symmetric, + narrow_range=narrow_range) + self.add = P.TensorAdd() + + def construct(self, x1, x2): + x = self.add(x1, x2) + x = self.fake_quant_act(x) + return x diff --git a/mindspore/nn/layer/activation.py b/mindspore/nn/layer/activation.py index ad63dde8bc..12d6c74dcd 100644 --- a/mindspore/nn/layer/activation.py +++ b/mindspore/nn/layer/activation.py @@ -234,7 +234,7 @@ class Tanh(Cell): class GELU(Cell): - """ + r""" Gaussian error linear unit activation function. Applies GELU function to each element of the input. The input is a Tensor with any valid shape. @@ -332,15 +332,74 @@ class PReLU(Cell): return v +class HSwish(Cell): + r""" + rHard swish activation function. + + Applies hswish-type activation element-wise. The input is a Tensor with any valid shape. + + Hard swish is defined as: + + .. math:: + \text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6}, + + where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. + + Inputs: + - **input_data** (Tensor) - The input of Hswish. + + Outputs: + Tensor, with the same type and shape as the `input_data`. + + """ + def __init__(self): + super(HSwish, self).__init__() + self.hswish = P.HSwish() + + def construct(self, x): + return self.hswish(x) + + +class HSigmoid(Cell): + r""" + Hard sigmoid activation function. + + Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape. + + Hard sigmoid is defined as: + + .. math:: + \text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})), + + where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. + + Inputs: + - **input_data** (Tensor) - The input of HSigmoid. + + Outputs: + Tensor, with the same type and shape as the `input_data`. + + """ + def __init__(self): + super(HSigmoid, self).__init__() + self.hsigmoid = P.HSigmoid() + + def construct(self, x): + return self.hsigmoid(x) + + _activation = { 'softmax': Softmax, 'logsoftmax': LogSoftmax, 'relu': ReLU, + 'relu6': ReLU6, 'tanh': Tanh, 'gelu': GELU, 'sigmoid': Sigmoid, 'prelu': PReLU, - 'leakyrelu': LeakyReLU + 'leakyrelu': LeakyReLU, + 'hswish': HSwish, + 'hsigmoid': HSigmoid, } diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index fbe48aff97..1b18d9f248 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -172,6 +172,28 @@ def get_bprop_relu6(self): return bprop +@bprop_getters.register(P.HSwish) +def get_bprop_hswish(self): + """Grad definition for `HSwish` operation.""" + input_grad = G.HSwishGrad() + + def bprop(x, out, dout): + dx = input_grad(dout, x) + return (dx,) + return bprop + + +@bprop_getters.register(P.HSigmoid) +def get_bprop_hsigmoid(self): + """Grad definition for `HSigmoid` operation.""" + input_grad = G.HSigmoidGrad() + + def bprop(x, out, dout): + dx = input_grad(dout, x) + return (dx,) + return bprop + + @bprop_getters.register(P.Elu) def get_bprop_elu(self): """Grad definition for `Elu` operation.""" diff --git a/mindspore/ops/_grad/grad_quant_ops.py b/mindspore/ops/_grad/grad_quant_ops.py new file mode 100644 index 0000000000..5d4ad22392 --- /dev/null +++ b/mindspore/ops/_grad/grad_quant_ops.py @@ -0,0 +1,82 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Generate bprop for aware quantization ops""" + +from .. import operations as P +from .grad_base import bprop_getters +from ..composite.multitype_ops.zeros_like_impl import zeros_like + + +@bprop_getters.register(P.FakeQuantWithMinMax) +def get_bprop_fakequant_with_minmax(self): + """Generate bprop for FakeQuantWithMinMax""" + op = P.FakeQuantWithMinMaxGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) + + def bprop(x, x_min, x_max, out, dout): + dx = op(dout, x, x_min, x_max) + return dx, zeros_like(x_min), zeros_like(x_max) + + return bprop + + +@bprop_getters.register(P.FakeQuantWithMinMaxPerChannel) +def get_bprop_fakequant_with_minmax_perchannel(self): + """Generate bprop for FakeQuantWithMinMaxPerChannel""" + op = P.FakeQuantWithMinMaxPerChannelGrad(num_bits=self.num_bits, quant_delay=self.quant_delay) + + def bprop(x, x_min, x_max, out, dout): + dx = op(dout, x, x_min, x_max) + return dx, zeros_like(x_min), zeros_like(x_max) + + return bprop + + +@bprop_getters.register(P.BatchNormFold) +def get_bprop_batchnorm_fold(self): + """Generate bprop for BatchNormFold""" + op = P.BatchNormFoldGrad(self.epsilon, self.is_training, self.freeze_bn) + + def bprop(x, mean, variance, global_step, out, dout): + dx = op(dout[0], dout[1], x, out[0], out[1], global_step) + return dx, zeros_like(mean), zeros_like(variance), zeros_like(global_step) + + return bprop + + +@bprop_getters.register(P.CorrectionMul) +def get_bprop_correction_mul(self): + """Generate bprop for CorrectionMul""" + grad = P.CorrectionMulGrad() + + def bprop(x, batch_std, running_std, out, dout): + dx, d_batch_std = grad(dout, x, batch_std, running_std) + return dx, d_batch_std, zeros_like(running_std) + + return bprop + + +@bprop_getters.register(P.BatchNormFold2) +def get_bprop_batchnorm_fold2(self): + """Generate bprop for CorrectionAdd""" + op_f = P.BatchNormFold2Grad(freeze_bn=self.freeze_bn) + + def bprop(x, beta, gamma, batch_std, batch_mean, running_std, running_mean, global_step, out, dout): + d_batch_std, d_batch_mean, d_beta, d_gamma, d_x = op_f(dout, x, gamma, batch_std, batch_mean, running_std, + running_mean, global_step) + return d_x, d_beta, d_gamma, d_batch_std, d_batch_mean, zeros_like(running_std), zeros_like(running_mean), \ + zeros_like(global_step) + + return bprop diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 45cd856298..8bfca77b38 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -59,7 +59,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, LogSoftmax, MaxPool, AvgPool, Conv2DBackpropInput, - MaxPoolWithArgmax, OneHot, Pad, PReLU, ReLU, ReLU6, + MaxPoolWithArgmax, OneHot, Pad, PReLU, ReLU, ReLU6, HSwish, HSigmoid, ResizeBilinear, Sigmoid, SigmoidCrossEntropyWithLogits, SmoothL1Loss, Softmax, @@ -68,7 +68,8 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm, TopK, BinaryCrossEntropy, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, ApplyRMSProp, ApplyCenteredRMSProp) from .other_ops import Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, CheckValid, MakeRefKey - +from . import _quant_ops +from ._quant_ops import * __all__ = [ 'TensorAdd', @@ -138,6 +139,8 @@ __all__ = [ 'ReLU6', 'Elu', 'Sigmoid', + 'HSwish', + 'HSigmoid', 'Tanh', 'RandomChoiceWithMask', 'ResizeBilinear', @@ -241,4 +244,5 @@ __all__ = [ "ApplyCenteredRMSProp" ] +__all__.extend(_quant_ops.__all__) __all__.sort() diff --git a/mindspore/ops/operations/_grad_ops.py b/mindspore/ops/operations/_grad_ops.py index f38044ab6a..f0a9a2f658 100644 --- a/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/ops/operations/_grad_ops.py @@ -805,6 +805,38 @@ class SigmoidGrad(PrimitiveWithInfer): return out +class HSigmoidGrad(PrimitiveWithInfer): + """Gets the gradient of HSigmoid operation.""" + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output']) + + def infer_shape(self, y_grad_shape, x_shape): + return x_shape + + def infer_dtype(self, y_grad_dtype, x_dtype): + validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32)) + validator.check_typename("x dtype", x_dtype, (mstype.float16, mstype.float32)) + return x_dtype + + +class HSwishGrad(PrimitiveWithInfer): + """Gets the gradient of HSwish operation.""" + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['y_grad', 'x'], outputs=['output']) + + def infer_shape(self, y_grad_shape, x_shape): + return x_shape + + def infer_dtype(self, y_grad_dtype, x_dtype): + validator.check_typename("y_grad dtype", y_grad_dtype, (mstype.float16, mstype.float32)) + validator.check_typename("x_ dtype", x_dtype, (mstype.float16, mstype.float32)) + return x_dtype + + class SigmoidCrossEntropyWithLogitsGrad(PrimitiveWithInfer): """Computes the gradients of `SigmoidCrossEntropyWithLogits`.""" diff --git a/mindspore/ops/operations/_quant_ops.py b/mindspore/ops/operations/_quant_ops.py new file mode 100644 index 0000000000..14d1bc9234 --- /dev/null +++ b/mindspore/ops/operations/_quant_ops.py @@ -0,0 +1,525 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0(the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http: // www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Operators for quantization.""" + +from ..._checkparam import ParamValidator as validator +from ..._checkparam import Rel, check_bool, check_int_positive, check_int +from ..primitive import PrimitiveWithInfer, prim_attr_register +from ...common import dtype as mstype + +__all__ = ["FakeQuantWithMinMax", + "FakeQuantWithMinMaxGrad", + "FakeQuantWithMinMaxPerChannel", + "FakeQuantWithMinMaxPerChannelGrad", + "BatchNormFold", + "BatchNormFoldGrad", + "CorrectionMul", + "CorrectionMulGrad", + "BatchNormFold2", + "BatchNormFold2Grad", + ] + + +class FakeQuantWithMinMax(PrimitiveWithInfer): + r""" + Simulate the quantize and dequantize operations in training time. + + Args: + num_bits (int) : Number bits for aware quantilization. Default: 8. + ema (bool): Use EMA algorithm update value min and max. Default: False. + ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. + quant_delay (int): Quantilization delay parameter. Before delay step in training time not update + simulate aware quantize funcion. After delay step in training time begin simulate the aware + quantize funcion. Default: 0. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + training (bool): Training the network or not. Default: True. + + Inputs: + - **x** (Tensor) : float32 Tensor representing the shape of the output tensor. + - **min** (Tensor) : Value of the min range of the input data x. + - **max** (Tensor) : Value of the max range of the input data x. + + Outputs: + - Tensor: Simulate quantize tensor of x. + + Examples: + >>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32) + >>> min_tensor = Tensor(np.array([-6]), mstype.float32) + >>> max_tensor = Tensor(np.array([6]), mstype.float32) + >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) + """ + support_quant_bit = [4, 7, 8] + + @prim_attr_register + def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, + training=True): + """init FakeQuantWithMinMax OP""" + if num_bits not in self.support_quant_bit: + raise ValueError("Attr \'num_bits\' is not support.") + if ema and not ema_decay: + raise ValueError( + "Attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = check_bool(ema) + self.symmetric = check_bool(symmetric) + self.narrow_range = check_bool(narrow_range) + self.training = check_bool(training) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH) + self.num_bits = check_int_positive(num_bits) + self.quant_delay = check_int(quant_delay) + self.init_prim_io_names(inputs=['x', 'min', 'max'], + outputs=['out']) + + def infer_shape(self, x_shape, min_shape, max_shape): + validator.check_integer("x shape", len(x_shape), 1, Rel.GT) + validator.check("min shape", min_shape, "max shape", max_shape) + validator.check_integer("min shape", len(min_shape), 1, Rel.EQ) + validator.check_integer("max shape", len(min_shape), 1, Rel.EQ) + return x_shape + + def infer_dtype(self, x_type, min_type, max_type): + validator.check_typename( + "x type", x_type, (mstype.float16, mstype.float32)) + validator.check_typename("min type", min_type, + (mstype.float16, mstype.float32)) + validator.check_typename("max type", max_type, + (mstype.float16, mstype.float32)) + return x_type + + +class FakeQuantWithMinMaxGrad(PrimitiveWithInfer): + """Performs grad of FakeQuantWithMinMax operation.""" + support_quant_bit = [4, 8] + + @prim_attr_register + def __init__(self, num_bits=8, quant_delay=0): + if num_bits not in self.support_quant_bit: + raise ValueError("Attr \'num_bits\' is not support.") + + self.quant_delay = check_int(quant_delay) + self.num_bits = check_int_positive(num_bits) + self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], + outputs=['dx']) + + def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): + validator.check("dout shape", dout_shape, "x shape", x_shape) + validator.check("min shape", min_shape, "max shape", max_shape) + validator.check_integer("min shape", len(min_shape), 1, Rel.EQ) + validator.check_integer("max shape", len(min_shape), 1, Rel.EQ) + return dout_shape + + def infer_dtype(self, dout_type, x_type, min_type, max_type): + validator.check_typename( + "dout type", dout_type, (mstype.float16, mstype.float32)) + validator.check_typename( + "x type", x_type, (mstype.float16, mstype.float32)) + validator.check_typename("min type", min_type, + (mstype.float16, mstype.float32)) + validator.check_typename("max type", max_type, + (mstype.float16, mstype.float32)) + return dout_type + + +class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer): + r""" + Simulate the quantize and dequantize operations in training time base on per channel. + + Args: + num_bits (int) : Number bits to quantilization. Default: 8. + ema (bool): Use EMA algorithm update tensor min and tensor max. Default: False. + ema_decay (int) : EMA algorithm decay parameter. Default: 0.999. + quant_delay (int): Quantilization delay parameter. Before delay step in training time not + update the weight data to simulate quantize operation. After delay step in training time + begin simulate the quantize operation. Default: 0. + symmetric (bool): Quantization algorithm use symmetric or not. Default: False. + narrow_range (bool): Quantization algorithm use narrow range or not. Default: False. + training (bool): Training the network or not. Default: True. + + Inputs: + - **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor. + - **min** (int, float) : Value of the min range of the input data. + - **max** (int, float) : Value of the max range of the input data. + + Outputs: + - Tensor, has the same type as input. + + Examples: + >>> input_tensor = Tensor(np.random.rand(3,4,5,5), mstype.float32) + >>> min_tensor = Tensor(np.array([-6.0, -6.5, -4.0, -5.0]), mstype.float32) + >>> max_tensor = Tensor(np.array([6.0, 6.5, 4.0, 5.0]), mstype.float32) + >>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(input_tensor, min_tensor, max_tensor) + """ + support_quant_bit = [4, 8] + channel_idx = 0 + + @prim_attr_register + def __init__(self, num_bits=8, ema=False, ema_decay=0.999, quant_delay=0, symmetric=False, narrow_range=False, + training=True): + """init FakeQuantWithMinMaxPerChannel OP""" + if num_bits not in self.support_quant_bit: + raise ValueError("Attr \'num_bits\' is not support.") + if ema and not ema_decay: + raise ValueError( + "Attr \'ema\' and \'ema_decay\' should set together.") + + self.ema = check_bool(ema) + self.symmetric = check_bool(symmetric) + self.narrow_range = check_bool(narrow_range) + self.training = check_bool(training) + self.ema_decay = validator.check_number_range( + 'ema_decay', ema_decay, 0, 1, Rel.INC_BOTH) + self.num_bits = check_int_positive(num_bits) + self.quant_delay = check_int(quant_delay) + self.init_prim_io_names(inputs=['x', 'min', 'max'], + outputs=['out']) + + def infer_shape(self, x_shape, min_shape, max_shape): + validator.check_integer("x shape", len(x_shape), 1, Rel.GT) + validator.check_integer( + "min len", min_shape[0], x_shape[self.channel_idx], Rel.EQ) + validator.check_integer( + "max len", max_shape[0], x_shape[self.channel_idx], Rel.EQ) + return x_shape + + def infer_dtype(self, x_type, min_type, max_type): + validator.check_typename( + "x type", x_type, (mstype.float16, mstype.float32)) + validator.check_typename("min type", min_type, + (mstype.float16, mstype.float32)) + validator.check_typename("max type", max_type, + (mstype.float16, mstype.float32)) + return x_type + + +class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer): + """Performs grad of FakeQuantWithMinMaxPerChannel operation.""" + support_quant_bit = [4, 8] + + @prim_attr_register + def __init__(self, num_bits=8, quant_delay=0): + """init FakeQuantWithMinMaxPerChannel Fill""" + if num_bits not in self.support_quant_bit: + raise ValueError("Attr \'num_bits\' is not support.") + + self.quant_delay = check_int(quant_delay) + self.num_bits = check_int_positive(num_bits) + self.init_prim_io_names(inputs=['dout', 'x', 'min', 'max'], + outputs=['dx']) + + def infer_shape(self, dout_shape, x_shape, min_shape, max_shape): + validator.check("dout shape", dout_shape, "x shape", x_shape) + validator.check("min shape", min_shape, "max shape", max_shape) + return dout_shape + + def infer_dtype(self, dout_type, x_type, min_type, max_type): + validator.check_typename( + "dout", dout_type, (mstype.float16, mstype.float32)) + validator.check_typename("x", x_type, (mstype.float16, mstype.float32)) + validator.check_typename( + "min", min_type, (mstype.float16, mstype.float32)) + validator.check_typename( + "max", max_type, (mstype.float16, mstype.float32)) + return dout_type + + +class BatchNormFold(PrimitiveWithInfer): + """ + Batch normalization folded. + + Args: + momentum (float): Momentum value should be [0, 1]. Default: 0.1. + epsilon (float): A small float number to avoid dividing by 0. 1e-12 if dtype in + float32 else 1e-3. Default: 1e-12. + is_training (bool): In training mode set True, else set False. Default: True. + freeze_bn (int): Delay in steps at which computation switches from regular batch + norm to frozen mean and std. Default: 0. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(N, C)`. + - **mean** (Tensor) - Tensor of shape :math:`(C,)`. + - **variance** (Tensor) - Tensor of shape :math:`(C,)`. + - **global_step** (Tensor) - Tensor to record current global step. + + Outputs: + Tuple of 4 Tensor, the normalized input and the updated parameters. + + - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`. + - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. + - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`. + - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. + + """ + channel = 1 + + @prim_attr_register + def __init__(self, momentum=0.1, epsilon=1e-12, is_training=True, freeze_bn=0): + """init batch norm fold layer""" + self.momentum = validator.check_number_range( + 'momentum', momentum, 0, 1, Rel.INC_BOTH) + self.epsilon = validator.check_float_positive('epsilon', epsilon) + self.is_training = check_bool(is_training) + self.freeze_bn = check_int(freeze_bn) + + self.init_prim_io_names(inputs=['x', 'mean', 'variance', 'global_step'], + outputs=['batch_mean', 'batch_std', 'running_mean', 'running_std']) + + def infer_shape(self, x_shape, mean_shape, variance_shape, global_step_shape): + validator.check("mean shape", mean_shape, + "gamma_shape", variance_shape) + validator.check("mean_shape size", + mean_shape[0], "input channel", x_shape[self.channel]) + validator.check_integer("global_step shape", + len(global_step_shape), 1, Rel.EQ) + return mean_shape, mean_shape, mean_shape, mean_shape + + def infer_dtype(self, x_type, mean_type, variance_type, global_step_type): + validator.check("input type", x_type, "mean type", mean_type) + validator.check("input type", x_type, "variance type", variance_type) + validator.check_typename("input type", x_type, + (mstype.float16, mstype.float32)) + validator.check_typename( + "global_step type", global_step_type, (mstype.int32,)) + return x_type, x_type, x_type, x_type + + +class BatchNormFoldGrad(PrimitiveWithInfer): + """Performs grad of BatchNormFold operation.""" + channel = 1 + + @prim_attr_register + def __init__(self, epsilon=1e-12, is_training=True, freeze_bn=0): + """init BatchNormGrad layer""" + self.is_training = check_bool(is_training) + self.freeze_bn = check_int(freeze_bn) + self.epsilon = validator.check_float_positive('epsilon', epsilon) + self.init_prim_io_names(inputs=['d_batch_mean', 'd_batch_std', 'x', 'batch_mean', 'batch_std', 'global_step'], + outputs=['dx']) + + def infer_shape(self, d_batch_mean_shape, d_batch_std_shape, x_shape, batch_mean_shape, batch_std_shape, + global_step_shape): + validator.check("d_batch_mean shape", d_batch_mean_shape, + "d_batch_std shape", d_batch_std_shape) + validator.check("d_batch_mean shape", d_batch_mean_shape, + "batch_mean shape", batch_mean_shape) + validator.check("d_batch_mean shape", d_batch_mean_shape, + "batch_std shape", batch_std_shape) + validator.check( + "x_shape shape", d_batch_mean_shape[0], "input channel", x_shape[self.channel]) + validator.check_integer("global_step shape", + len(global_step_shape), 1, Rel.EQ) + return x_shape + + def infer_dtype(self, d_batch_mean_type, d_batch_std_type, x_type, batch_mean_type, batch_std_type, + global_step_type): + validator.check("input type", x_type, + "d_batch_mean type", d_batch_mean_type) + validator.check("input type", x_type, + "d_batch_std type", d_batch_std_type) + validator.check("input type", x_type, + "batch_mean type", batch_mean_type) + validator.check("input type", x_type, "batch_std type", batch_std_type) + validator.check_typename("input type", x_type, + (mstype.float16, mstype.float32)) + validator.check_typename( + "global_step type", global_step_type, (mstype.int32,)) + return x_type + + +class CorrectionMul(PrimitiveWithInfer): + """ + Scale the weights with a correction factor to the long term statistics + prior to quantization. This ensures that there is no jitter in the quantized weights + due to batch to batch variation. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(N, C)`. + - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. + - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. + + Outputs: + - **out** (Tensor) - Tensor has the same shape as x. + + """ + channel = 0 + + @prim_attr_register + def __init__(self): + """init correction mul layer""" + self.init_prim_io_names(inputs=['x', 'batch_std', 'running_std'], + outputs=['out']) + + def infer_shape(self, x_shape, batch_std_shape, running_std_shape): + validator.check("batch_std shape", batch_std_shape, + "running_std shape", running_std_shape) + validator.check( + "batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel]) + return x_shape + + def infer_dtype(self, x_type, batch_std_type, running_std_type): + validator.check("batch_std type", batch_std_type, + "running_std type", running_std_type) + validator.check("batch_std_type", batch_std_type, "x_type", x_type) + validator.check_typename( + "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) + return x_type + + +class CorrectionMulGrad(PrimitiveWithInfer): + """Performs grad of CorrectionMul operation.""" + channel = 0 + + @prim_attr_register + def __init__(self): + """init correction mul layer""" + self.init_prim_io_names(inputs=['dout', 'x', 'gamma', 'running_std'], + outputs=['dx', 'd_gamma']) + + def infer_shape(self, dout_shape, x_shape, gamma_shape, running_std_shape): + validator.check("dout shape", dout_shape, "x_shape x", x_shape) + validator.check( + "gamma size", gamma_shape[0], "dout channel size", dout_shape[self.channel]) + validator.check( + "running_std size", running_std_shape[0], "dout channel size", dout_shape[self.channel]) + return x_shape, gamma_shape + + def infer_dtype(self, dout_type, x_type, gamma_type, running_std_type): + validator.check("x type", x_type, "dout type", dout_type) + validator.check("gamma type", gamma_type, "dout type", dout_type) + validator.check("running_std type", running_std_type, + "dout type", dout_type) + validator.check_typename( + "dout type", dout_type, (mstype.float16, mstype.float32)) + return x_type, x_type + + +class BatchNormFold2(PrimitiveWithInfer): + """ + Scale the bias with a correction factor to the long term statistics + prior to quantization. This ensures that there is no jitter in the quantized bias + due to batch to batch variation. + + Inputs: + - **x** (Tensor) - Tensor of shape :math:`(N, C)`. + - **beta** (Tensor) - Tensor of shape :math:`(C,)`. + - **gamma** (Tensor) - Tensor of shape :math:`(C,)`. + - **batch_std** (Tensor) - Tensor of shape :math:`(C,)`. + - **batch_mean** (Tensor) - Tensor of shape :math:`(C,)`. + - **running_std** (Tensor) - Tensor of shape :math:`(C,)`. + - **running_mean** (Tensor) - Tensor of shape :math:`(C,)`. + - **global_step** (Tensor) - Tensor to record current global step. + + Outputs: + - **y** (Tensor) - Tensor has the same shape as x. + + """ + channel = 1 + + @prim_attr_register + def __init__(self, freeze_bn=0): + """init conv2d fold layer""" + self.freeze_bn = check_int(freeze_bn) + self.init_prim_io_names(inputs=['x', 'beta', 'gamma', 'batch_std', 'batch_mean', + 'running_std', 'running_mean', 'global_step'], + outputs=['y']) + + def infer_shape(self, x_shape, beta_shape, gamma_shape, batch_std_shape, running_std_shape, batch_mean_shape, + running_mean_shape, global_step_shape): + validator.check("batch_std shape", batch_std_shape, + "running_std shape", running_std_shape) + validator.check("batch_std shape", batch_std_shape, + "batch_mean shape", batch_mean_shape) + validator.check("batch_std shape", batch_std_shape, + "beta shape", beta_shape) + validator.check("batch_std shape", batch_std_shape, + "running_mean shape", running_mean_shape) + validator.check("batch_std shape", batch_std_shape, + "batch_mean shape", gamma_shape) + validator.check( + "batch_std size", batch_std_shape[0], "x_shape channel size", x_shape[self.channel]) + validator.check_integer("global_step shape", + len(global_step_shape), 1, Rel.EQ) + return x_shape + + def infer_dtype(self, x_type, beta_type, gamma_type, batch_std_type, running_std_type, batch_mean_type, + running_mean_type, global_step_type): + validator.check("batch_std type", batch_std_type, + "running_std type", running_std_type) + validator.check("batch_std type", batch_std_type, + "batch_mean type", batch_mean_type) + validator.check("batch_std type", batch_std_type, + "beta type", beta_type) + validator.check("batch_std type", batch_std_type, + "running_mean type", running_mean_type) + validator.check("batch_std type", batch_std_type, + "gamma type", gamma_type) + validator.check("x_type", x_type, "batch_std type", batch_std_type) + validator.check_typename( + "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) + validator.check_typename( + "global_step type", global_step_type, (mstype.int32,)) + return x_type + + +class BatchNormFold2Grad(PrimitiveWithInfer): + """Performs grad of CorrectionAddGrad operation.""" + channel = 1 + + @prim_attr_register + def __init__(self, freeze_bn=0): + """init MulFold layer""" + self.freeze_bn = freeze_bn + self.init_prim_io_names(inputs=['dout', 'x', 'gamma', + 'batch_std', 'batch_mean', + 'running_std', 'running_mean', 'global_step'], + outputs=['d_batch_std', 'd_batch_mean', 'd_beta', 'd_gamma', 'dx']) + + def infer_shape(self, dout_shape, x_shape, gamma_shape, + batch_std_shape, batch_mean_shape, + running_std_shape, running_mean_shape, global_step_shape): + validator.check("batch_std shape", batch_std_shape, + "batch_mean shape", batch_mean_shape) + validator.check("batch_std shape", batch_std_shape, + "running_std shape", running_std_shape) + validator.check("batch_std shape", batch_std_shape, + "running_mean shape", running_mean_shape) + validator.check("batch_std shape", batch_std_shape, + "gamma shape", gamma_shape) + validator.check( + "batch_std size", batch_std_shape[0], "dout channel size", dout_shape[self.channel]) + validator.check_integer("global_step shape", + len(global_step_shape), 1, Rel.EQ) + return gamma_shape, gamma_shape, gamma_shape, gamma_shape, x_shape + + def infer_dtype(self, dout_type, x_type, gamma_type, + batch_std_type, batch_mean_type, + running_std_type, running_mean_type, global_step_type): + validator.check("batch_std type", batch_std_type, + "batch_mean type", batch_mean_type) + validator.check("batch_std type", batch_std_type, + "gamma type", gamma_type) + validator.check("batch_std type", batch_std_type, + "running_std type", running_std_type) + validator.check("batch_std type", batch_std_type, + "running_mean type", running_mean_type) + validator.check("batch_std_type", batch_std_type, + "dout type", dout_type) + validator.check_typename( + "batch_std type", batch_std_type, (mstype.float16, mstype.float32)) + validator.check_typename( + "global_step type", global_step_type, (mstype.int32,)) + return gamma_type, gamma_type, gamma_type, gamma_type, gamma_type diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 538d7f3826..91f6d7ec01 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -207,7 +207,7 @@ class ReLU6(PrimitiveWithInfer): class Elu(PrimitiveWithInfer): - """ + r""" Computes exponential linear: `alpha * (exp(x) - 1)` if x < 0, `x` otherwise. The data type of input tensor should be float. @@ -242,6 +242,40 @@ class Elu(PrimitiveWithInfer): return input_x +class HSwish(PrimitiveWithInfer): + r""" + Hard swish activation function. + + Applies hswish-type activation element-wise. The input is a Tensor with any valid shape. + + Hard swish is defined as: + + .. math:: + \text{hswish}(x_{i}) = x_{i} * \frac{ReLU6(x_{i} + 3)}{6}, + + where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. + + Inputs: + - **input_data** (Tensor) - The input of Hswish. + + Outputs: + Tensor, with the same type and shape as the `input_data`. + + """ + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, xshape): + return xshape + + def infer_dtype(self, x_dtype): + validator.check_subclass("x_dtype", x_dtype, mstype.tensor) + validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) + return x_dtype + + + class Sigmoid(PrimitiveWithInfer): r""" Sigmoid activation function. @@ -258,6 +292,7 @@ class Sigmoid(PrimitiveWithInfer): Outputs: Tensor, with the same type and shape as the input_x. + """ @prim_attr_register @@ -273,6 +308,40 @@ class Sigmoid(PrimitiveWithInfer): return input_x +class HSigmoid(PrimitiveWithInfer): + r""" + Hard sigmoid activation function. + + Applies hard sigmoid activation element-wise. The input is a Tensor with any valid shape. + + Hard sigmoid is defined as: + + .. math:: + \text{hsigmoid}(x_{i}) = max(0, min(1, \ftac{2 * x_{i} + 5}{10})), + + where :math:`x_{i}` is the :math:`i`-th slice along the given dim of the input Tensor. + + Inputs: + - **input_data** (Tensor) - The input of HSigmoid. + + Outputs: + Tensor, with the same type and shape as the `input_data`. + + """ + + @prim_attr_register + def __init__(self): + self.init_prim_io_names(inputs=['x'], outputs=['output']) + + def infer_shape(self, x_shape): + return x_shape + + def infer_dtype(self, x_dtype): + validator.check_subclass("x_dtype", x_dtype, mstype.tensor) + validator.check_typename("x_dtype", x_dtype, (mstype.float16, mstype.float32)) + return x_dtype + + class Tanh(PrimitiveWithInfer): r""" Tanh activation function. diff --git a/tests/ut/python/nn/test_dense.py b/tests/ut/python/nn/test_dense.py index 8581576c6b..0845983bb0 100644 --- a/tests/ut/python/nn/test_dense.py +++ b/tests/ut/python/nn/test_dense.py @@ -27,11 +27,6 @@ def test_dense_none(): nn.Dense(3, 2, None, None) -def test_dense_invalid_activation(): - with pytest.raises(KeyError): - nn.Dense(3, 2, activation='relu6') - - @non_graph_engine def test_dense_str_activation(): dense = nn.Dense(1, 1, activation='relu') diff --git a/tests/ut/python/pynative_mode/nn/test_activation.py b/tests/ut/python/pynative_mode/nn/test_activation.py index 7230fa272b..1b8a6f5d76 100644 --- a/tests/ut/python/pynative_mode/nn/test_activation.py +++ b/tests/ut/python/pynative_mode/nn/test_activation.py @@ -51,11 +51,6 @@ def test_activation_empty(): assert nn.get_activation('') is None -def test_activation_invalid(): - with pytest.raises(KeyError): - nn.get_activation('relu6') - - # test softmax def test_softmax_axis(): layer = nn.Softmax(1) diff --git a/tests/ut/python/pynative_mode/nn/test_dense.py b/tests/ut/python/pynative_mode/nn/test_dense.py index 48bfcc6674..cc9d280521 100644 --- a/tests/ut/python/pynative_mode/nn/test_dense.py +++ b/tests/ut/python/pynative_mode/nn/test_dense.py @@ -68,11 +68,6 @@ def test_dense_none(): nn.Dense(3, 2, None, None) -def test_dense_invalid_activation(): - with pytest.raises(KeyError): - nn.Dense(3, 2, activation='relu6') - - def test_dense_str_activation(): dense = nn.Dense(1, 1, activation='relu') assert isinstance(dense.activation, nn.ReLU)