!7643 [MSLITE][Develop] optimize arm cpu int8 op conv dw 3x3: add assembly arm32

Merge pull request !7643 from yangruoqi713/conv_dw
pull/7643/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 9d9f98768b

@ -0,0 +1,116 @@
#ifdef __arm__
#ifndef __aarch64__
.text
.align 5
.global ConvDw3x3BorderPixelInt8
#ifndef __APPLE__
.type ConvDw3x3BorderPixelInt8, %function
#endif
// void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
// size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp,
// size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max) {
// r0: dst, r1: src, r2: weight, r3: bias, r4: height, r5: width, r6: in_kh_step, r7: in_kw_step,
// r8: channel, r9: in_zp, r10: out_zp, r11: out_multiplier, r12: left_shift, r13: right_shift
// r14: acc_min, r15: acc_max
ConvDw3x3BorderPixelInt8:
// at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr"
// according to https://stackoverflow.com/questions/53625807
// even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway
// clang's rule seems more simple, though there are no subroutine calls here
// r4-r8 and q4-q7 must be saved according to https://static.docs.arm.com/ihi0042/i/aapcs32.pdf
push {r4-r8, r9-r12, lr}
vpush {q4-q7}
add sp, sp, #104
ldr r4, [sp]
ldr r5, [sp, #4]
ldr r6, [sp, #8]
ldr r7, [sp, #12]
ldr r8, [sp, #16]
ldrb r10, [sp, #20] // in_zp
vdup.8 d18, r10
ldr r10, [sp, #24] // out_zp
vdup.32 q15, r10
ldr r10, [sp, #28] // out_multiplier
vdup.32 q14, r10
ldr r10, [sp, #32] // left_shift
vdup.32 q13, r10
ldr r10, [sp, #36] // right_shift
vdup.32 q12, r10
ldr r10, [sp, #40] // acc_min
vdup.32 q11, r10
ldr r10, [sp, #44] // acc_max
vdup.32 q10, r10
mov r4, #2
mul lr, r8, r4
LoopC:
mov r9, r1
mov r10, r2
ldr r4, [sp]
vld1.32 {q3}, [r3]!
vld1.32 {q4}, [r3]!
LoopH:
mov r11, r9
mov r12, r10
ldr r5, [sp, #4]
LoopW:
vld1.8 {d0}, [r11], r7
vld1.16 {d2, d3}, [r12], lr // weight
vsubl.s8 q2, d0, d18 // -zp
vmlal.s16 q3, d4, d2
vmlal.s16 q4, d5, d3
subs r5, r5, #1
bne LoopW
subs r4, r4, #1
add r9, r9, r6
mov r11, #3
mul r5, lr, r11
add r10, r10, r5
bne LoopH
vshl.s32 q3, q3, q13
vqrdmulh.s32 q3, q3, q14
vand q5, q3, q12
vshr.s32 q5, q5, #31
vqadd.s32 q3, q3, q5
vrshl.s32 q3, q3, q12
vadd.i32 q3, q3, q15
vmax.s32 q3, q3, q11
vmin.s32 q3, q3, q10
vqmovn.s32 d14, q3
vshl.s32 q4, q4, q13
vqrdmulh.s32 q4, q4, q14
vand q6, q4, q12
vshr.s32 q6, q6, #31
vqadd.s32 q4, q4, q6
vrshl.s32 q4, q4, q12
vadd.i32 q4, q4, q15
vmax.s32 q4, q4, q11
vmin.s32 q4, q4, q10
vqmovn.s32 d15, q4
vqmovn.s16 d16, q7
vst1.8 {d16}, [r0]!
add r1, r1, #8
add r2, r2, #16
sub r8, r8, #8
cmp r8, #8
bge LoopC
sub sp, sp, #104
vpop {q4-q7}
pop {r4-r8, r9-r12, pc}
#endif
#endif

@ -41,66 +41,128 @@ ConvDw3x3BorderPixelInt8:
mul x14, x13, x9 // x8 * 3 * 2
LoopC:
ld1 {v23.4s}, [x3], #16
ld1 {v24.4s}, [x3], #16
mov x9, x1
mov x10, x2
mov x17, x4 // height
ld1 {v5.4s}, [x3], #16
mov v3.16b, v5.16b
ld1 {v6.4s}, [x3], #16
mov v4.16b, v6.16b
LoopH:
mov x11, x9
mov x12, x10
mov x18, x5 // width
LoopW:
ld1 {v0.8b}, [x11], x7
ssubl v1.8h, v0.8b, v25.8b
ld1 {v2.8h}, [x12], x13 // weight
smlal v3.4s, v1.4h, v2.4h
smlal2 v4.4s, v1.8h, v2.8h
subs x18, x18, #1
bne LoopW
subs x17, x17, #1
add x9, x9, x6
add x10, x10, x14
bne LoopH
sqshl v3.4s, v3.4s, v28.4s
sqshl v4.4s, v4.4s, v28.4s
sqrdmulh v3.4s, v3.4s, v27.4s
sqrdmulh v4.4s, v4.4s, v27.4s
and v12.16b, v29.16b, v3.16b
sshr v12.4s, v12.4s, #31
sqadd v3.4s, v3.4s, v12.4s
srshl v3.4s, v3.4s, v29.4s
and v11.16b, v29.16b, v4.16b
sshr v11.4s, v11.4s, #31
sqadd v4.4s, v4.4s, v11.4s
srshl v4.4s, v4.4s, v29.4s
add v3.4s, v3.4s, v26.4s
add v4.4s, v4.4s, v26.4s
smax v3.4s, v3.4s, v30.4s
smax v4.4s, v4.4s, v30.4s
smin v3.4s, v3.4s, v31.4s
smin v4.4s, v4.4s, v31.4s
sqxtn v3.4h, v3.4s
sqxtn v4.4h, v4.4s
sqxtn v3.8b, v3.8h
sqxtn v4.8b, v4.8h
st1 {v3.s}[0], [x0], #4
st1 {v4.s}[0], [x0], #4
add x1, x1, #8
add x2, x2, #16
sub x8, x8, #8
cmp x8, #8
bge LoopC
cmp x4, #2
blt LoopHW
LoopH2W2:
cmp x5, #2
blt LoopHW
ld1 {v0.8b}, [x9], x7
ssubl v0.8h, v0.8b, v25.8b
add x11, x1, x6
ld1 {v4.8h}, [x10], x13 // weight
smlal v23.4s, v0.4h, v4.4h
smlal2 v24.4s, v0.8h, v4.8h
add x12, x2, x14
ld1 {v1.8b}, [x9], x7
ssubl v1.8h, v1.8b, v25.8b
ld1 {v5.8h}, [x10], x13
smlal v23.4s, v1.4h, v5.4h
smlal2 v24.4s, v1.8h, v5.8h
add x15, x11, x6
ld1 {v2.8b}, [x11], x7
ssubl v2.8h, v2.8b, v25.8b
add x16, x12, x14
ld1 {v6.8h}, [x12], x13
smlal v23.4s, v2.4h, v6.4h
smlal2 v24.4s, v2.8h, v6.8h
ld1 {v3.8b}, [x11], x7
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
smlal v23.4s, v3.4h, v7.4h
smlal2 v24.4s, v3.8h, v7.8h
cmp x5, #3
beq LoopH2W3
cmp x4, #3
beq LoopH3W2
b Post
LoopH2W3:
ld1 {v16.8b}, [x9], x7
ssubl v16.8h, v16.8b, v25.8b
ld1 {v17.8h}, [x10], x13
smlal v23.4s, v16.4h, v17.4h
smlal2 v24.4s, v16.8h, v17.8h
ld1 {v18.8b}, [x11], x7
ssubl v18.8h, v18.8b, v25.8b
ld1 {v19.8h}, [x12], x13
smlal v23.4s, v18.4h, v19.4h
smlal2 v24.4s, v18.8h, v19.8h
b Post
LoopH3W2:
ld1 {v16.8b}, [x15], x7
ssubl v16.8h, v16.8b, v25.8b
ld1 {v17.8h}, [x16], x13
smlal v23.4s, v16.4h, v17.4h
smlal2 v24.4s, v16.8h, v17.8h
ld1 {v18.8b}, [x15], x7
ssubl v18.8h, v18.8b, v25.8b
ld1 {v19.8h}, [x16], x13
smlal v23.4s, v18.4h, v19.4h
smlal2 v24.4s, v18.8h, v19.8h
b Post
LoopHW:
mov x9, x1
mov x10, x2
mov x17, x4 // height
LoopH:
mov x11, x9
mov x12, x10
mov x18, x5 // width
LoopW:
ld1 {v0.8b}, [x11], x7
ssubl v1.8h, v0.8b, v25.8b
ld1 {v2.8h}, [x12], x13 // weight
smlal v23.4s, v1.4h, v2.4h
smlal2 v24.4s, v1.8h, v2.8h
subs x18, x18, #1
bne LoopW
subs x17, x17, #1
add x9, x9, x6
add x10, x10, x14
bne LoopH
Post:
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
and v12.16b, v29.16b, v23.16b
sshr v12.4s, v12.4s, #31
sqadd v23.4s, v23.4s, v12.4s
srshl v23.4s, v23.4s, v29.4s
and v11.16b, v29.16b, v24.16b
sshr v11.4s, v11.4s, #31
sqadd v24.4s, v24.4s, v11.4s
srshl v24.4s, v24.4s, v29.4s
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
smax v24.4s, v24.4s, v30.4s
smin v23.4s, v23.4s, v31.4s
smin v24.4s, v24.4s, v31.4s
sqxtn v23.4h, v23.4s
sqxtn v24.4h, v24.4s
sqxtn v23.8b, v23.8h
sqxtn v24.8b, v24.8h
st1 {v23.s}[0], [x0], #4
st1 {v24.s}[0], [x0], #4
add x1, x1, #8
add x2, x2, #16
sub x8, x8, #8
cmp x8, #8
bge LoopC
ret
#endif

@ -47,6 +47,10 @@ void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, con
size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp,
int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,
int32_t *acc_min, int32_t *acc_max);
void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp,
size_t out_zp, size_t out_multiplier, size_t left_shift, size_t right_shift,
size_t acc_min, size_t acc_max);
#endif
#ifdef ENABLE_ARM32
@ -67,10 +71,6 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei
void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp,
size_t out_zp, size_t out_multiplier, size_t left_shift, size_t right_shift,
size_t acc_min, size_t acc_max);
#endif
#ifdef __cplusplus
}

@ -140,11 +140,8 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da
/*conv depthwise 3x3 int8 begin*/
bool CheckIfUse3X3(const ConvParameter *conv_param, int channel) {
bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 &&
(conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) &&
(conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) &&
conv_param->stride_h_ == conv_param->stride_w_ &&
(conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) &&
bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 &&
conv_param->stride_w_ == 1 && (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) &&
(conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && conv_param->pad_u_ == conv_param->pad_l_ &&
conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (channel % 8 == 0);
return use_3x3;
@ -303,7 +300,7 @@ void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data
}
}
#ifndef ENABLE_ARM64
#ifndef ENABLE_ARM
void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp,
int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) {

@ -172,7 +172,21 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *>
auto act_quant_size =
MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size());
if (act_quant_size == 1) { // per tensor
kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
auto conv_parm = reinterpret_cast<ConvParameter *>(opParameter);
auto channel = inputs[kWeightIndex]->shape()[0];
auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size();
if (CheckIfUse3X3(conv_parm, channel) && weight_quant_size == 1) {
#ifdef ENABLE_ARM64
kernel =
new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
#else
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
#endif
} else {
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
}
} else { // per channel
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);

Loading…
Cancel
Save