diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index 9c0b2fc7fb..756bd13850 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -371,26 +371,265 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight } } -void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param, - MATMUL_OPT_R_FUNC matmul_func) { - if (matmul_func != NULL) { - matmul_func(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, - conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, - conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - (conv_param->conv_quant_arg_.filter_arg_num_ > 1)); +void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t output_channel, size_t plane_size, ConvParameter *conv_param) { + int ic4 = UP_ROUND(input_channel, C4NUM); + size_t hw_8div = plane_size / C8NUM * C8NUM; + size_t hw_8res = plane_size - hw_8div; + size_t ic_4div = input_channel / C4NUM * C4NUM; + int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; + + if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (int hwi = 0; hwi < hw_8div; hwi += C8NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_r = input_sum + hwi; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + asm volatile( + "dup v10.4s, wzr \n" + "dup v11.4s, wzr \n" + "mov x20, %[input_sum_r] \n" + "dup v20.4s, %w[filter_zp] \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x0, #0 \n" + "1: \n" + "cmp x0, %[ic_4div] \n" + "add x0, x0, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + "ld1 {v1.s}[0], [x12], %[src_stride]\n" + "ld1 {v1.s}[1], [x12], %[src_stride]\n" + "ld1 {v1.s}[2], [x12], %[src_stride]\n" + "ld1 {v1.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + + "add v10.4s, v10.4s, v0.4s \n" + "add v11.4s, v11.4s, v1.4s \n" + "b 1b \n" + + "3: \n" /* col res 1 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + "ld1 {v1.b}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[8], [x12], %[src_stride]\n" + "ld1 {v1.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v10.4s, v10.4s, v0.4s \n" + "add v11.4s, v11.4s, v1.4s \n" + "b 6f \n" + + "4: \n" /* col res 2 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v10.4s, v10.4s, v0.4s \n" + "add v11.4s, v11.4s, v1.4s \n" + "b 6f \n" + + "5: \n" /* col res 3 */ + "dup v0.4s, wzr \n" + "dup v1.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + "ld1 {v1.h}[0], [x12], %[src_stride]\n" + "ld1 {v1.b}[2], [x13], %[src_stride]\n" + "ld1 {v1.h}[2], [x12], %[src_stride]\n" + "ld1 {v1.b}[6], [x13], %[src_stride]\n" + "ld1 {v1.h}[4], [x12], %[src_stride]\n" + "ld1 {v1.b}[10], [x13], %[src_stride]\n" + "ld1 {v1.h}[6], [x12], %[src_stride]\n" + "ld1 {v1.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "st1 {v1.16b}, [x11], #16\n" + "saddlp v4.8h, v0.16b \n" + "saddlp v5.8h, v1.16b \n" + "saddlp v0.4s, v4.8h \n" + "saddlp v1.4s, v5.8h \n" + "add v10.4s, v10.4s, v0.4s \n" + "add v11.4s, v11.4s, v1.4s \n" + "b 6f \n" + + "6: \n" + "mul v10.4s, v10.4s, v20.4s \n" + "mul v11.4s, v11.4s, v20.4s \n" + + "st1 {v10.4s}, [x20], #16 \n" + "st1 {v11.4s}, [x20], #16 \n" + + : + : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), + [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), + [ filter_zp ] "r"(filter_zp) + : "x0", "x1", "x10", "x11", "x12", "x13", "x20", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v11", + "v20"); +#else + int32_t tmp_sum_value[8] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C8NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int i = 0; i < C8NUM; i++) { + input_sum_r[i] = tmp_sum_value[i] * filter_zp; + } +#endif + src_r += input_channel * C8NUM; + pack_r += ic4 * C8NUM; + } + + if (hw_8div != plane_size) { + memset(pack_r, 0, C8NUM * ic4); + for (int hwi = hw_8div; hwi < plane_size; hwi += 1) { + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C8NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + input_sum[hwi] = tmp_sum_value * filter_zp; + src_r += input_channel; + pack_r += C4NUM; + } + for (int hwi = plane_size; hwi < plane_size + hw_8res; hwi++) { + input_sum[hwi] = 0; + } + } } else { - MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, - conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, - conv_param->conv_quant_arg_.quant_multiplier_, - conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - (conv_param->conv_quant_arg_.filter_arg_num_ > 1)); + /* per channel */ + RowMajor2Row4x8MajorInt8(src_input, packed_input, plane_size, input_channel); + PackInputSum8x4Int8(packed_input, input_sum, input_channel, output_channel, plane_size, conv_param); } return; } +void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param, + MATMUL_OPT_R_FUNC matmul_func) { + matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, + conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, + conv_param->conv_quant_arg_.quant_multiplier_, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], false); + return; +} + +void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param) { +#ifdef ENABLE_ARM64 + MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum, + bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0], + conv_param->conv_quant_arg_.right_shift_[0], row, col, conv_param->output_channel_); +#else + MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, + conv_param->conv_quant_arg_.left_shift_, conv_param->conv_quant_arg_.right_shift_, + conv_param->conv_quant_arg_.quant_multiplier_, + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0], + conv_param->conv_quant_arg_.out_act_max_[0], false); +#endif + return; +} + // int8 convolution 3x3 void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out, diff --git a/mindspore/lite/nnacl/int8/conv_int8.h b/mindspore/lite/nnacl/int8/conv_int8.h index 101978953c..5741ee3117 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.h +++ b/mindspore/lite/nnacl/int8/conv_int8.h @@ -54,9 +54,13 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight ConvParameter *conv_param, GEMM_FUNC gemm_func); // int8 convolution 1x1 +void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel, + size_t output_channel, size_t plane_size, ConvParameter *conv_param); void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param, - MATMUL_OPT_R_FUNC matmul_func); + const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param); +void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, + const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param, + MATMUL_OPT_R_FUNC matmul_func); // int8 convolution 3x3 void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, diff --git a/mindspore/lite/nnacl/int8/deconv.c b/mindspore/lite/nnacl/int8/deconv.c index 0317ce5ea8..2195f2728b 100644 --- a/mindspore/lite/nnacl/int8/deconv.c +++ b/mindspore/lite/nnacl/int8/deconv.c @@ -172,7 +172,7 @@ void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16, bool suppport_opt) { /* optimize normal -> same layout */ - PackInputSum16x4PerLater(src, dst, filter_zp, row4, col16); + PackInputSum16x4PerLayer(src, dst, filter_zp, row4, col16); return; } diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index 13b8b8fdb9..1135cc5e09 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -36,7 +36,24 @@ void RowMajor2Row4x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int co for (int c = 0; c < col; c++) { int cd16 = c / C16NUM; int cm16 = c % C16NUM; - dst_ptr[cd16 * col16 * C4NUM + rd4 * C4NUM * C16NUM + rm4 * C16NUM + cm16] = src_ptr[r * col16 + c]; + int dst_index = rd4 * col16 * C4NUM + cd16 * C4NUM * C16NUM + rm4 * C16NUM + cm16; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } +} + +void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col4 = UP_ROUND(col, C4NUM); + for (int r = 0; r < row; r++) { + int rd8 = r / C8NUM; + int rm8 = r % C8NUM; + for (int c = 0; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + int dst_index = rd8 * col4 * C8NUM + cd4 * C8NUM * C4NUM + rm8 * C4NUM + cm4; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; } } } @@ -50,6 +67,29 @@ void MatrixPack4x16UnitInt8(int8_t *src, int8_t *dst, int row, int col, int stri return; } +void MatrixEmptyInt8(int8_t *dst, int row, int col) { + for (int r = 0; r < row; r++) { + int8_t *dst_r = dst + r * C16NUM; + memset(dst_r, 0, col * sizeof(int8_t)); + } + return; +} + +void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + /* Row-major to row16x4-major (block row-major) */ + int col4 = UP_ROUND(col, C4NUM); + for (int r = 0; r < row; r++) { + int rd8 = r / C8NUM, rm8 = r % C8NUM; + for (int c = 0; c < col; c++) { + int cd4 = c / C4NUM, cm4 = c % C4NUM; + int src_index = r * col + c; + int dst_index = rd8 * col4 * C8NUM + cd4 * C4NUM * C8NUM + rm8 * C4NUM + cm4; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } + return; +} + void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) { /* Row-major to row16x4-major (block row-major) */ int col16 = UP_ROUND(col, C16NUM); @@ -90,12 +130,15 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) { if (col != col_16div) { MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, C4NUM, col_16res, col); + MatrixEmptyInt8(dst_r + col_16div * C4NUM + col_16res, C4NUM, C16NUM - col_16res); } src_r += C4NUM * col; dst_r += C4NUM * col16; } if (row != row_4div) { + memset(dst_r, 0, C4NUM * col16); + for (int ci = 0; ci < col_16div; ci += C16NUM) { MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, row_4res, C16NUM, col); } @@ -172,6 +215,38 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row return; } +void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, + int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, + bool per_channel) { + /* row8x4-major * row4x8-major => (int8)row-major */ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r8div = r / C8NUM, r8mod = r % C8NUM; + int c8div = c / C8NUM, c8mod = c % C8NUM; + size_t ci = r * stride + c; + int32_t value = 0; + for (int d = 0; d < deep_4; d++) { + int d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + r8mod * C4NUM + d4mod; + size_t bi = c8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + c8mod * C4NUM + d4mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) + r * C8NUM + c8mod] : input_sum[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} + /* row4x16-major * col16x4-major => row4x4-major */ void MatmulInt8(const int8_t *a, const int8_t *b, int8_t *dst, const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int multiplier, int left_shift, int right_shift, int row, int col, int deep16, diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index c9e1d01873..03028c49ec 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -35,6 +35,13 @@ void RowMajor2Row4x16MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int co void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col); void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col); +void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, + int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, + bool per_channel); +void RowMajor2Row8x4MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void RowMajor2Row4x8MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); + void RowMajor2Row4x16Major(int8_t *src, int row, int col, int8_t *dst, int col_16); void RowMajor2Col16x4Major(int8_t *src, int row, int col, int8_t *dst, int row_16); void CalcInputSums(int8_t *input, int row, int col, int weight_zp, int *dst, DataOrder order); diff --git a/mindspore/lite/nnacl/matmul_parameter.h b/mindspore/lite/nnacl/matmul_parameter.h index 783d3631fe..7be90402c8 100644 --- a/mindspore/lite/nnacl/matmul_parameter.h +++ b/mindspore/lite/nnacl/matmul_parameter.h @@ -22,7 +22,7 @@ typedef void (*MATMUL_OPT_R4_FUNC)(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16, const int *input_sum, const int *bias); -typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, +typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, bool per_channel); @@ -35,11 +35,15 @@ typedef struct MatMulParameter { OpParameter op_parameter_; int row_; int col_; + int row_4_; int row_8_; int row_12_; int row_16_; + int col_4_; int col_8_; int deep_; + int deep_4_; + int deep_16_; bool has_bias_; int batch; bool a_transpose_; /* false : row-major */ diff --git a/mindspore/lite/nnacl/opt_op_handler.c b/mindspore/lite/nnacl/opt_op_handler.c index 52c7767eeb..f866081779 100644 --- a/mindspore/lite/nnacl/opt_op_handler.c +++ b/mindspore/lite/nnacl/opt_op_handler.c @@ -37,7 +37,7 @@ void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int size_t ksize, size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, - size_t asymmetric, size_t per_channel) { + size_t asymmetric, size_t per_channel) { return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min, act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel); } @@ -47,7 +47,7 @@ void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, i return MatMulOptR4Int8Neon64(a, b, dst, row4, col4, deep16, input_sum, bias); } -void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16, +void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, bool per_channel) { diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index 3f511af347..fb103815e0 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -194,7 +194,7 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam return; } -void PackInputSum16x4PerLater(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { +void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) { /* optimize normal -> same layout */ #ifdef ENABLE_ARM64 asm volatile( @@ -267,12 +267,12 @@ void PackInputSum16x4PerLater(const int8_t *src, int32_t *dst, int32_t filter_zp return; } -void PackInputSum16x4Int8(int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, +void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, size_t plane_size, ConvParameter *conv_param) { size_t hw4 = UP_ROUND(plane_size, C4NUM); size_t ic16 = UP_ROUND(input_channel, C16NUM); if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { - PackInputSum16x4PerLater(input_value, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); + PackInputSum16x4PerLayer(input_value, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16); } else { for (int ri = 0; ri < plane_size; ri++) { int ri4div = ri / C4NUM, ri4mod = ri % C4NUM; @@ -293,6 +293,40 @@ void PackInputSum16x4Int8(int8_t *input_value, int32_t *input_sum, size_t input_ return; } +void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, + size_t plane_size, ConvParameter *conv_param) { + size_t hw8 = UP_ROUND(plane_size, C8NUM); + size_t ic4 = UP_ROUND(input_channel, C4NUM); + if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) { + for (int r = 0; r < hw8; r++) { + int32_t tmp_value = 0; + for (int c = 0; c < ic4; c++) { + int r8div = r / C8NUM, r8mod = r % C8NUM, c4div = c / C4NUM, c4mod = c % C4NUM; + int src_index = r8div * C8NUM * ic4 + c4div * C8NUM * C4NUM + r8mod * C4NUM + c4mod; + tmp_value += input_value[src_index]; + } + input_sum[r] = tmp_value * conv_param->conv_quant_arg_.filter_quant_args_[0].zp_; + } + } else { + for (int ri = 0; ri < plane_size; ri++) { + int ri8div = ri / C8NUM, ri8mod = ri % C8NUM; + for (int ci = 0; ci < output_channel; ci++) { + int32_t tmp_sum_value = 0; + int ci8div = ci / C8NUM, ci8mod = ci % C8NUM; + int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_; + for (int di = 0; di < input_channel; di++) { + size_t di4div = di / C4NUM, di4mod = di % C4NUM; + int src_index = ri8div * C8NUM * ic4 + di4div * C8NUM * C4NUM + ri8mod * C4NUM + di4mod; + tmp_sum_value += input_value[src_index]; + } + int dst_index = ci8div * C8NUM * hw8 + ri * C8NUM + ci8mod; + input_sum[dst_index] = tmp_sum_value * filter_zp; + } + } + } + return; +} + void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num, int block_index) { // input format : nhwc diff --git a/mindspore/lite/nnacl/pack.h b/mindspore/lite/nnacl/pack.h index dff8d79558..b05083c52d 100644 --- a/mindspore/lite/nnacl/pack.h +++ b/mindspore/lite/nnacl/pack.h @@ -35,15 +35,18 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, int32_t *input_sum, ConvParameter *conv_param); -void PackInputSum16x4PerLater(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); +void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size); void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param); -void PackInputSum16x4Int8(int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, +void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, size_t plane_size, ConvParameter *conv_param); +void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, + size_t plane_size, ConvParameter *conv_param); + void MatrixPack(const float *src, float *dst, int row, int ic4, int stride); void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc index a64d94d1f8..39bc0ff65b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc @@ -22,6 +22,15 @@ using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_OK; namespace mindspore::kernel { +int Convolution1x1Int8Pre(void *cdata, int task_id) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->RunPre(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv1x1 Int8 RunPre error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; + } + return RET_OK; +} Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() { if (matmul_param_ != nullptr) { @@ -37,20 +46,16 @@ Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() { } void Convolution1x1Int8CPUKernel::FreeResizeBuf() { - if (packed_input_ != nullptr) { - free(packed_input_); - packed_input_ = nullptr; - } - if (input_sum_ != nullptr) { - free(input_sum_); - input_sum_ = nullptr; + if (pre_trans_input_ && input_ptr_ != nullptr) { + free(input_ptr_); + input_ptr_ = nullptr; } return; } void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { - support_optimize_ = false; - matmul_func_ = MatMulInt8_16x4_r; + support_optimize_ = true; + matmul_func_ = MatMulInt8_8x8_r; #ifdef ENABLE_ARM64 void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_; if (optimize_op_handler != nullptr) { @@ -63,14 +68,13 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { matmul_func_ = nullptr; } else { support_optimize_ = true; + matmul_func_ = MatMulInt8_8x8_r; } } else { support_optimize_ = false; matmul_func_ = nullptr; } #endif - - matmul_func_ = MatMulInt8_16x4_r; return; } @@ -80,24 +84,32 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() { auto output_channel = filter_tensor->Batch(); /* weight */ - size_t size = UP_ROUND(input_channel, C16NUM) * UP_ROUND(output_channel, C4NUM) * sizeof(int8_t); + size_t size = support_optimize_ ? UP_ROUND(input_channel, C4NUM) * UP_ROUND(output_channel, C8NUM) * sizeof(int8_t) + : UP_ROUND(input_channel, C16NUM) * UP_ROUND(output_channel, C4NUM) * sizeof(int8_t); packed_weight_ = reinterpret_cast(malloc(size)); if (packed_weight_ == nullptr) { MS_LOG(ERROR) << "Conv1x1 int8 Malloc weight error!"; return RET_ERROR; } memset(packed_weight_, 0, size); - RowMajor2Row4x16MajorInt8(reinterpret_cast(filter_tensor->Data()), packed_weight_, output_channel, - input_channel); + if (support_optimize_) { + RowMajor2Row8x4MajorInt8(reinterpret_cast(filter_tensor->Data()), packed_weight_, output_channel, + input_channel); + } else { + RowMajor2Row4x16MajorInt8(reinterpret_cast(filter_tensor->Data()), packed_weight_, output_channel, + input_channel); + } /* bias = bias - v2 x zp1 + zp1 x zp2 */ int col4 = UP_ROUND(output_channel, C4NUM); - bias_data_ = malloc(col4 * sizeof(int32_t)); + int col8 = UP_ROUND(output_channel, C8NUM); + size = support_optimize_ ? col8 * sizeof(int32_t) : col4 * sizeof(int32_t); + bias_data_ = malloc(size); if (bias_data_ == nullptr) { MS_LOG(ERROR) << "Conv1x1 int8 Malloc bias_ptr_ error!"; return RET_ERROR; } - memset(bias_data_, 0, col4 * sizeof(int32_t)); + memset(bias_data_, 0, size); if (in_tensors_.size() == 3) { memcpy(bias_data_, in_tensors_[kBiasIndex]->Data(), output_channel * sizeof(int32_t)); } @@ -119,9 +131,6 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() { } int Convolution1x1Int8CPUKernel::Init() { - if (!InferShapeDone()) { - return RET_OK; - } matmul_param_ = new (std::nothrow) MatMulParameter(); if (matmul_param_ == nullptr) { MS_LOG(ERROR) << "Init matmul_param_ failed."; @@ -142,6 +151,9 @@ int Convolution1x1Int8CPUKernel::Init() { return ret; } + if (!InferShapeDone()) { + return RET_OK; + } return ReSize(); } @@ -152,30 +164,52 @@ int Convolution1x1Int8CPUKernel::InitParam() { matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_; matmul_param_->deep_ = conv_param_->input_channel_; matmul_param_->col_ = conv_param_->output_channel_; - - thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C4NUM)); - thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C4NUM), thread_count_); - - size_t size = UP_ROUND(matmul_param_->row_, C4NUM) * UP_ROUND(matmul_param_->deep_, C16NUM); - packed_input_ = reinterpret_cast(malloc(size * sizeof(int8_t))); - if (packed_input_ == nullptr) { - MS_LOG(ERROR) << "conv1x1 int8 Malloc packed_input_ error!"; - return RET_ERROR; + matmul_param_->col_4_ = UP_ROUND(matmul_param_->col_, C4NUM); + matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); + matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM); + matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM); + matmul_param_->deep_4_ = UP_ROUND(matmul_param_->deep_, C4NUM); + matmul_param_->deep_16_ = UP_ROUND(matmul_param_->deep_, C16NUM); + + /* init input sum size */ + if (support_optimize_) { + if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { + input_sum_size = UP_ROUND(conv_param_->output_channel_, C8NUM) * UP_ROUND(matmul_param_->row_, C8NUM); + } else { + input_sum_size = UP_ROUND(matmul_param_->row_, C8NUM); + } + } else { + if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { + input_sum_size = UP_ROUND(conv_param_->output_channel_, C4NUM) * UP_ROUND(matmul_param_->row_, C4NUM); + } else { + input_sum_size = UP_ROUND(matmul_param_->row_, C4NUM); + } } - memset(packed_input_, 0, size * sizeof(int8_t)); - if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { - size = UP_ROUND(conv_param_->output_channel_, C4NUM) * UP_ROUND(matmul_param_->row_, C4NUM); + if (support_optimize_) { + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C8NUM)); + thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C8NUM), thread_count_); } else { - size = UP_ROUND(matmul_param_->row_, C4NUM); + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, C4NUM)); + thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, C4NUM), thread_count_); } - input_sum_ = reinterpret_cast(malloc(size * sizeof(int32_t))); - if (input_sum_ == nullptr) { - MS_LOG(ERROR) << "malloc input_sum_ failed."; - return RET_ERROR; + + if (support_optimize_) { + thread_count_hw_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, C8NUM)); + thread_stride_hw_ = UP_DIV(UP_DIV(matmul_param_->row_, C8NUM), thread_count_hw_); + } else { + thread_count_hw_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, C4NUM)); + thread_stride_hw_ = UP_DIV(UP_DIV(matmul_param_->row_, C4NUM), thread_count_hw_); } - memset(input_sum_, 0, size * sizeof(int32_t)); + if (pre_trans_input_) { + input_ptr_ = reinterpret_cast(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t))); + if (input_ptr_ == nullptr) { + MS_LOG(ERROR) << "Conv1x1 int8 Malloc input_ptr_ error!"; + return RET_MEMORY_FAILED; + } + memset(input_ptr_, 0, matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t)); + } return RET_OK; } @@ -199,21 +233,54 @@ void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_out } else { input_ptr_ = src_input; } - RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_); + + if (support_optimize_) { + ParallelLaunch(THREAD_POOL_DEFAULT, Convolution1x1Int8Pre, this, thread_count_hw_); + } else { + RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_); + PackInputSum16x4Int8(packed_input_, input_sum_, matmul_param_->deep_, matmul_param_->col_, matmul_param_->row_, + conv_param_); + } + return; } int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { - int cur_oc = MSMIN(thread_stride_ * C4NUM, matmul_param_->col_ - task_id * thread_stride_ * C4NUM); - if (cur_oc <= 0) { - return RET_OK; + if (support_optimize_) { + int cur_stride = thread_stride_ * C8NUM; + int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C8NUM; + int cur_oc = MSMIN(cur_stride, res_stride); + if (cur_oc <= 0) { + return RET_OK; + } + Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_, + output_ptr_ + task_id * thread_stride_ * C8NUM, input_sum_, + reinterpret_cast(bias_data_) + task_id * thread_stride_ * C8NUM, matmul_param_->row_, + cur_oc, matmul_param_->deep_4_, conv_param_, matmul_func_); + } else { + int cur_stride = thread_stride_ * C4NUM; + int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C4NUM; + int cur_oc = MSMIN(cur_stride, res_stride); + if (cur_oc <= 0) { + return RET_OK; + } + Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_, + output_ptr_ + task_id * thread_stride_ * C4NUM, input_sum_, + reinterpret_cast(bias_data_) + task_id * thread_stride_ * C4NUM, matmul_param_->row_, cur_oc, + matmul_param_->deep_16_, conv_param_); } + return RET_OK; +} - int32_t *bias = reinterpret_cast(bias_data_) + thread_stride_ * C4NUM * task_id; - - Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_, - output_ptr_ + task_id * thread_stride_ * C4NUM, input_sum_, bias + task_id * thread_stride_ * C4NUM, - matmul_param_->row_, cur_oc, UP_ROUND(matmul_param_->deep_, C16NUM), conv_param_, matmul_func_); +int Convolution1x1Int8CPUKernel::RunPre(int task_id) { + int cur_hw = MSMIN(thread_stride_hw_ * C8NUM, matmul_param_->row_ - task_id * thread_stride_hw_ * C8NUM); + if (cur_hw <= 0) { + return RET_OK; + } + Conv1x1PreOpt(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_, + packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_, + input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, matmul_param_->col_, cur_hw, + conv_param_); return RET_OK; } @@ -227,6 +294,35 @@ int Convolution1x1Int8Impl(void *cdata, int task_id) { return RET_OK; } +int Convolution1x1Int8CPUKernel::InitRunBuf() { + input_sum_ = reinterpret_cast(malloc(input_sum_size * sizeof(int32_t))); + if (input_sum_ == nullptr) { + MS_LOG(ERROR) << "malloc input_sum_ failed."; + return RET_ERROR; + } + + size_t size = support_optimize_ ? UP_ROUND(matmul_param_->row_, C8NUM) * UP_ROUND(matmul_param_->deep_, C4NUM) + : UP_ROUND(matmul_param_->row_, C4NUM) * UP_ROUND(matmul_param_->deep_, C16NUM); + packed_input_ = reinterpret_cast(ctx_->allocator->Malloc(size * sizeof(int8_t))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "conv1x1 int8 Malloc packed_input_ error!"; + return RET_ERROR; + } + return RET_OK; +} + +void Convolution1x1Int8CPUKernel::FreeRunBuf() { + if (packed_input_ != nullptr) { + ctx_->allocator->Free(packed_input_); + packed_input_ = nullptr; + } + if (input_sum_ != nullptr) { + ctx_->allocator->Free(input_sum_); + input_sum_ = nullptr; + } + return; +} + int Convolution1x1Int8CPUKernel::Run() { auto ret = Prepare(); if (ret != RET_OK) { @@ -234,13 +330,10 @@ int Convolution1x1Int8CPUKernel::Run() { return RET_ERROR; } - if (pre_trans_input_) { - input_ptr_ = - reinterpret_cast(ctx_->allocator->Malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t))); - if (input_ptr_ == nullptr) { - MS_LOG(ERROR) << "Conv1x1 int8 Malloc input_ptr_ error!"; - return RET_MEMORY_FAILED; - } + int error_code = InitRunBuf(); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv1x1 int8 InitRunBuf error_code[" << error_code << "]"; + return RET_ERROR; } int8_t *src_in = reinterpret_cast(in_tensors_[0]->Data()); @@ -249,21 +342,10 @@ int Convolution1x1Int8CPUKernel::Run() { for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_, src_out + batch_index * matmul_param_->row_ * matmul_param_->col_); - - PackInputSum16x4Int8(packed_input_, input_sum_, matmul_param_->deep_, matmul_param_->col_, matmul_param_->row_, - conv_param_); - - int error_code = ParallelLaunch(THREAD_POOL_DEFAULT, Convolution1x1Int8Impl, this, thread_count_); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "conv1x1 fp16 error error_code[" << error_code << "]"; - return RET_ERROR; - } + ParallelLaunch(THREAD_POOL_DEFAULT, Convolution1x1Int8Impl, this, thread_count_); } - if (pre_trans_input_ && input_ptr_ != nullptr) { - ctx_->allocator->Free(input_ptr_); - input_ptr_ = nullptr; - } + FreeRunBuf(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h index f3e2018852..6ffd5aa4a7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h @@ -40,8 +40,13 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { int ReSize() override; int Run() override; + private: + int InitRunBuf(); + void FreeRunBuf(); + public: int RunImpl(int task_id); + int RunPre(int task_id); private: void FreeResizeBuf(); @@ -58,7 +63,10 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { int8_t *output_ptr_ = nullptr; size_t thread_count_ = 1; size_t thread_stride_ = 0; + size_t thread_count_hw_ = 1; + size_t thread_stride_hw_ = 0; bool pre_trans_input_ = false; + size_t input_sum_size = 0; MatMulParameter *matmul_param_ = nullptr; MATMUL_OPT_R_FUNC matmul_func_ = nullptr; bool support_optimize_ = false; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc index 7b94444ae4..f0c5f33da5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_int8.cc @@ -398,11 +398,11 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vectordilation_h_; int dilation_w = conv_param->dilation_w_; kernel::LiteKernel *kernel; + auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); - } else if (kernel_h == 1 && kernel_w == 1) { - /* Convolution1x1Int8CPUKernel */ - kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); + } else if (kernel_h == 1 && kernel_w == 1 && filter_quant_size == 1) { + kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); } else { kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); }