From 0e60235107640ece611fb8305ab38451edf3d054 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Thu, 25 Mar 2021 13:27:25 +0800 Subject: [PATCH] [MSLITE] Fix the bug of matmul int8 for arm32 --- mindspore/lite/nnacl/int8/matmul_int8.c | 13 ++++++++++ mindspore/lite/nnacl/int8/matmul_int8.h | 1 + .../kernel/arm/int8/matmul_base_int8.cc | 25 +++++++++++++++---- .../kernel/arm/int8/matmul_base_int8.h | 2 ++ 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index 857f640001..ac80749c0c 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -32,6 +32,19 @@ void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int co } } +void RowMajor2Col16x2MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int row16 = UP_ROUND(row, C16NUM); + int stride = sizeof(int8_t) * C16NUM * C2NUM; + for (int r = 0; r < row; ++r) { + for (int c = 0; c < col; ++c) { + int stride_idx = c / C2NUM * (row16 / C16NUM) + r / C16NUM; + int dst_idx = stride * stride_idx + c % C2NUM * C16NUM + r % C16NUM; + int src_idx = r * col + c; + dst_ptr[dst_idx] = src_ptr[src_idx]; + } + } +} + void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { int col4 = UP_ROUND(col, C4NUM); for (int r = 0; r < row; r++) { diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index c10c1b6149..2389e98ed9 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -48,6 +48,7 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, /* 4x16 16x2 -> 4x2 */ /* arm32 conv1x1 */ void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void RowMajor2Col16x2MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.cc index ac8fdad9b5..c711b130e3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.cc @@ -33,7 +33,7 @@ int MatmulBaseInt8Run(void *cdata, int task_id) { } int MatmulBaseInt8CPUKernel::RunImpl(int task_id) { - int stride = thread_stride_ * C4NUM; + int stride = thread_stride_ * col_tile_; int cur_stride = task_id * stride; int res_stride = param_->col_ - cur_stride; int cur_oc = MSMIN(stride, res_stride); @@ -155,16 +155,23 @@ void MatmulBaseInt8CPUKernel::InitQuantParam() { void MatmulBaseInt8CPUKernel::InitParameter() { param_->a_const_ = (in_tensors_[0]->data_c() != nullptr); param_->b_const_ = (in_tensors_[1]->data_c() != nullptr); +#ifdef ENABLE_ARM32 + row_tile_ = C4NUM; + col_tile_ = C2NUM; +#else + row_tile_ = C4NUM; + col_tile_ = C4NUM; +#endif return; } void MatmulBaseInt8CPUKernel::ResizeParameter() { - param_->row_align_ = UP_ROUND(param_->row_, C4NUM); - param_->col_align_ = UP_ROUND(param_->col_, C4NUM); + param_->row_align_ = UP_ROUND(param_->row_, row_tile_); + param_->col_align_ = UP_ROUND(param_->col_, col_tile_); param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM); - thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, C4NUM)); - thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, C4NUM), thread_count_); + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_)); + thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_); return; } @@ -195,11 +202,19 @@ void MatmulBaseInt8CPUKernel::TransferB() { auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_16_; auto current_sums = weight_bias_sums_ + i * param_->col_align_; if (param_->b_transpose_) { +#ifdef ENABLE_ARM32 + RowMajor2Row2x16MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_); +#else RowMajor2Row16x4MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_); +#endif CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_, current_sums, ColMajor, filter_per_channel_); } else { +#ifdef ENABLE_ARM32 + RowMajor2Col16x2MajorInt8(current_weight, current_b_pack, param_->deep_, param_->col_); +#else RowMajor2Col16x4MajorInt8(current_weight, param_->deep_, param_->col_, current_b_pack); +#endif CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_, current_sums, RowMajor, false); } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.h index ab96da216b..b62e7820ed 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/matmul_base_int8.h @@ -75,6 +75,8 @@ class MatmulBaseInt8CPUKernel : public LiteKernel { int8_t *batch_b_ptr_ = nullptr; int8_t *batch_c_ptr_ = nullptr; int *batch_sums_ = nullptr; + int row_tile_ = C4NUM; + int col_tile_ = C4NUM; }; } // namespace mindspore::kernel