From 4f58c85d94fb791554e5def3ca542a62eaa83e7a Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Sat, 12 Sep 2020 17:15:07 +0800 Subject: [PATCH] [MSLITE][Develop] arm cpu int8 conv depthwise support arm32 --- .../assembly/arm32/ConvDwInt8PostAlign4.S | 110 ++++++++++++++ .../lite/nnacl/assembly/arm32/ConvDwInt8Row.S | 134 ++++++++++++++++++ mindspore/lite/nnacl/int8/common_func.h | 11 +- .../lite/nnacl/int8/conv_depthwise_int8.c | 4 +- 4 files changed, 253 insertions(+), 6 deletions(-) create mode 100644 mindspore/lite/nnacl/assembly/arm32/ConvDwInt8PostAlign4.S create mode 100644 mindspore/lite/nnacl/assembly/arm32/ConvDwInt8Row.S diff --git a/mindspore/lite/nnacl/assembly/arm32/ConvDwInt8PostAlign4.S b/mindspore/lite/nnacl/assembly/arm32/ConvDwInt8PostAlign4.S new file mode 100644 index 0000000000..b9d0e9b92a --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/ConvDwInt8PostAlign4.S @@ -0,0 +1,110 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global ConvDwInt8PostAlign4 +#ifndef __APPLE__ +.type ConvDwInt8PostAlign4, %function +#endif + +// void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, +// int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +// r0: dst, r1: buffer, r2: num_pixels, r3: output_zp, r4: out_multiplier, +// r5: left_shift, r6: right_shift, r7: acc_min, r8: acc_max + +ConvDwInt8PostAlign4: + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r10} + vpush {q4-q7} + add sp, sp, #88 + + vdup.32 q15, r3 // output_zp + + ldr r4, [sp] // out_multiplier + vdup.32 q14, r4 + + ldr r5, [sp, #4] // left_shift + vdup.32 q13, r5 + + ldr r6, [sp, #8] // right_shift + vdup.32 q12, r6 + + ldr r7, [sp, #12] // acc_min + vdup.32 q11, r7 + + ldr r8, [sp, #16] // acc_max + vdup.32 q10, r8 + + mov r10, r0 + + LoopDepth8: + cmp r2, #8 + blt End + vld1.32 {q0}, [r1]! + vshl.s32 q0, q0, q13 + vqrdmulh.s32 q0, q0, q14 + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + vqmovn.s32 d4, q0 + + vld1.32 {q1}, [r1]! + vshl.s32 q1, q1, q13 + vqrdmulh.s32 q1, q1, q14 + vand q4, q1, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q1, q1, q4 + vrshl.s32 q1, q1, q12 + vadd.i32 q1, q1, q15 + vmax.s32 q1, q1, q11 + vmin.s32 q1, q1, q10 + vqmovn.s32 d5, q1 + vqmovn.s16 d4, q2 + + vst1.8 {d4}, [r10]! + + sub r2, r2, #8 + b LoopDepth8 + + LoopDepth4: + cmp r2, #4 + blt End + vld1.32 {q0}, [r1]! + + vshl.s32 q0, q0, q13 + vqrdmulh.s32 q0, q0, q14 + vand q4, q0, q12 + vshr.s32 q4, q4, #31 + vqadd.s32 q0, q0, q4 + vrshl.s32 q0, q0, q12 + vadd.i32 q0, q0, q15 + vmax.s32 q0, q0, q11 + vmin.s32 q0, q0, q10 + + vqmovn.s32 d0, q0 + vqmovn.s16 d0, q0 + + vst1.8 {d0[0]}, [r10]! + vst1.8 {d0[1]}, [r10]! + vst1.8 {d0[2]}, [r10]! + vst1.8 {d0[3]}, [r10]! + + sub r2, r2, #4 + b LoopDepth4 + End: + sub sp, sp, #88 + vpop {q4-q7} + pop {r4-r8, r10} + bx lr + +#endif +#endif diff --git a/mindspore/lite/nnacl/assembly/arm32/ConvDwInt8Row.S b/mindspore/lite/nnacl/assembly/arm32/ConvDwInt8Row.S new file mode 100644 index 0000000000..9b5bfa1242 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/ConvDwInt8Row.S @@ -0,0 +1,134 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global ConvDwInt8Row +#ifndef __APPLE__ +.type ConvDwInt8Row, %function +#endif + +// void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, +// int output_channel, int input_step, int8_t input_zp) +// r0: output_ptr, r1: input_ptr, r2: weight_ptr, r3: num_pixels, +// r4: output_channel, r5: input_step, r6: input_zp, + +ConvDwInt8Row: + // at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr" + // according to https://stackoverflow.com/questions/53625807 + // even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway + // clang's rule seems more simple, though there are no subroutine calls here + // r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf + push {r4-r8, r9-r12, lr} + vpush {q4-q7} + add sp, sp, #104 + + cmp r3, #0 + beq End + + ldr r4, [sp] // channel + ldr r5, [sp, #4] // input_step + ldr r6, [sp, #8] // input_zp + vdup.8 d30, r6 + + mov r7, r0 + + LoopPixel: + mov r8, r1 // input + mov r10, r2 // weight + mov r11, r4 + + LoopDepth16In: + cmp r11, #16 + blt L8 + sub r11, r11, #16 + + vld1.8 {q0}, [r8]! + vld1.16 {q1, q2}, [r10]! // weight + + vsubl.s8 q3, d0, d30 // -zp + vld1.32 {q4, q5}, [r0]! + vmlal.s16 q4, d6, d2 + vmlal.s16 q5, d7, d3 + + cmp r11, #16 + blt LoopDepth16Out + LoopDepth16: + vst1.32 {q4, q5}, [r7]! + + vsubl.s8 q6, d1, d30 + vld1.32 {q7, q8}, [r0]! + vmlal.s16 q7, d12, d4 + vmlal.s16 q8, d13, d5 + vst1.32 {q7, q8}, [r7]! + + vld1.8 {q0}, [r8]! + vld1.16 {q1, q2}, [r10]! // weight + + vsubl.s8 q3, d0, d30 // -zp + vld1.32 {q4, q5}, [r0]! + vmlal.s16 q4, d6, d2 + vmlal.s16 q5, d7, d3 + + sub r11, r11, #16 + cmp r11, #16 + bge LoopDepth16 + + LoopDepth16Out: + vst1.32 {q4, q5}, [r7]! + + vsubl.s8 q6, d1, d30 + vld1.32 {q7, q8}, [r0]! + vmlal.s16 q7, d12, d4 + vmlal.s16 q8, d13, d5 + vst1.32 {q7, q8}, [r7]! + + L8: + cmp r11, #8 + blt L0 + + LoopDepth8: + vld1.8 {d0}, [r8]! + vld1.16 {d2, d3}, [r10]! // weight + + vsubl.s8 q2, d0, d30 // -zp + + vld1.32 {q3}, [r0]! + vmlal.s16 q3, d4, d2 + vst1.32 {q3}, [r7]! + + vld1.32 {q4}, [r0]! + vmlal.s16 q4, d5, d3 + vst1.32 {q4}, [r7]! + + sub r11, r11, #8 + cmp r11, #8 + bge LoopDepth8 + + L0: + cmp r11, #0 + beq LoopDepthEnd + + LoopDepth0: + ldrsb r12, [r8], #1 + ldrsh r9, [r10], #2 + sub r12, r12, r6 + + ldr lr, [r0], #4 + smlabb r12, r12, r9, lr + str r12, [r7], #4 + + subs r11, r11, #1 + bne L0 + + LoopDepthEnd: + add r1, r1, r5 + subs r3, r3, #1 + bne LoopPixel + + End: + sub sp, sp, #104 + vpop {q4-q7} + pop {r4-r8, r9-r12, pc} +#endif +#endif diff --git a/mindspore/lite/nnacl/int8/common_func.h b/mindspore/lite/nnacl/int8/common_func.h index eabb22c6f4..56f2bb9e42 100644 --- a/mindspore/lite/nnacl/int8/common_func.h +++ b/mindspore/lite/nnacl/int8/common_func.h @@ -32,6 +32,13 @@ void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi); +#ifdef ENABLE_ARM +void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, + int output_channel, int input_step, int8_t input_zp); +void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, + int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); +#endif + #ifdef ENABLE_ARM64 void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, @@ -50,10 +57,6 @@ void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, con size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t *acc_min, int32_t *acc_max); -void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, - int output_channel, int input_step, int8_t input_zp); -void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, - int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max); void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min, int32_t acc_max); diff --git a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c index b84bc58357..c5590e1951 100644 --- a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c +++ b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c @@ -20,7 +20,7 @@ #include "nnacl/int8/common_func.h" /*conv depthwise int8 begin*/ -#ifndef ENABLE_ARM64 +#ifndef ENABLE_ARM void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, int output_channel, int input_step, int8_t input_zp) { for (int i = 0; i < num_pixels; i++) { @@ -59,7 +59,7 @@ void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int output_w, int channel, int } else { int num_pixels = output_w * channel; int align_num = 0; -#ifdef ENABLE_ARM64 +#ifdef ENABLE_ARM align_num = num_pixels / 4 * 4; ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier[0], left_shift[0], right_shift[0], acc_min, acc_max);