!8892 Add Conv2dBnFoldQuantOneConv function(gpu and ascend)

From: @xiaoyisd
Reviewed-by: @sanjaychan,@jjfeing
Signed-off-by: @sanjaychan
pull/8892/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 534cc9bbe9

@ -419,12 +419,17 @@ class Conv2dBnFoldQuantOneConv(Cell):
self.fake = fake self.fake = fake
self.quant_config = quant_config self.quant_config = quant_config
self.quant_dtype = quant_dtype self.quant_dtype = quant_dtype
data_format = 'NCHW'
self.format = Validator.check_string(data_format, ['NCHW', 'NHWC'], 'format', self.cls_name)
self.is_gpu = context.get_context('device_target') == "GPU" self.is_gpu = context.get_context('device_target') == "GPU"
self.is_Ascend = context.get_context('device_target') == "Ascend" self.is_Ascend = context.get_context('device_target') == "Ascend"
self.is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
if context.get_context("enable_ge"): if context.get_context("enable_ge"):
self.is_ge_backend = True self.is_ge_backend = True
else: else:
self.is_ge_backend = False self.is_ge_backend = False
self.enable_default_train = self.is_graph_mode and \
(self.is_ge_backend or self.is_ascend)
# initialize convolution op and Parameter # initialize convolution op and Parameter
if context.get_context('device_target') == "Ascend" and group > 1: if context.get_context('device_target') == "Ascend" and group > 1:
@ -448,6 +453,7 @@ class Conv2dBnFoldQuantOneConv(Cell):
group=group) group=group)
weight_shape = [out_channels, in_channels // group, *self.kernel_size] weight_shape = [out_channels, in_channels // group, *self.kernel_size]
channel_axis = 0 channel_axis = 0
self.channel_axis = channel_axis
self.weight = Parameter(initializer(weight_init, weight_shape), name='weight') self.weight = Parameter(initializer(weight_init, weight_shape), name='weight')
self.bias_add = P.BiasAdd() self.bias_add = P.BiasAdd()
if Validator.check_bool(has_bias): if Validator.check_bool(has_bias):
@ -490,7 +496,6 @@ class Conv2dBnFoldQuantOneConv(Cell):
self.mul_var = P.Mul().shard(data_parallel_strategy_one) self.mul_var = P.Mul().shard(data_parallel_strategy_one)
self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy) self.assign_sub_mean = P.AssignSub().shard(data_parallel_strategy)
self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy) self.assign_sub_var = P.AssignSub().shard(data_parallel_strategy)
self.one = Tensor(1, mstype.int32)
self.reshape = P.Reshape() self.reshape = P.Reshape()
def extend_repr(self): def extend_repr(self):
@ -507,18 +512,22 @@ class Conv2dBnFoldQuantOneConv(Cell):
def construct(self, x): def construct(self, x):
running_std = P.Sqrt()(P.TensorAdd()(self.moving_variance, self.eps)) running_std = P.Sqrt()(P.TensorAdd()(self.moving_variance, self.eps))
scale_factor = self.gamma / running_std scale_factor = self.gamma / running_std
weight = self.weight * scale_factor
if self.channel_axis: if self.channel_axis:
scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
else: else:
scale_factor = self.reshape(scale_factor, (-1, 1, 1, 1)) scale_factor = self.reshape(scale_factor, (-1, 1, 1, 1))
weight = self.weight * scale_factor
if self.fake: if self.fake:
weight = self.fake_quant_weight(weight) weight = self.fake_quant_weight(weight)
conv = self.conv(x, weight) conv = self.conv(x, weight)
scale_factor = self.reshape(scale_factor, (1, -1, 1, 1)) scale_factor = self.reshape(scale_factor, (1, -1, 1, 1))
conv_orig = conv / scale_factor if self.enable_default_train:
scale_factor = P.Reciprocal()(scale_factor)
conv_orig = conv * scale_factor
else:
conv_orig = conv / scale_factor
if self.training: if self.training:
if not self.is_gpu: if self.enable_default_train:
out, batch_mean, batch_var, _, _ = self.bn_train(conv_orig, out, batch_mean, batch_var, _, _ = self.bn_train(conv_orig,
self.gamma, self.gamma,
self.beta, self.beta,
@ -531,20 +540,19 @@ class Conv2dBnFoldQuantOneConv(Cell):
temp_variance = self.mul_var(mean_sub2, self.momentum) temp_variance = self.mul_var(mean_sub2, self.momentum)
out = F.depend(out, self.assign_sub_mean(self.moving_mean, temp_mean)) out = F.depend(out, self.assign_sub_mean(self.moving_mean, temp_mean))
out = F.depend(out, self.assign_sub_var(self.moving_variance, temp_variance)) out = F.depend(out, self.assign_sub_var(self.moving_variance, temp_variance))
else: return out
out = self.bn_train(conv_orig,
self.gamma, return self.bn_train(conv_orig,
self.beta, self.gamma,
self.moving_mean, self.beta,
self.moving_variance)[0] self.moving_mean,
else: self.moving_variance)[0]
out = self.bn_infer(conv_orig,
self.gamma, return self.bn_infer(conv_orig,
self.beta, self.gamma,
self.moving_mean, self.beta,
self.moving_variance)[0] self.moving_mean,
self.moving_variance)[0]
return out
class Conv2dBnFoldQuant(Cell): class Conv2dBnFoldQuant(Cell):

Loading…
Cancel
Save