|
|
|
@ -141,8 +141,7 @@ class FakeQuantWithMinMax(Cell):
|
|
|
|
|
out_channels=1,
|
|
|
|
|
quant_delay=0,
|
|
|
|
|
symmetric=False,
|
|
|
|
|
narrow_range=False,
|
|
|
|
|
training=True):
|
|
|
|
|
narrow_range=False):
|
|
|
|
|
"""init FakeQuantWithMinMax layer"""
|
|
|
|
|
super(FakeQuantWithMinMax, self).__init__()
|
|
|
|
|
self.min_init = min_init
|
|
|
|
@ -156,7 +155,6 @@ class FakeQuantWithMinMax(Cell):
|
|
|
|
|
self.quant_delay = quant_delay
|
|
|
|
|
self.symmetric = symmetric
|
|
|
|
|
self.narrow_range = narrow_range
|
|
|
|
|
self.training = training
|
|
|
|
|
self.is_ascend = context.get_context('device_target') == "Ascend"
|
|
|
|
|
|
|
|
|
|
# init tensor min and max for fake quant op
|
|
|
|
@ -190,7 +188,7 @@ class FakeQuantWithMinMax(Cell):
|
|
|
|
|
symmetric=self.symmetric,
|
|
|
|
|
narrow_range=self.narrow_range,
|
|
|
|
|
training=self.training)
|
|
|
|
|
if self.ema:
|
|
|
|
|
if self.training:
|
|
|
|
|
self.ema_update = ema_fun(num_bits=self.num_bits,
|
|
|
|
|
ema=self.ema,
|
|
|
|
|
ema_decay=self.ema_decay,
|
|
|
|
@ -206,7 +204,7 @@ class FakeQuantWithMinMax(Cell):
|
|
|
|
|
return s
|
|
|
|
|
|
|
|
|
|
def construct(self, x):
|
|
|
|
|
if self.ema and self.is_ascend:
|
|
|
|
|
if self.is_ascend and self.training:
|
|
|
|
|
min_up, max_up = self.ema_update(x, self.minq, self.maxq)
|
|
|
|
|
out = self.fake_quant(x, min_up, max_up)
|
|
|
|
|
P.Assign()(self.minq, min_up)
|
|
|
|
|