bug fix in fake quant training in r0.3

pull/1965/head
chenzomi 5 years ago
parent bd8c623b06
commit d7a4ae2e34

@ -141,8 +141,7 @@ class FakeQuantWithMinMax(Cell):
out_channels=1,
quant_delay=0,
symmetric=False,
narrow_range=False,
training=True):
narrow_range=False):
"""init FakeQuantWithMinMax layer"""
super(FakeQuantWithMinMax, self).__init__()
self.min_init = min_init
@ -156,7 +155,6 @@ class FakeQuantWithMinMax(Cell):
self.quant_delay = quant_delay
self.symmetric = symmetric
self.narrow_range = narrow_range
self.training = training
self.is_ascend = context.get_context('device_target') == "Ascend"
# init tensor min and max for fake quant op
@ -190,7 +188,7 @@ class FakeQuantWithMinMax(Cell):
symmetric=self.symmetric,
narrow_range=self.narrow_range,
training=self.training)
if self.ema:
if self.training:
self.ema_update = ema_fun(num_bits=self.num_bits,
ema=self.ema,
ema_decay=self.ema_decay,
@ -206,7 +204,7 @@ class FakeQuantWithMinMax(Cell):
return s
def construct(self, x):
if self.ema and self.is_ascend:
if self.is_ascend and self.training:
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
out = self.fake_quant(x, min_up, max_up)
P.Assign()(self.minq, min_up)

@ -38,12 +38,6 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
.input(3, "running_std", None, "required", None) \
.output(0, "dx", True, "required", "all") \
.output(1, "d_batch_std", True, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default,
DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD,
DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default,
DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD,
DataType.F32_5HD, DataType.F32_5HD) \
.get_op_info()

@ -247,7 +247,7 @@ def convert_quant_network(network,
network (Cell): Obtain a pipeline through network for saving graph summary.
quant_delay (int): Number of steps after which weights and activations are quantized during eval. Default: 0.
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
freeze_bn (bool): Number of steps after which BN parameters used total mean and variance. Default: 0.
freeze_bn (int): Number of steps after which BN parameters used total mean and variance. Default: 0.
weight_bits (int): Number of bits to use for quantizing weights. Default: 8.
act_bits (int): Number of bits to use for quantizing activations. Default: 8.
per_channel (bool): Quantization granularity based on layer or on channel. Default: False.

Loading…
Cancel
Save