From 827d5443de6cadce25d88be512a2fc4851cbce20 Mon Sep 17 00:00:00 2001 From: chenfei Date: Fri, 11 Sep 2020 11:08:50 +0800 Subject: [PATCH] display quant delay in ir when run with ascend --- mindspore/nn/layer/quant.py | 43 ++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index 79661f72e1..a70b9058d4 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -372,7 +372,8 @@ class FakeQuantWithMinMax(Cell): if self.is_ascend: self.fake_quant_train = quant_fun(num_bits=self.num_bits, symmetric=self.symmetric, - narrow_range=self.narrow_range) + narrow_range=self.narrow_range, + quant_delay=self.quant_delay) self.fake_quant_infer = self.fake_quant_train else: quant_fun = partial(quant_fun, @@ -679,28 +680,40 @@ class Conv2dBnWithoutFoldQuant(Cell): self.group = group self.quant_delay = quant_delay - weight_shape = [out_channels, in_channels // group, *self.kernel_size] - self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') - self.bias_add = P.BiasAdd() if check_bool(has_bias): self.bias = Parameter(initializer(bias_init, [out_channels]), name='bias') else: self.bias = None - - self.conv = P.Conv2D(out_channel=self.out_channels, - kernel_size=self.kernel_size, - mode=1, - pad_mode=self.pad_mode, - pad=self.padding, - stride=self.stride, - dilation=self.dilation, - group=self.group) + # initialize convolution op and Parameter + if context.get_context('device_target') == "Ascend" and group > 1: + validator.check_integer('group', group, in_channels, Rel.EQ) + validator.check_integer('group', group, out_channels, Rel.EQ) + self.conv = P.DepthwiseConv2dNative(channel_multiplier=1, + kernel_size=self.kernel_size, + pad_mode=pad_mode, + pad=padding, + stride=self.stride, + dilation=self.dilation) + weight_shape = [1, in_channels, *self.kernel_size] + channel_axis = 1 + else: + self.conv = P.Conv2D(out_channel=self.out_channels, + kernel_size=self.kernel_size, + mode=1, + pad_mode=self.pad_mode, + pad=self.padding, + stride=self.stride, + dilation=self.dilation, + group=self.group) + weight_shape = [out_channels, in_channels // group, *self.kernel_size] + channel_axis = 0 + self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.fake_quant_weight = FakeQuantWithMinMax(min_init=-6, max_init=6, ema=False, per_channel=per_channel, - channel_axis=0, + channel_axis=channel_axis, num_channels=out_channels, num_bits=num_bits, symmetric=symmetric, @@ -1009,6 +1022,7 @@ class ActQuant(_QuantActivation): def get_origin(self): return self.act + class LeakyReLUQuant(_QuantActivation): r""" LeakyReLUQuant activation function. Add Fake Quant OP after HSwish OP. @@ -1078,7 +1092,6 @@ class LeakyReLUQuant(_QuantActivation): return self.act - class HSwishQuant(_QuantActivation): r""" HSwishQuant activation function. Add Fake Quant OP after HSwish OP.