From b1dbfa643b129c2296a9542cd8b96641066e1de6 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Thu, 22 Oct 2020 15:17:29 +0800 Subject: [PATCH] [MSLITE][Develop] optimize arm cpu int8 op conv dw 3x3: add assembly arm32 --- .../assembly/arm32/ConvDw3x3BorderPixelInt8.S | 116 +++++++++++ .../assembly/arm64/ConvDw3x3BorderPixelInt8.S | 180 ++++++++++++------ mindspore/lite/nnacl/int8/common_func.h | 8 +- .../lite/nnacl/int8/conv_depthwise_int8.c | 9 +- .../arm/int8/convolution_depthwise_int8.cc | 16 +- 5 files changed, 259 insertions(+), 70 deletions(-) create mode 100644 mindspore/lite/nnacl/assembly/arm32/ConvDw3x3BorderPixelInt8.S diff --git a/mindspore/lite/nnacl/assembly/arm32/ConvDw3x3BorderPixelInt8.S b/mindspore/lite/nnacl/assembly/arm32/ConvDw3x3BorderPixelInt8.S new file mode 100644 index 0000000000..1913b1a8e8 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/ConvDw3x3BorderPixelInt8.S @@ -0,0 +1,116 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global ConvDw3x3BorderPixelInt8 +#ifndef __APPLE__ +.type ConvDw3x3BorderPixelInt8, %function +#endif + +// void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, +// size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, +// size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max) { + +// r0: dst, r1: src, r2: weight, r3: bias, r4: height, r5: width, r6: in_kh_step, r7: in_kw_step, +// r8: channel, r9: in_zp, r10: out_zp, r11: out_multiplier, r12: left_shift, r13: right_shift +// r14: acc_min, r15: acc_max +ConvDw3x3BorderPixelInt8: + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + + push {r4-r8, r9-r12, lr} + vpush {q4-q7} + add sp, sp, #104 + + ldr r4, [sp] + ldr r5, [sp, #4] + ldr r6, [sp, #8] + ldr r7, [sp, #12] + ldr r8, [sp, #16] + + ldrb r10, [sp, #20] // in_zp + vdup.8 d18, r10 + ldr r10, [sp, #24] // out_zp + vdup.32 q15, r10 + ldr r10, [sp, #28] // out_multiplier + vdup.32 q14, r10 + ldr r10, [sp, #32] // left_shift + vdup.32 q13, r10 + ldr r10, [sp, #36] // right_shift + vdup.32 q12, r10 + ldr r10, [sp, #40] // acc_min + vdup.32 q11, r10 + ldr r10, [sp, #44] // acc_max + vdup.32 q10, r10 + + mov r4, #2 + mul lr, r8, r4 + + LoopC: + mov r9, r1 + mov r10, r2 + ldr r4, [sp] + + vld1.32 {q3}, [r3]! + vld1.32 {q4}, [r3]! + LoopH: + mov r11, r9 + mov r12, r10 + ldr r5, [sp, #4] + LoopW: + vld1.8 {d0}, [r11], r7 + vld1.16 {d2, d3}, [r12], lr // weight + vsubl.s8 q2, d0, d18 // -zp + + vmlal.s16 q3, d4, d2 + vmlal.s16 q4, d5, d3 + + subs r5, r5, #1 + bne LoopW + subs r4, r4, #1 + add r9, r9, r6 + mov r11, #3 + mul r5, lr, r11 + add r10, r10, r5 + bne LoopH + + vshl.s32 q3, q3, q13 + vqrdmulh.s32 q3, q3, q14 + vand q5, q3, q12 + vshr.s32 q5, q5, #31 + vqadd.s32 q3, q3, q5 + vrshl.s32 q3, q3, q12 + vadd.i32 q3, q3, q15 + vmax.s32 q3, q3, q11 + vmin.s32 q3, q3, q10 + vqmovn.s32 d14, q3 + + vshl.s32 q4, q4, q13 + vqrdmulh.s32 q4, q4, q14 + vand q6, q4, q12 + vshr.s32 q6, q6, #31 + vqadd.s32 q4, q4, q6 + vrshl.s32 q4, q4, q12 + vadd.i32 q4, q4, q15 + vmax.s32 q4, q4, q11 + vmin.s32 q4, q4, q10 + vqmovn.s32 d15, q4 + vqmovn.s16 d16, q7 + + vst1.8 {d16}, [r0]! + add r1, r1, #8 + add r2, r2, #16 + + sub r8, r8, #8 + cmp r8, #8 + bge LoopC + + sub sp, sp, #104 + vpop {q4-q7} + pop {r4-r8, r9-r12, pc} +#endif +#endif diff --git a/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3BorderPixelInt8.S b/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3BorderPixelInt8.S index c1985a4c5c..80132ac658 100644 --- a/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3BorderPixelInt8.S +++ b/mindspore/lite/nnacl/assembly/arm64/ConvDw3x3BorderPixelInt8.S @@ -41,66 +41,128 @@ ConvDw3x3BorderPixelInt8: mul x14, x13, x9 // x8 * 3 * 2 LoopC: + ld1 {v23.4s}, [x3], #16 + ld1 {v24.4s}, [x3], #16 + mov x9, x1 mov x10, x2 - mov x17, x4 // height - - ld1 {v5.4s}, [x3], #16 - mov v3.16b, v5.16b - ld1 {v6.4s}, [x3], #16 - mov v4.16b, v6.16b - LoopH: - mov x11, x9 - mov x12, x10 - mov x18, x5 // width - LoopW: - ld1 {v0.8b}, [x11], x7 - ssubl v1.8h, v0.8b, v25.8b - - ld1 {v2.8h}, [x12], x13 // weight - smlal v3.4s, v1.4h, v2.4h - smlal2 v4.4s, v1.8h, v2.8h - - subs x18, x18, #1 - bne LoopW - subs x17, x17, #1 - add x9, x9, x6 - add x10, x10, x14 - bne LoopH - - sqshl v3.4s, v3.4s, v28.4s - sqshl v4.4s, v4.4s, v28.4s - sqrdmulh v3.4s, v3.4s, v27.4s - sqrdmulh v4.4s, v4.4s, v27.4s - - and v12.16b, v29.16b, v3.16b - sshr v12.4s, v12.4s, #31 - sqadd v3.4s, v3.4s, v12.4s - srshl v3.4s, v3.4s, v29.4s - - and v11.16b, v29.16b, v4.16b - sshr v11.4s, v11.4s, #31 - sqadd v4.4s, v4.4s, v11.4s - srshl v4.4s, v4.4s, v29.4s - - add v3.4s, v3.4s, v26.4s - add v4.4s, v4.4s, v26.4s - smax v3.4s, v3.4s, v30.4s - smax v4.4s, v4.4s, v30.4s - smin v3.4s, v3.4s, v31.4s - smin v4.4s, v4.4s, v31.4s - - sqxtn v3.4h, v3.4s - sqxtn v4.4h, v4.4s - sqxtn v3.8b, v3.8h - sqxtn v4.8b, v4.8h - - st1 {v3.s}[0], [x0], #4 - st1 {v4.s}[0], [x0], #4 - add x1, x1, #8 - add x2, x2, #16 - sub x8, x8, #8 - cmp x8, #8 - bge LoopC + cmp x4, #2 + blt LoopHW + LoopH2W2: + cmp x5, #2 + blt LoopHW + ld1 {v0.8b}, [x9], x7 + ssubl v0.8h, v0.8b, v25.8b + add x11, x1, x6 + ld1 {v4.8h}, [x10], x13 // weight + smlal v23.4s, v0.4h, v4.4h + smlal2 v24.4s, v0.8h, v4.8h + add x12, x2, x14 + ld1 {v1.8b}, [x9], x7 + ssubl v1.8h, v1.8b, v25.8b + ld1 {v5.8h}, [x10], x13 + smlal v23.4s, v1.4h, v5.4h + smlal2 v24.4s, v1.8h, v5.8h + add x15, x11, x6 + ld1 {v2.8b}, [x11], x7 + ssubl v2.8h, v2.8b, v25.8b + add x16, x12, x14 + ld1 {v6.8h}, [x12], x13 + smlal v23.4s, v2.4h, v6.4h + smlal2 v24.4s, v2.8h, v6.8h + ld1 {v3.8b}, [x11], x7 + ssubl v3.8h, v3.8b, v25.8b + ld1 {v7.8h}, [x12], x13 + smlal v23.4s, v3.4h, v7.4h + smlal2 v24.4s, v3.8h, v7.8h + cmp x5, #3 + beq LoopH2W3 + cmp x4, #3 + beq LoopH3W2 + b Post + + LoopH2W3: + ld1 {v16.8b}, [x9], x7 + ssubl v16.8h, v16.8b, v25.8b + ld1 {v17.8h}, [x10], x13 + smlal v23.4s, v16.4h, v17.4h + smlal2 v24.4s, v16.8h, v17.8h + ld1 {v18.8b}, [x11], x7 + ssubl v18.8h, v18.8b, v25.8b + ld1 {v19.8h}, [x12], x13 + smlal v23.4s, v18.4h, v19.4h + smlal2 v24.4s, v18.8h, v19.8h + b Post + + LoopH3W2: + ld1 {v16.8b}, [x15], x7 + ssubl v16.8h, v16.8b, v25.8b + ld1 {v17.8h}, [x16], x13 + smlal v23.4s, v16.4h, v17.4h + smlal2 v24.4s, v16.8h, v17.8h + ld1 {v18.8b}, [x15], x7 + ssubl v18.8h, v18.8b, v25.8b + ld1 {v19.8h}, [x16], x13 + smlal v23.4s, v18.4h, v19.4h + smlal2 v24.4s, v18.8h, v19.8h + b Post + + LoopHW: + mov x9, x1 + mov x10, x2 + mov x17, x4 // height + LoopH: + mov x11, x9 + mov x12, x10 + mov x18, x5 // width + LoopW: + ld1 {v0.8b}, [x11], x7 + ssubl v1.8h, v0.8b, v25.8b + + ld1 {v2.8h}, [x12], x13 // weight + smlal v23.4s, v1.4h, v2.4h + smlal2 v24.4s, v1.8h, v2.8h + + subs x18, x18, #1 + bne LoopW + subs x17, x17, #1 + add x9, x9, x6 + add x10, x10, x14 + bne LoopH + Post: + sqshl v23.4s, v23.4s, v28.4s + sqshl v24.4s, v24.4s, v28.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + + and v12.16b, v29.16b, v23.16b + sshr v12.4s, v12.4s, #31 + sqadd v23.4s, v23.4s, v12.4s + srshl v23.4s, v23.4s, v29.4s + + and v11.16b, v29.16b, v24.16b + sshr v11.4s, v11.4s, #31 + sqadd v24.4s, v24.4s, v11.4s + srshl v24.4s, v24.4s, v29.4s + + add v23.4s, v23.4s, v26.4s + add v24.4s, v24.4s, v26.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v23.4h, v23.4s + sqxtn v24.4h, v24.4s + sqxtn v23.8b, v23.8h + sqxtn v24.8b, v24.8h + + st1 {v23.s}[0], [x0], #4 + st1 {v24.s}[0], [x0], #4 + add x1, x1, #8 + add x2, x2, #16 + sub x8, x8, #8 + cmp x8, #8 + bge LoopC ret #endif diff --git a/mindspore/lite/nnacl/int8/common_func.h b/mindspore/lite/nnacl/int8/common_func.h index 77acc93aef..2819049739 100644 --- a/mindspore/lite/nnacl/int8/common_func.h +++ b/mindspore/lite/nnacl/int8/common_func.h @@ -47,6 +47,10 @@ void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, con size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t *acc_min, int32_t *acc_max); +void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, + size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, + size_t out_zp, size_t out_multiplier, size_t left_shift, size_t right_shift, + size_t acc_min, size_t acc_max); #endif #ifdef ENABLE_ARM32 @@ -67,10 +71,6 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); -void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, - size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, - size_t out_zp, size_t out_multiplier, size_t left_shift, size_t right_shift, - size_t acc_min, size_t acc_max); #endif #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c index 5f92bafe59..e1046d3eeb 100644 --- a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c +++ b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c @@ -140,11 +140,8 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da /*conv depthwise 3x3 int8 begin*/ bool CheckIfUse3X3(const ConvParameter *conv_param, int channel) { - bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && - (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) && - (conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) && - conv_param->stride_h_ == conv_param->stride_w_ && - (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && + bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 && + conv_param->stride_w_ == 1 && (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && conv_param->pad_u_ == conv_param->pad_l_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (channel % 8 == 0); return use_3x3; @@ -303,7 +300,7 @@ void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data } } -#ifndef ENABLE_ARM64 +#ifndef ENABLE_ARM void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc index be919fb0b9..706f286524 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc @@ -172,7 +172,21 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector auto act_quant_size = MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size()); if (act_quant_size == 1) { // per tensor - kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); + auto conv_parm = reinterpret_cast(opParameter); + auto channel = inputs[kWeightIndex]->shape()[0]; + auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); + if (CheckIfUse3X3(conv_parm, channel) && weight_quant_size == 1) { +#ifdef ENABLE_ARM64 + kernel = + new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); +#else + kernel = + new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); +#endif + } else { + kernel = + new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); + } } else { // per channel kernel = new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);