!6202 [MSLITE][Develop] arm cpu int8 conv depthwise support arm32

Merge pull request !6202 from yangruoqi713/int8_arm32
pull/6202/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 07422d5438

@ -0,0 +1,110 @@
#ifdef __arm__
#ifndef __aarch64__
.text
.align 5
.global ConvDwInt8PostAlign4
#ifndef __APPLE__
.type ConvDwInt8PostAlign4, %function
#endif
// void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
// int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max);
// r0: dst, r1: buffer, r2: num_pixels, r3: output_zp, r4: out_multiplier,
// r5: left_shift, r6: right_shift, r7: acc_min, r8: acc_max
ConvDwInt8PostAlign4:
// at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr"
// according to https://stackoverflow.com/questions/53625807
// even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway
// clang's rule seems more simple, though there are no subroutine calls here
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
push {r4-r8, r10}
vpush {q4-q7}
add sp, sp, #88
vdup.32 q15, r3 // output_zp
ldr r4, [sp] // out_multiplier
vdup.32 q14, r4
ldr r5, [sp, #4] // left_shift
vdup.32 q13, r5
ldr r6, [sp, #8] // right_shift
vdup.32 q12, r6
ldr r7, [sp, #12] // acc_min
vdup.32 q11, r7
ldr r8, [sp, #16] // acc_max
vdup.32 q10, r8
mov r10, r0
LoopDepth8:
cmp r2, #8
blt End
vld1.32 {q0}, [r1]!
vshl.s32 q0, q0, q13
vqrdmulh.s32 q0, q0, q14
vand q4, q0, q12
vshr.s32 q4, q4, #31
vqadd.s32 q0, q0, q4
vrshl.s32 q0, q0, q12
vadd.i32 q0, q0, q15
vmax.s32 q0, q0, q11
vmin.s32 q0, q0, q10
vqmovn.s32 d4, q0
vld1.32 {q1}, [r1]!
vshl.s32 q1, q1, q13
vqrdmulh.s32 q1, q1, q14
vand q4, q1, q12
vshr.s32 q4, q4, #31
vqadd.s32 q1, q1, q4
vrshl.s32 q1, q1, q12
vadd.i32 q1, q1, q15
vmax.s32 q1, q1, q11
vmin.s32 q1, q1, q10
vqmovn.s32 d5, q1
vqmovn.s16 d4, q2
vst1.8 {d4}, [r10]!
sub r2, r2, #8
b LoopDepth8
LoopDepth4:
cmp r2, #4
blt End
vld1.32 {q0}, [r1]!
vshl.s32 q0, q0, q13
vqrdmulh.s32 q0, q0, q14
vand q4, q0, q12
vshr.s32 q4, q4, #31
vqadd.s32 q0, q0, q4
vrshl.s32 q0, q0, q12
vadd.i32 q0, q0, q15
vmax.s32 q0, q0, q11
vmin.s32 q0, q0, q10
vqmovn.s32 d0, q0
vqmovn.s16 d0, q0
vst1.8 {d0[0]}, [r10]!
vst1.8 {d0[1]}, [r10]!
vst1.8 {d0[2]}, [r10]!
vst1.8 {d0[3]}, [r10]!
sub r2, r2, #4
b LoopDepth4
End:
sub sp, sp, #88
vpop {q4-q7}
pop {r4-r8, r10}
bx lr
#endif
#endif

@ -0,0 +1,134 @@
#ifdef __arm__
#ifndef __aarch64__
.text
.align 5
.global ConvDwInt8Row
#ifndef __APPLE__
.type ConvDwInt8Row, %function
#endif
// void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
// int output_channel, int input_step, int8_t input_zp)
// r0: output_ptr, r1: input_ptr, r2: weight_ptr, r3: num_pixels,
// r4: output_channel, r5: input_step, r6: input_zp,
ConvDwInt8Row:
// at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr"
// according to https://stackoverflow.com/questions/53625807
// even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway
// clang's rule seems more simple, though there are no subroutine calls here
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
push {r4-r8, r9-r12, lr}
vpush {q4-q7}
add sp, sp, #104
cmp r3, #0
beq End
ldr r4, [sp] // channel
ldr r5, [sp, #4] // input_step
ldr r6, [sp, #8] // input_zp
vdup.8 d30, r6
mov r7, r0
LoopPixel:
mov r8, r1 // input
mov r10, r2 // weight
mov r11, r4
LoopDepth16In:
cmp r11, #16
blt L8
sub r11, r11, #16
vld1.8 {q0}, [r8]!
vld1.16 {q1, q2}, [r10]! // weight
vsubl.s8 q3, d0, d30 // -zp
vld1.32 {q4, q5}, [r0]!
vmlal.s16 q4, d6, d2
vmlal.s16 q5, d7, d3
cmp r11, #16
blt LoopDepth16Out
LoopDepth16:
vst1.32 {q4, q5}, [r7]!
vsubl.s8 q6, d1, d30
vld1.32 {q7, q8}, [r0]!
vmlal.s16 q7, d12, d4
vmlal.s16 q8, d13, d5
vst1.32 {q7, q8}, [r7]!
vld1.8 {q0}, [r8]!
vld1.16 {q1, q2}, [r10]! // weight
vsubl.s8 q3, d0, d30 // -zp
vld1.32 {q4, q5}, [r0]!
vmlal.s16 q4, d6, d2
vmlal.s16 q5, d7, d3
sub r11, r11, #16
cmp r11, #16
bge LoopDepth16
LoopDepth16Out:
vst1.32 {q4, q5}, [r7]!
vsubl.s8 q6, d1, d30
vld1.32 {q7, q8}, [r0]!
vmlal.s16 q7, d12, d4
vmlal.s16 q8, d13, d5
vst1.32 {q7, q8}, [r7]!
L8:
cmp r11, #8
blt L0
LoopDepth8:
vld1.8 {d0}, [r8]!
vld1.16 {d2, d3}, [r10]! // weight
vsubl.s8 q2, d0, d30 // -zp
vld1.32 {q3}, [r0]!
vmlal.s16 q3, d4, d2
vst1.32 {q3}, [r7]!
vld1.32 {q4}, [r0]!
vmlal.s16 q4, d5, d3
vst1.32 {q4}, [r7]!
sub r11, r11, #8
cmp r11, #8
bge LoopDepth8
L0:
cmp r11, #0
beq LoopDepthEnd
LoopDepth0:
ldrsb r12, [r8], #1
ldrsh r9, [r10], #2
sub r12, r12, r6
ldr lr, [r0], #4
smlabb r12, r12, r9, lr
str r12, [r7], #4
subs r11, r11, #1
bne L0
LoopDepthEnd:
add r1, r1, r5
subs r3, r3, #1
bne LoopPixel
End:
sub sp, sp, #104
vpop {q4-q7}
pop {r4-r8, r9-r12, pc}
#endif
#endif

@ -32,6 +32,13 @@ void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t
void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride,
int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini,
int32_t maxi);
#ifdef ENABLE_ARM
void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
int output_channel, int input_step, int8_t input_zp);
void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max);
#endif
#ifdef ENABLE_ARM64
void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res,
size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift,
@ -50,10 +57,6 @@ void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, con
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp,
int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,
int32_t *acc_min, int32_t *acc_max);
void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
int output_channel, int input_step, int8_t input_zp);
void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max);
void ConvDwInt8PostAlign4PerChannel(int8_t *dst, int32_t *buffer, int channel4, int32_t output_zp,
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min,
int32_t acc_max);

@ -20,7 +20,7 @@
#include "nnacl/int8/common_func.h"
/*conv depthwise int8 begin*/
#ifndef ENABLE_ARM64
#ifndef ENABLE_ARM
void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels,
int output_channel, int input_step, int8_t input_zp) {
for (int i = 0; i < num_pixels; i++) {
@ -59,7 +59,7 @@ void ConvDwInt8Post(int8_t *dst, int32_t *buffer, int output_w, int channel, int
} else {
int num_pixels = output_w * channel;
int align_num = 0;
#ifdef ENABLE_ARM64
#ifdef ENABLE_ARM
align_num = num_pixels / 4 * 4;
ConvDwInt8PostAlign4(dst, buffer, align_num, output_zp, out_multiplier[0], left_shift[0], right_shift[0], acc_min,
acc_max);

Loading…
Cancel
Save