From e518921cf45b3f9f26a213290ed2facefa4a286d Mon Sep 17 00:00:00 2001 From: xiaoyisd Date: Mon, 23 Nov 2020 11:21:14 +0800 Subject: [PATCH] add one conv ascend --- mindspore/nn/layer/quant.py | 44 ++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/mindspore/nn/layer/quant.py b/mindspore/nn/layer/quant.py index fead4d6ed9..24a54e78e1 100644 --- a/mindspore/nn/layer/quant.py +++ b/mindspore/nn/layer/quant.py @@ -419,12 +419,17 @@ class Conv2dBnFoldQuantOneConv(Cell): self.fake = fake self.quant_config = quant_config self.quant_dtype = quant_dtype + data_format = 'NCHW' + self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name) self.is_gpu = context.get_context('device_target') == "GPU" self.is_Ascend = context.get_context('device_target') == "Ascend" + self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE if context.get_context("enable_ge"): self.is_ge_backend = True else: self.is_ge_backend = False + self.enable_default_train = self.is_graph_mode and \ + (self.is_ge_backend or self.is_ascend) # initialize convolution op and Parameter if context.get_context('device_target') == "Ascend" and group > 1: @@ -448,6 +453,7 @@ class Conv2dBnFoldQuantOneConv(Cell): group=group) weight_shape = [out_channels, in_channels // group, *self.kernel_size] channel_axis = 0 + self.channel_axis = channel_axis self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.bias_add = P.BiasAdd() if Validator.check_bool(has_bias): @@ -490,7 +496,6 @@ class Conv2dBnFoldQuantOneConv(Cell): self.mul_var = P.Mul().shard(data_parallel_strategy_one) self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy) - self.one = Tensor(1, mstype.int32) self.reshape = P.Reshape() def extend_repr(self): @@ -507,18 +512,22 @@ class Conv2dBnFoldQuantOneConv(Cell): def construct(self, x): running_std = P.Sqrt()(P.TensorAdd()(self.moving_variance, self.eps)) scale_factor = self.gamma / running_std - weight = self.weight * scale_factor if self.channel_axis: scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) else: scale_factor = self.reshape(scale_factor, (-1, 1, 1, 1)) + weight = self.weight * scale_factor if self.fake: weight = self.fake_quant_weight(weight) conv = self.conv(x, weight) scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) - conv_orig = conv / scale_factor + if self.enable_default_train: + scale_factor = P.Reciprocal()(scale_factor) + conv_orig = conv * scale_factor + else: + conv_orig = conv / scale_factor if self.training: - if not self.is_gpu: + if self.enable_default_train: out, batch_mean, batch_var, _, _ = self.bn_train(conv_orig, self.gamma, self.beta, @@ -531,20 +540,19 @@ class Conv2dBnFoldQuantOneConv(Cell): temp_variance = self.mul_var(mean_sub2, self.momentum) out = F.depend(out, self.assign_sub_mean(self.moving_mean, temp_mean)) out = F.depend(out, self.assign_sub_var(self.moving_variance, temp_variance)) - else: - out = self.bn_train(conv_orig, - self.gamma, - self.beta, - self.moving_mean, - self.moving_variance)[0] - else: - out = self.bn_infer(conv_orig, - self.gamma, - self.beta, - self.moving_mean, - self.moving_variance)[0] - - return out + return out + + return self.bn_train(conv_orig, + self.gamma, + self.beta, + self.moving_mean, + self.moving_variance)[0] + + return self.bn_infer(conv_orig, + self.gamma, + self.beta, + self.moving_mean, + self.moving_variance)[0] class Conv2dBnFoldQuant(Cell):