From 59c4f31fc04c7019ff96a80442f06fd2f98f67a7 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Mon, 10 Aug 2020 09:15:41 +0800 Subject: [PATCH] [MS][LITE] fix bug of arm cpu int8 op: conv_depthwise --- .../arm/nnacl/int8/conv_depthwise_int8.cc | 33 +++++++++---------- mindspore/lite/tools/common/node_util.cc | 2 +- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.cc index 784b754cd8..11167a165b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/int8/conv_depthwise_int8.cc @@ -159,18 +159,17 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; - const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * C4NUM; - int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * C4NUM; + const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; #ifdef ENABLE_ARM64 - ConvDwInt8Center( - out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, - conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int8_t), - sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int16_t), - sliding->in_sw_step_ * sizeof(int16_t), sliding->in_kh_step_ * sizeof(int16_t), - sliding->in_kw_step_ * sizeof(int16_t), conv_param->conv_quant_arg_.quant_multiplier_[0], - conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], - conv_param->conv_quant_arg_.quant_args_[2][0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], - conv_param->conv_quant_arg_.out_act_max_[0]); + ConvDwInt8Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int8_t), + sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int16_t), + sliding->in_sw_step_ * sizeof(int16_t), sliding->in_kh_step_ * sizeof(int16_t), + sliding->in_kw_step_ * sizeof(int16_t), conv_param->conv_quant_arg_.quant_multiplier_[0], + conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], + conv_param->conv_quant_arg_.quant_args_[2][0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], + conv_param->conv_quant_arg_.out_act_max_[0]); #else DepthwiseCenterInt8( out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, @@ -315,15 +314,15 @@ void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *in if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { int oh_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_h_; int oh_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; - int32_t *out_t = output_buffer + oh_h_start * sliding->in_h_step_ + oh_w_start * C4NUM; + int32_t *out_t = output_buffer + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; const int16_t *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; #ifdef ENABLE_ARM64 - DeconvDwInt8Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, - sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, - sliding->out_h_step_ * sizeof(int16_t), sliding->block_channel_ * sizeof(int16_t), - sliding->in_sh_step_ * sizeof(int32_t), sliding->in_sw_step_ * sizeof(int32_t), - sliding->in_kh_step_ * sizeof(int32_t), sliding->in_kw_step_ * sizeof(int32_t)); + DeconvDwInt8Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int16_t), + sliding->block_channel_ * sizeof(int16_t), sliding->in_sh_step_ * sizeof(int32_t), + sliding->in_sw_step_ * sizeof(int32_t), sliding->in_kh_step_ * sizeof(int32_t), + sliding->in_kw_step_ * sizeof(int32_t)); #else DeconvDepthwiseCenterInt8(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 2f3312a8cb..16c146052e 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -96,7 +96,7 @@ static const std::vector nhwcOpList = { schema::PrimitiveType_Conv2D, schema::PrimitiveType_DeConv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_DeDepthwiseConv2D, schema::PrimitiveType_Pooling, schema::PrimitiveType_Resize, - schema::PrimitiveType_FusedBatchNorm}; + schema::PrimitiveType_BatchNorm}; static const std::vector fp32FullOpList = { schema::PrimitiveType_Concat, schema::PrimitiveType_Add,