[MSLITE] Fix the bug of matmul int8 for arm32

pull/14112/head
zhanyuan 4 years ago
parent 1aa16fb431
commit 0e60235107

@ -32,6 +32,19 @@ void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int co
}
}
void RowMajor2Col16x2MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
int row16 = UP_ROUND(row, C16NUM);
int stride = sizeof(int8_t) * C16NUM * C2NUM;
for (int r = 0; r < row; ++r) {
for (int c = 0; c < col; ++c) {
int stride_idx = c / C2NUM * (row16 / C16NUM) + r / C16NUM;
int dst_idx = stride * stride_idx + c % C2NUM * C16NUM + r % C16NUM;
int src_idx = r * col + c;
dst_ptr[dst_idx] = src_ptr[src_idx];
}
}
}
void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
int col4 = UP_ROUND(col, C4NUM);
for (int r = 0; r < row; r++) {

@ -48,6 +48,7 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
/* 4x16 16x2 -> 4x2 */
/* arm32 conv1x1 */
void RowMajor2Row2x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Col16x2MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,

@ -33,7 +33,7 @@ int MatmulBaseInt8Run(void *cdata, int task_id) {
}
int MatmulBaseInt8CPUKernel::RunImpl(int task_id) {
int stride = thread_stride_ * C4NUM;
int stride = thread_stride_ * col_tile_;
int cur_stride = task_id * stride;
int res_stride = param_->col_ - cur_stride;
int cur_oc = MSMIN(stride, res_stride);
@ -155,16 +155,23 @@ void MatmulBaseInt8CPUKernel::InitQuantParam() {
void MatmulBaseInt8CPUKernel::InitParameter() {
param_->a_const_ = (in_tensors_[0]->data_c() != nullptr);
param_->b_const_ = (in_tensors_[1]->data_c() != nullptr);
#ifdef ENABLE_ARM32
row_tile_ = C4NUM;
col_tile_ = C2NUM;
#else
row_tile_ = C4NUM;
col_tile_ = C4NUM;
#endif
return;
}
void MatmulBaseInt8CPUKernel::ResizeParameter() {
param_->row_align_ = UP_ROUND(param_->row_, C4NUM);
param_->col_align_ = UP_ROUND(param_->col_, C4NUM);
param_->row_align_ = UP_ROUND(param_->row_, row_tile_);
param_->col_align_ = UP_ROUND(param_->col_, col_tile_);
param_->deep_16_ = UP_ROUND(param_->deep_, C16NUM);
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, C4NUM));
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, C4NUM), thread_count_);
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(param_->col_align_, col_tile_));
thread_stride_ = UP_DIV(UP_DIV(param_->col_align_, col_tile_), thread_count_);
return;
}
@ -195,11 +202,19 @@ void MatmulBaseInt8CPUKernel::TransferB() {
auto current_b_pack = pack_b_ptr_ + i * param_->col_align_ * param_->deep_16_;
auto current_sums = weight_bias_sums_ + i * param_->col_align_;
if (param_->b_transpose_) {
#ifdef ENABLE_ARM32
RowMajor2Row2x16MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_);
#else
RowMajor2Row16x4MajorInt8(current_weight, current_b_pack, param_->col_, param_->deep_);
#endif
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_,
current_sums, ColMajor, filter_per_channel_);
} else {
#ifdef ENABLE_ARM32
RowMajor2Col16x2MajorInt8(current_weight, current_b_pack, param_->deep_, param_->col_);
#else
RowMajor2Col16x4MajorInt8(current_weight, param_->deep_, param_->col_, current_b_pack);
#endif
CalcWeightBiasSums(current_weight, param_->deep_, param_->col_, quant_.input_.zp_, quant_.filter_zp_, bias_ptr_,
current_sums, RowMajor, false);
}

@ -75,6 +75,8 @@ class MatmulBaseInt8CPUKernel : public LiteKernel {
int8_t *batch_b_ptr_ = nullptr;
int8_t *batch_c_ptr_ = nullptr;
int *batch_sums_ = nullptr;
int row_tile_ = C4NUM;
int col_tile_ = C4NUM;
};
} // namespace mindspore::kernel

Loading…
Cancel
Save