diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 31fbbb9651..ae01cab882 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -141,8 +141,7 @@ class FakeQuantWithMinMax(Cell): out_channels=1, quant_delay=0, symmetric=False, - narrow_range=False, - training=True): + narrow_range=False): """init FakeQuantWithMinMax layer""" super(FakeQuantWithMinMax, self).__init__() self.min_init = min_init @@ -156,7 +155,6 @@ class FakeQuantWithMinMax(Cell): self.quant_delay = quant_delay self.symmetric = symmetric self.narrow_range = narrow_range - self.training = training self.is_ascend = context.get_context('device_target') == "Ascend" # init tensor min and max for fake quant op @@ -190,7 +188,7 @@ class FakeQuantWithMinMax(Cell): symmetric=self.symmetric, narrow_range=self.narrow_range, training=self.training) - if self.ema: + if self.training: self.ema_update = ema_fun(num_bits=self.num_bits, ema=self.ema, ema_decay=self.ema_decay, @@ -206,7 +204,7 @@ class FakeQuantWithMinMax(Cell): return s def construct(self, x): - if self.ema and self.is_ascend: + if self.is_ascend and self.training: min_up, max_up = self.ema_update(x, self.minq, self.maxq) out = self.fake_quant(x, min_up, max_up) P.Assign()(self.minq, min_up) diff --git a/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py b/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py index fc0ea2545e..e79217f521 100644 --- a/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py +++ b/mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py @@ -38,12 +38,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \ .input(3, "running_std", None, "required", None) \ .output(0, "dx", True, "required", "all") \ .output(1, "d_batch_std", True, "required", "all") \ - .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, - DataType.F16_Default, DataType.F16_Default) \ - .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, - DataType.F16_5HD, DataType.F16_5HD) \ - .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, - DataType.F32_Default, DataType.F32_Default) \ .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ .get_op_info() diff --git a/mindspore/train/quant/quant.py b/mindspore/train/quant/quant.py index e2a035bc77..ff40426931 100644 --- a/mindspore/train/quant/quant.py +++ b/mindspore/train/quant/quant.py @@ -247,7 +247,7 @@ def convert_quant_network(network, network (Cell): Obtain a pipeline through network for saving graph summary. quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0. bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False. - freeze_bn (bool): Number of steps after which BN parameters used total mean and variance. Default: 0. + freeze_bn (int): Number of steps after which BN parameters used total mean and variance. Default: 0. weight_bits (int): Number of bits to use for quantizing weights. Default: 8. act_bits (int): Number of bits to use for quantizing activations. Default: 8. per_channel (bool): Quantization granularity based on layer or on channel. Default: False.