From 1c3e57968730c5aad6fa1017d0136547c247c4b0 Mon Sep 17 00:00:00 2001 From: wandongdong Date: Tue, 2 Jun 2020 11:33:46 +0800 Subject: [PATCH] fix bug in quant and correction_mul_grad --- .../kernel/gpu/cuda_impl/fake_quant_impl.cu | 7 +++---- mindspore/nn/layer/quant.py | 16 ++++++++-------- .../_op_impl/_custom_op/correction_mul_grad.py | 4 ++-- .../_custom_op/fake_quant_with_min_max.py | 3 +-- 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu index f25727f2c3..db3f8a857f 100644 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu @@ -21,7 +21,7 @@ #include "fake_quant_impl.cuh" __global__ void FakeQuantize(const float *input, float *output, const int size, const float *nudge_min, - const float *nudge_max, const float *scale, bool symmetric) { + const float *nudge_max, const float *scale) { float input_x = 0.f; int nudge_input = 0; @@ -35,7 +35,7 @@ __global__ void FakeQuantize(const float *input, float *output, const int size, input_x = nudge_max[0]; } // clamp shift - nudge_input = floor((input_x - nudge_min[0]) / scale[0] + 0.5f); + nudge_input = round((input_x - nudge_min[0]) / scale[0]); // quantize output[i] = nudge_input * scale[0] + nudge_min[0]; @@ -99,8 +99,7 @@ __global__ void UpdateInputMinMax(float *input_min, float *input_max, const floa void CalFakeQuantize(const float *input, float *output, const int size, const float *nudge_min, const float *nudge_max, const float *scale, bool symmetric, cudaStream_t cuda_stream) { - FakeQuantize<<>>(input, output, size, nudge_min, nudge_max, scale, - symmetric); + FakeQuantize<<>>(input, output, size, nudge_min, nudge_max, scale); return; } diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 570c760d15..13421ce908 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -22,7 +22,7 @@ from mindspore.common.parameter import Parameter from mindspore.common.initializer import initializer from mindspore.common.tensor import Tensor from mindspore._checkparam import check_int_positive, check_bool, twice -from mindspore._checkparam import Validator as validator +from mindspore._checkparam import Validator as validator, Rel from mindspore.nn.cell import Cell from mindspore.nn.layer.activation import get_activation import mindspore.context as context @@ -207,7 +207,7 @@ class FakeQuantWithMinMaxD(Cell): class FakeQuantWithMinMax(Cell): r""" - Aware Quantization training op. This OP provide Fake quantization observer function on data with min and max. + Aware Quantization op. This OP provide Fake quantization observer function on data with min and max. Args: min_init (int, list): The dimension of channel or 1(layer). Default: -6. @@ -243,8 +243,7 @@ class FakeQuantWithMinMax(Cell): out_channels=1, quant_delay=0, symmetric=False, - narrow_range=False, - training=True): + narrow_range=False): """init FakeQuantWithMinMax layer""" super(FakeQuantWithMinMax, self).__init__() @@ -258,7 +257,6 @@ class FakeQuantWithMinMax(Cell): self.quant_delay = quant_delay self.symmetric = symmetric self.narrow_range = narrow_range - self.training = training if per_channel: min_array = np.array([self.min_init for i in range(0, self.out_channels)]).astype(np.float32) @@ -422,11 +420,13 @@ class Conv2dBatchNormQuant(Cell): self.per_channel = per_channel self.symmetric = symmetric self.narrow_range = narrow_range + self.channel_axis = int(group > 1) + self.is_gpu = context.get_context('device_target') == "GPU" # initialize convolution op and Parameter if context.get_context('device_target') == "Ascend" and group > 1: - validator.check_integer('group', group, 'in_channels', in_channels, 'Conv2dBatchNormQuant') - validator.check_integer('group', group, 'in_channels', out_channels, 'Conv2dBatchNormQuant') + validator.check_integer('group', group, in_channels, Rel.EQ, 'Conv2dBatchNormQuant') + validator.check_integer('group', group, out_channels, Rel.EQ, 'Conv2dBatchNormQuant') self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=self.kernel_size, pad_mode=pad_mode, @@ -472,7 +472,7 @@ class Conv2dBatchNormQuant(Cell): symmetric=symmetric, narrow_range=narrow_range) self.batchnorm_fold = BatchNormFoldCell(epsilon=eps, momentum=momentum, freeze_bn=freeze_bn) - self.correct_mul = P.CorrectionMul() + self.correct_mul = P.CorrectionMul(self.channel_axis) if context.get_context('device_target') == "Ascend": self.batchnorm_fold2_train = P.BatchNormFold2_D(freeze_bn=freeze_bn) self.batchnorm_fold2_infer = P.BatchNormFold2_D(freeze_bn=0) diff --git a/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py b/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py index 810ce7323c..fc0ea2545e 100644 --- a/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py @@ -93,8 +93,8 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe util.check_dtype_rule(inp_dtype_dout, ("float16", "float32")) util.check_dtype_rule(inp_dtype_x, ("float16", "float32")) - util.check_dtype_rule(inp_dtype_batch_std, ("float32",)) - util.check_dtype_rule(inp_dtype_running_std, ("float32",)) + util.check_dtype_rule(inp_dtype_batch_std, ("float16", "float32")) + util.check_dtype_rule(inp_dtype_running_std, ("float16", "float32")) util.compare_tensor_dict_key(dout, x, "dtype") util.compare_tensor_dict_key(dout, x, "shape") util.compare_tensor_dict_key(dx, x, "shape") diff --git a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py b/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py index 4afdf3a051..f35dfae39b 100644 --- a/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py +++ b/mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py @@ -80,8 +80,7 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min, # FakeQuant input_x = te.lang.cce.vmin(nudge_max, te.lang.cce.vmax(nudge_min, x)) - nudge_input = te.lang.cce.floor(te.lang.cce.vadds(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale), - 0.5)) + nudge_input = te.lang.cce.round(te.lang.cce.vdiv(te.lang.cce.vsub(input_x, nudge_min), scale)) res = te.lang.cce.vadd(te.lang.cce.vmul(nudge_input, scale), nudge_min) return res