fix FakeQuantPerLayer/FakeQuantPerLayerGrad symmetric bug

pull/2194/head
王东旭 5 years ago committed by wangdongxu
parent 7038df8b99
commit 4e09ae83eb

@ -214,7 +214,7 @@ class BatchNormFoldCell(Cell):
Batch normalization folded. Batch normalization folded.
Args: Args:
momentum (float): Momentum value should be [0, 1]. Default: 0.1. momentum (float): Momentum value should be [0, 1]. Default: 0.9.
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
float32 else 1e-3. Default: 1e-5. float32 else 1e-3. Default: 1e-5.
freeze_bn (int): Delay in steps at which computation switches from regular batch freeze_bn (int): Delay in steps at which computation switches from regular batch
@ -280,6 +280,7 @@ class FakeQuantWithMinMax(Cell):
ema (bool): Exponential Moving Average algorithm update min and max. Default: False. ema (bool): Exponential Moving Average algorithm update min and max. Default: False.
ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999. ema_decay (float): Exponential Moving Average algorithm parameter. Default: 0.999.
per_channel (bool): Quantization by layer or channel. Default: False. per_channel (bool): Quantization by layer or channel. Default: False.
channel_axis (int): Quantization by channel axis. Default: 1.
out_channels (int): declarate the min and max channel size, Default: 1. out_channels (int): declarate the min and max channel size, Default: 1.
quant_delay (int): Quantization delay parameters according by global step. Default: 0. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False. symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
@ -391,17 +392,17 @@ class Conv2dBatchNormQuant(Cell):
pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same". pad_mode: (str): Specifies padding mode. The optional values are "same", "valid", "pad". Default: "same".
padding: (int): Implicit paddings on both sides of the input. Default: 0. padding: (int): Implicit paddings on both sides of the input. Default: 0.
eps (int): Parameters for BatchNormal. Default: 1e-5. eps (int): Parameters for BatchNormal. Default: 1e-5.
momentum (int): Parameters for BatchNormal op. Default: 0.9. momentum (int): Parameters for BatchNormal op. Default: 0.997.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
convolution kernel. Default: 'None'. convolution kernel. Default: 'normal'.
beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the beta_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
beta vector. Default: 'None'. beta vector. Default: 'zeros'.
gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the gamma_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
gamma vector. Default: 'None'. gamma vector. Default: 'ones'.
mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the mean_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
mean vector. Default: 'None'. mean vector. Default: 'zeros'.
var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the var_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the
variance vector. Default: 'None'. variance vector. Default: 'ones'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000. freeze_bn (int): Quantization freeze BatchNormal op according by global step. Default: 100000.
fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True. fake (bool): Conv2dBatchNormQuant Cell add FakeQuantWithMinMax op or not. Default: True.
@ -434,11 +435,11 @@ class Conv2dBatchNormQuant(Cell):
group=1, group=1,
eps=1e-5, eps=1e-5,
momentum=0.997, momentum=0.997,
weight_init=None, weight_init='normal',
beta_init=None, beta_init='zeros',
gamma_init=None, gamma_init='ones',
mean_init=None, mean_init='zeros',
var_init=None, var_init='ones',
quant_delay=0, quant_delay=0,
freeze_bn=100000, freeze_bn=100000,
fake=True, fake=True,
@ -477,8 +478,7 @@ class Conv2dBatchNormQuant(Cell):
pad=padding, pad=padding,
stride=self.stride, stride=self.stride,
dilation=self.dilation) dilation=self.dilation)
if weight_init is None: weight_shape = [1, in_channels, *self.kernel_size]
weight_init = initializer('normal', [1, in_channels, *self.kernel_size])
channel_axis = 1 channel_axis = 1
else: else:
self.conv = P.Conv2D(out_channel=out_channels, self.conv = P.Conv2D(out_channel=out_channels,
@ -488,24 +488,16 @@ class Conv2dBatchNormQuant(Cell):
stride=self.stride, stride=self.stride,
dilation=self.dilation, dilation=self.dilation,
group=group) group=group)
if weight_init is None: weight_shape = [out_channels, in_channels // group, *self.kernel_size]
weight_init = initializer('normal', [out_channels, in_channels // group, *self.kernel_size])
channel_axis = 0 channel_axis = 0
self.weight = Parameter(weight_init, name='weight') self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
# initialize batchnorm Parameter # initialize batchnorm Parameter
if gamma_init is None: self.gamma = Parameter(initializer(gamma_init, [out_channels]), name='gamma')
gamma_init = initializer('ones', [out_channels]) self.beta = Parameter(initializer(beta_init, [out_channels]), name='beta')
self.gamma = Parameter(gamma_init, name='gamma') self.moving_mean = Parameter(initializer(mean_init, [out_channels]), name='moving_mean', requires_grad=False)
if beta_init is None: self.moving_variance = Parameter(initializer(var_init, [out_channels]), name='moving_variance',
beta_init = initializer('zeros', [out_channels]) requires_grad=False)
self.beta = Parameter(beta_init, name='beta')
if mean_init is None:
mean_init = initializer('zeros', [out_channels])
self.moving_mean = Parameter(mean_init, name='moving_mean', requires_grad=False)
if var_init is None:
var_init = initializer('ones', [out_channels])
self.moving_variance = Parameter(var_init, name='moving_variance', requires_grad=False)
# initialize fake ops # initialize fake ops
self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6,
@ -588,8 +580,8 @@ class Conv2dQuant(Cell):
divisible by the number of groups. Default: 1. divisible by the number of groups. Default: 1.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: False. has_bias (bool): Specifies whether the layer uses a bias vector. Default: False.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel. weight_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the convolution kernel.
Default: None. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: None. bias_init (Union[Tensor, str, Initializer, numbers.Number]): Initializer for the bias vector. Default: 'zeros'.
quant_delay (int): Quantization delay parameters according by global step. Default: 0. quant_delay (int): Quantization delay parameters according by global step. Default: 0.
num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8. num_bits (int): Quantization number bit, support 4 and 8bit. Default: 8.
per_channel (bool): FakeQuantWithMinMax Parameters. Default: False. per_channel (bool): FakeQuantWithMinMax Parameters. Default: False.
@ -619,8 +611,8 @@ class Conv2dQuant(Cell):
dilation=1, dilation=1,
group=1, group=1,
has_bias=False, has_bias=False,
weight_init=None, weight_init='normal',
bias_init=None, bias_init='zeros',
quant_delay=0, quant_delay=0,
num_bits=8, num_bits=8,
per_channel=False, per_channel=False,
@ -641,15 +633,14 @@ class Conv2dQuant(Cell):
self.group = group self.group = group
self.quant_delay = quant_delay self.quant_delay = quant_delay
if weight_init is None: weight_shape = [out_channels, in_channels // group, *self.kernel_size]
weight_init = initializer( self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
'normal', [out_channels, in_channels // group, *self.kernel_size])
self.weight = Parameter(weight_init, name='weight')
if bias_init is None:
bias_init = initializer('zeros', [out_channels])
if has_bias:
self.bias = Parameter(bias_init, name='bias')
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
if check_bool(has_bias):
self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias')
else:
self.bias = None
self.conv = P.Conv2D(out_channel=self.out_channels, self.conv = P.Conv2D(out_channel=self.out_channels,
kernel_size=self.kernel_size, kernel_size=self.kernel_size,
@ -738,8 +729,8 @@ class DenseQuant(Cell):
self.has_bias = check_bool(has_bias) self.has_bias = check_bool(has_bias)
if isinstance(weight_init, Tensor): if isinstance(weight_init, Tensor):
if weight_init.dim() != 2 or weight_init.shape[0] != out_channels or \ if weight_init.dim() != 2 or weight_init.shape()[0] != out_channels or \
weight_init.shape[1] != in_channels: weight_init.shape()[1] != in_channels:
raise ValueError("weight_init shape error") raise ValueError("weight_init shape error")
self.weight = Parameter(initializer( self.weight = Parameter(initializer(
@ -747,7 +738,7 @@ class DenseQuant(Cell):
if self.has_bias: if self.has_bias:
if isinstance(bias_init, Tensor): if isinstance(bias_init, Tensor):
if bias_init.dim() != 1 or bias_init.shape[0] != out_channels: if bias_init.dim() != 1 or bias_init.shape()[0] != out_channels:
raise ValueError("bias_init shape error") raise ValueError("bias_init shape error")
self.bias = Parameter(initializer( self.bias = Parameter(initializer(

@ -65,7 +65,6 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW", momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0, data_format="NCHW",
kernel_name="batchnorm_fold"): kernel_name="batchnorm_fold"):
"""batchnorm_fold TBE op""" """batchnorm_fold TBE op"""
momentum = 1.0 - momentum
util.check_kernel_name(kernel_name) util.check_kernel_name(kernel_name)
data_format = data_format.upper() data_format = data_format.upper()
if data_format != "NCHW": if data_format != "NCHW":
@ -120,13 +119,12 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
variance_div = te.lang.cce.vmuls(x_square_sum, num_rec) variance_div = te.lang.cce.vmuls(x_square_sum, num_rec)
mean_square = te.lang.cce.vmul(batch_mean, batch_mean) mean_square = te.lang.cce.vmul(batch_mean, batch_mean)
batch_var_biased = te.lang.cce.vsub(variance_div, mean_square) batch_var_biased = te.lang.cce.vsub(variance_div, mean_square)
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_var_biased, epsilon))
if num == 1: if num == 1:
batch_var_scaler = 0.0 batch_var_scaler = 0.0
else: else:
batch_var_scaler = float(num) / (num - 1) batch_var_scaler = float(num) / (num - 1)
batch_variance = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler) batch_var_unbiased = te.lang.cce.vmuls(batch_var_biased, batch_var_scaler)
batch_std = te.lang.cce.vsqrt(te.lang.cce.vadds(batch_variance, epsilon))
factor = 1.0 - momentum factor = 1.0 - momentum
factor_reverse = momentum factor_reverse = momentum
@ -134,7 +132,7 @@ def batchnorm_fold(x, x_sum, x_square_sum, mean, variance,
mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse) mean_mul_rev = te.lang.cce.vmuls(mean, factor_reverse)
mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev) mean_updated = te.lang.cce.vadd(mean_mul, mean_mul_rev)
var_mul = te.lang.cce.vmuls(batch_variance, factor) var_mul = te.lang.cce.vmuls(batch_var_unbiased, factor)
var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse) var_mul_rev = te.lang.cce.vmuls(variance, factor_reverse)
variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev) variance_updated = te.lang.cce.vadd(var_mul, var_mul_rev)

@ -50,15 +50,16 @@ def _fake_quant_per_layer_tbe():
@fusion_manager.register("fake_quant_per_layer") @fusion_manager.register("fake_quant_per_layer")
def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max, def fake_quant_per_layer_compute(x, min_val, max_val, y, quant_min, quant_max, symmetric,
kernel_name="fake_quant_per_layer"): kernel_name="fake_quant_per_layer"):
"""FakeQuantPerLayer""" """FakeQuantPerLayer"""
shape = te.lang.cce.util.shape_to_list(x.shape) shape = te.lang.cce.util.shape_to_list(x.shape)
shape_min = te.lang.cce.util.shape_to_list(min_val.shape) shape_min = te.lang.cce.util.shape_to_list(min_val.shape)
quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype) quant_min = te.lang.cce.broadcast(quant_min, shape_min, x.dtype)
quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype) quant_max = te.lang.cce.broadcast(quant_max, shape_min, x.dtype)
min_val = te.lang.cce.broadcast(min_val, shape_min, x.dtype) if symmetric:
max_val = te.lang.cce.broadcast(max_val, shape_min, x.dtype) max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val)
min_val = te.lang.cce.vmuls(max_val, -1.)
# CalNudge(NudgeMinMax) # CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub( scale = te.lang.cce.vdiv(te.lang.cce.vsub(
@ -119,10 +120,6 @@ def fake_quant_per_layer(x, min_val, max_val, y,
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape) shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0 quant_min = 0
quant_max = 2 ** num_bits - 1 quant_max = 2 ** num_bits - 1
if narrow_range: if narrow_range:
@ -132,7 +129,7 @@ def fake_quant_per_layer(x, min_val, max_val, y,
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_per_layer_compute(input_data, min_data, max_data, y, res = fake_quant_per_layer_compute(input_data, min_data, max_data, y,
quant_min, quant_max, kernel_name) quant_min, quant_max, symmetric, kernel_name)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res) sch = generic.auto_schedule(res)

@ -78,7 +78,7 @@ def _fake_quant_per_layer_grad_tbe():
@fusion_manager.register("fake_quant_per_layer_grad") @fusion_manager.register("fake_quant_per_layer_grad")
def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quant_max, symmetric,
kernel_name="fake_quant_per_layer_grad"): kernel_name="fake_quant_per_layer_grad"):
"""FakeQuantPerLayerGrad""" """FakeQuantPerLayerGrad"""
shape = te.lang.cce.util.shape_to_list(x.shape) shape = te.lang.cce.util.shape_to_list(x.shape)
@ -88,6 +88,10 @@ def fake_quant_per_layer_grad_compute(dout, x, min_val, max_val, quant_min, quan
quant_min = te.lang.cce.broadcast(quant_min, shape_min) quant_min = te.lang.cce.broadcast(quant_min, shape_min)
quant_max = te.lang.cce.broadcast(quant_max, shape_min) quant_max = te.lang.cce.broadcast(quant_max, shape_min)
if symmetric:
max_val = te.lang.cce.vmax(te.lang.cce.vmuls(min_val, -1.), max_val)
min_val = te.lang.cce.vmuls(max_val, -1.)
# CalNudge(NudgeMinMax) # CalNudge(NudgeMinMax)
scale = te.lang.cce.vdiv(te.lang.cce.vsub( scale = te.lang.cce.vdiv(te.lang.cce.vsub(
max_val, min_val), te.lang.cce.vsub(quant_max, quant_min)) max_val, min_val), te.lang.cce.vsub(quant_max, quant_min))
@ -142,10 +146,6 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx,
input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),) input_shape = (functools_reduce(lambda x, y: x * y, input_shape[:]),)
shape_min, _, _ = util.produce_shapes(min_shape, input_shape) shape_min, _, _ = util.produce_shapes(min_shape, input_shape)
if symmetric:
quant_min = 0 - 2 ** (num_bits - 1)
quant_max = 2 ** (num_bits - 1) - 1
else:
quant_min = 0 quant_min = 0
quant_max = 2 ** num_bits - 1 quant_max = 2 ** num_bits - 1
if narrow_range: if narrow_range:
@ -155,8 +155,8 @@ def fake_quant_per_layer_grad(dout, x, min_val, max_val, dx,
input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype) input_data = tvm.placeholder(input_shape, name="x", dtype=x_dtype)
min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype) min_data = tvm.placeholder(shape_min, name="min_data", dtype=min_dtype)
max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype) max_data = tvm.placeholder(shape_min, name="max_data", dtype=max_dtype)
res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data, quant_min, res = fake_quant_per_layer_grad_compute(dout_data, input_data, min_data, max_data,
quant_max, kernel_name) quant_min, quant_max, symmetric, kernel_name)
with tvm.target.cce(): with tvm.target.cce():
sch = generic.auto_schedule(res) sch = generic.auto_schedule(res)

@ -58,7 +58,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
BiasAdd, Conv2D, BiasAdd, Conv2D,
DepthwiseConv2dNative, DepthwiseConv2dNative,
DropoutDoMask, DropoutGrad, Dropout, DropoutDoMask, DropoutGrad, Dropout,
DropoutGenMask, Flatten, FusedBatchNorm, DropoutGenMask, Flatten, FusedBatchNorm, BNTrainingReduce, BNTrainingUpdate,
Gelu, Elu, Gelu, Elu,
GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss, GetNext, L2Normalize, LayerNorm, L2Loss, CTCLoss,
LogSoftmax, LogSoftmax,
@ -76,7 +76,6 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
CheckValid, MakeRefKey, Partial, Depend, CheckBprop) CheckValid, MakeRefKey, Partial, Depend, CheckBprop)
from . import _quant_ops
from ._quant_ops import * from ._quant_ops import *
from .thor_ops import * from .thor_ops import *
@ -101,6 +100,9 @@ __all__ = [
'Conv2D', 'Conv2D',
'Flatten', 'Flatten',
'MaxPoolWithArgmax', 'MaxPoolWithArgmax',
'FusedBatchNorm',
'BNTrainingReduce',
'BNTrainingUpdate',
'BatchNorm', 'BatchNorm',
'MaxPool', 'MaxPool',
'TopK', 'TopK',
@ -313,5 +315,4 @@ __all__ = [
"DataFormatDimMap" "DataFormatDimMap"
] ]
__all__.extend(_quant_ops.__all__)
__all__.sort() __all__.sort()

@ -35,7 +35,6 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFold2Grad", "BatchNormFold2Grad",
"BatchNormFoldD", "BatchNormFoldD",
"BatchNormFoldGradD", "BatchNormFoldGradD",
"BNTrainingReduce",
"BatchNormFold2_D", "BatchNormFold2_D",
"BatchNormFold2GradD", "BatchNormFold2GradD",
"BatchNormFold2GradReduce", "BatchNormFold2GradReduce",
@ -333,7 +332,7 @@ class BatchNormFold(PrimitiveWithInfer):
Batch normalization folded. Batch normalization folded.
Args: Args:
momentum (float): Momentum value should be [0, 1]. Default: 0.1. momentum (float): Momentum value should be [0, 1]. Default: 0.9.
epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in epsilon (float): A small float number to avoid dividing by 0. 1e-5 if dtype in
float32 else 1e-3. Default: 1e-5. float32 else 1e-3. Default: 1e-5.
is_training (bool): In training mode set True, else set False. Default: True. is_training (bool): In training mode set True, else set False. Default: True.
@ -365,7 +364,7 @@ class BatchNormFold(PrimitiveWithInfer):
channel_axis = 1 channel_axis = 1
@prim_attr_register @prim_attr_register
def __init__(self, momentum=0.1, epsilon=1e-5, is_training=True, freeze_bn=0): def __init__(self, momentum=0.9, epsilon=1e-5, is_training=True, freeze_bn=0):
"""init batch norm fold layer""" """init batch norm fold layer"""
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name) self.epsilon = validator.check_float_positive('epsilon', epsilon, self.name)
@ -697,32 +696,6 @@ class BatchNormFoldGradD(PrimitiveWithInfer):
return x_type return x_type
class BNTrainingReduce(PrimitiveWithInfer):
"""
reduce sum at axis [0, 2, 3].
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
Outputs:
- **x_sum** (Tensor) - Tensor has the same shape as x.
- **x_square_sum** (Tensor) - Tensor has the same shape as x.
"""
@prim_attr_register
def __init__(self):
"""init _BNTrainingReduce layer"""
self.init_prim_io_names(inputs=['x'],
outputs=['x_sum', 'x_square_sum'])
def infer_shape(self, x_shape):
return [x_shape[1]], [x_shape[1]]
def infer_dtype(self, x_type):
return x_type, x_type
class BatchNormFold2_D(PrimitiveWithInfer): class BatchNormFold2_D(PrimitiveWithInfer):
""" """
Scale the bias with a correction factor to the long term statistics Scale the bias with a correction factor to the long term statistics

@ -585,6 +585,50 @@ class FusedBatchNorm(Primitive):
self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name) self.momentum = validator.check_number_range('momentum', momentum, 0, 1, Rel.INC_BOTH, self.name)
class BNTrainingReduce(PrimitiveWithInfer):
"""
reduce sum at axis [0, 2, 3].
Inputs:
- **x** (Tensor) - Tensor of shape :math:`(N, C)`.
Outputs:
- **sum** (Tensor) - Tensor of shape :math:`(C,)`.
- **square_sum** (Tensor) - Tensor of shape :math:`(C,)`.
"""
@prim_attr_register
def __init__(self):
self.init_prim_io_names(inputs=['x'], outputs=['sum', 'square_sum'])
def infer_shape(self, x_shape):
validator.check_integer("x rank", len(x_shape), 4, Rel.EQ, self.name)
return ([x_shape[1]], [x_shape[1]])
def infer_dtype(self, x_type):
return (x_type, x_type)
class BNTrainingUpdate(PrimitiveWithInfer):
"""
primitive operator of bn_training_update's register and info descriptor
"""
@prim_attr_register
def __init__(self, isRef=True, epsilon=1e-5, factor=0.1):
self.init_prim_io_names(inputs=['x', 'sum', 'square_sum', 'scale', 'b', 'mean', 'variance'],
outputs=['y', 'running_mean', 'running_variance', 'save_mean', 'save_inv_variance'])
#self.isRef = validator.check_integer('isRef', isRef, [0, 1], Rel.IN)
self.epsilon = validator.check_number_range('epsilon', epsilon, 0, 1, Rel.INC_RIGHT, 'BNTrainingUpdate')
self.factor = validator.check_number_range('factor', factor, 0, 1, Rel.INC_BOTH, 'BNTrainingUpdate')
def infer_shape(self, x, sum, square_sum, scale, b, mean, variance):
return (x, variance, variance, variance, variance)
def infer_dtype(self, x, sum, square_sum, scale, b, mean, variance):
return (x, variance, variance, variance, variance)
class BatchNorm(PrimitiveWithInfer): class BatchNorm(PrimitiveWithInfer):
r""" r"""
Batch Normalization for input data and updated parameters. Batch Normalization for input data and updated parameters.

@ -28,7 +28,7 @@ context.set_context(device_target='GPU')
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.op = P.BatchNormFold(freeze_bn=10) self.op = P.BatchNormFold(momentum=0.9, freeze_bn=10)
@ms_function @ms_function
def construct(self, x, mean, variance, current_step): def construct(self, x, mean, variance, current_step):
@ -40,8 +40,8 @@ def np_result(x, mean, var, momentum, epsilon):
np_mean = x.mean(axis=(0, 2, 3)) np_mean = x.mean(axis=(0, 2, 3))
np_var = x.var(axis=(0, 2, 3)) np_var = x.var(axis=(0, 2, 3))
n = x.shape[0] * x.shape[2] * x.shape[3] n = x.shape[0] * x.shape[2] * x.shape[3]
mean_update = momentum * np_mean + (1 - momentum) * mean mean_update = (1 - momentum) * np_mean + momentum * mean
var_update = momentum * np_var * n / (n - 1) + (1 - momentum) * var var_update = (1 - momentum) * np_var * n / (n - 1) + momentum * var
np_var = np.sqrt(np_var + epsilon) np_var = np.sqrt(np_var + epsilon)
delay_mean = mean.copy() delay_mean = mean.copy()
delay_std = np.sqrt(var + epsilon) delay_std = np.sqrt(var + epsilon)

Loading…
Cancel
Save