[MSLITE][Develop] optimize arm cpu int8 depthwise: 3x3 support perchannel

pull/8619/head
yangruoqi713 4 years ago
parent 51a57b9243
commit bcb21050f4

@ -10,8 +10,10 @@
// void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
// size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp,
// size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max) {
// size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max
// size_t per_channel) {
// todo: support per channel
// r0: dst, r1: src, r2: weight, r3: bias, r4: height, r5: width, r6: in_kh_step, r7: in_kw_step,
// r8: channel, r9: in_zp, r10: out_zp, r11: out_multiplier, r12: left_shift, r13: right_shift
// r14: acc_min, r15: acc_max

@ -10,7 +10,8 @@
// void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size,
// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp,
// int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max)
// int *out_multiplier, int *left_shift, int *right_shift, int32_t acc_min, int32_t acc_max,
// size_t per_channel)
//
// x0: output
// x1: input
@ -28,33 +29,46 @@
// w13: right_shift
// w14: acc_min
// w15: acc_max
// w16: per_channel
ConvDw3x3Int8Neon64:
sub sp, sp, #160
sub sp, sp, #176
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
ldr w8, [sp]
ldr w9, [sp, #8]
ldr w10, [sp, #16]
ldr w11, [sp, #24]
ldr w12, [sp, #32]
ldr w13, [sp, #40]
ldr w14, [sp, #48]
ldr w15, [sp, #56]
stp x23, x24, [sp], #16
ldr x8, [sp]
ldr x9, [sp, #8]
ldr x10, [sp, #16]
ldr x11, [sp, #24]
ldr x12, [sp, #32]
ldr x13, [sp, #40]
ldr x14, [sp, #48]
ldr x15, [sp, #56]
ldr x23, [sp, #64] // per_channel
add x19, x3, #16
add w20, w6, w6 // channel * 2
add w21, w4, w4 // col_size * 2
dup v25.8b, w9
dup v26.4s, w12
dup v27.4s, w11
dup v28.4s, w13
cbnz w23, PER_CHANNEL_DUMP
PER_LAYER_DUMP:
ld1r {v27.4s}, [x11] // out_multiplier
ld1r {v26.4s}, [x12] // left_shift
ld1r {v28.4s}, [x13] // right_shift
b MAIN_FUC
PER_CHANNEL_DUMP:
ld1 {v27.4s}, [x11]
ld1 {v26.4s}, [x12]
ld1 {v28.4s}, [x13]
MAIN_FUC:
dup v29.4s, w10
dup v30.4s, w14
dup v31.4s, w15
ldr w24, [x12]
// Load weights
ld1 {v0.8h}, [x2], x20
@ -158,7 +172,8 @@ HEIGHT1_LOOP:
smlal v23.4s, v8.4h, v20.4h
smlal2 v24.4s, v8.8h, v20.8h
cbz w12, SKIP_LEFTSHIFT1
cbnz w23, PER_CHANNEL_POST1
cbz w24, SKIP_LEFTSHIFT1
sqshl v21.4s, v21.4s, v26.4s
sqshl v22.4s, v22.4s, v26.4s
sqshl v23.4s, v23.4s, v26.4s
@ -178,6 +193,27 @@ SKIP_LEFTSHIFT1:
sqrshl v22.4s, v22.4s, v28.4s
sqrshl v23.4s, v23.4s, v28.4s
sqrshl v24.4s, v24.4s, v28.4s
b OUTZP1
PER_CHANNEL_POST1:
sqshl v21.4s, v21.4s, v26.4s
sqshl v23.4s, v23.4s, v26.4s
sqrdmulh v21.4s, v21.4s, v27.4s
sqrdmulh v23.4s, v23.4s, v27.4s
ldr q26, [x12, #16]
sqrshl v21.4s, v21.4s, v28.4s
sqrshl v23.4s, v23.4s, v28.4s
ldr q27, [x11, #16]
sqshl v22.4s, v22.4s, v26.4s
sqshl v24.4s, v24.4s, v26.4s
ldr q28, [x13, #16]
sqrdmulh v22.4s, v22.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
ld1 {v26.4s}, [x12]
sqrshl v22.4s, v22.4s, v28.4s
sqrshl v24.4s, v24.4s, v28.4s
ld1 {v27.4s}, [x11]
ld1 {v28.4s}, [x13]
OUTZP1:
// Add output zero point
@ -271,7 +307,8 @@ WIDTH2_LEFT:
smlal v23.4s, v8.4h, v20.4h
smlal2 v24.4s, v8.8h, v20.8h
cbz w12, SKIP_LEFTSHIFT2
cbnz w23, PER_CHANNEL_POST2
cbz w24, SKIP_LEFTSHIFT2
sqshl v21.4s, v21.4s, v26.4s
sqshl v22.4s, v22.4s, v26.4s
sqshl v23.4s, v23.4s, v26.4s
@ -291,6 +328,24 @@ SKIP_LEFTSHIFT2:
sqrshl v22.4s, v22.4s, v28.4s
sqrshl v23.4s, v23.4s, v28.4s
sqrshl v24.4s, v24.4s, v28.4s
b OUTZP2
PER_CHANNEL_POST2:
sqshl v21.4s, v21.4s, v26.4s
sqshl v23.4s, v23.4s, v26.4s
sqrdmulh v21.4s, v21.4s, v27.4s
sqrdmulh v23.4s, v23.4s, v27.4s
ldr q26, [x12, #16]
sqrshl v21.4s, v21.4s, v28.4s
sqrshl v23.4s, v23.4s, v28.4s
ldr q27, [x11, #16]
sqshl v22.4s, v22.4s, v26.4s
sqshl v24.4s, v24.4s, v26.4s
ldr q28, [x13, #16]
sqrdmulh v22.4s, v22.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v22.4s, v22.4s, v28.4s
sqrshl v24.4s, v24.4s, v28.4s
OUTZP2:
// Add output zero point
@ -342,7 +397,8 @@ WIDTH1_LEFT:
smlal v21.4s, v8.4h, v19.4h
smlal2 v22.4s, v8.8h, v19.8h
cbz w12, SKIP_LEFTSHIFT3
cbnz w23, PER_CHANNEL_POST3
cbz w24, SKIP_LEFTSHIFT3
sqshl v21.4s, v21.4s, v26.4s
sqshl v22.4s, v22.4s, v26.4s
sqrdmulh v21.4s, v21.4s, v27.4s
@ -354,6 +410,18 @@ SKIP_LEFTSHIFT3:
sqrdmulh v22.4s, v22.4s, v27.4s
sqrshl v21.4s, v21.4s, v28.4s
sqrshl v22.4s, v22.4s, v28.4s
b OUTZP3
PER_CHANNEL_POST3:
sqshl v21.4s, v21.4s, v26.4s
sqrdmulh v21.4s, v21.4s, v27.4s
ldr q26, [x12, #16]
sqrshl v21.4s, v21.4s, v28.4s
ldr q27, [x11, #16]
sqshl v22.4s, v22.4s, v26.4s
ldr q28, [x13, #16]
sqrdmulh v22.4s, v22.4s, v27.4s
sqrshl v22.4s, v22.4s, v28.4s
OUTZP3:
// Add output zero point
@ -374,11 +442,12 @@ OUTZP3:
st1 {v21.8b}, [x0], x6
End:
sub sp, sp, #160
sub sp, sp, #176
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ret
#endif

@ -8,55 +8,68 @@
#endif
// void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
// size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max)
// size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, int32_t *out_multiplier,
// int32_t *left_shift, int32_t *right_shift, size_t acc_min, size_t acc_max, size_t per_channel)
// x0: dst, x1: src, x2: weight, x3: bias, x4: in_kh_step, x5: in_kw_step,
// x6: channel, x7: in_zp, x8: out_zp, x9: out_multiplier, x10: left_shift, x11: right_shift
// x11: acc_min, x13: acc_max
// x12: acc_min, x13: acc_max, x14: per_channel
ConvDw3x3Int8Corner:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
dup v25.8b, w7 // in_zp
ldr x9, [sp]
dup v26.4s, w9 // out_zp
ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier
ldr x8, [sp, #16]
dup v28.4s, w8 // left_shift
ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #32]
dup v30.4s, w9 // acc_min
ldr x9, [sp, #40]
dup v31.4s, w9 // acc_max
mov x9, #2
mul x13, x6, x9 // x6 * 2
mov x9, #3
mul x14, x13, x9 // x6 * 3 * 2
sub sp, sp, #32
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
dup v25.8b, w7 // in_zp
ldr x8, [sp]
dup v26.4s, w8 // out_zp
ldr x9, [sp, #8] // out_multiplier
ldr x10, [sp, #16] // left_shift
ldr x11, [sp, #24] // right_shift
ldr x12, [sp, #32]
dup v30.4s, w12 // acc_min
ldr x13, [sp, #40]
dup v31.4s, w13 // acc_max
ldr x14, [sp, #48] // per_channel
cbnz x14, PerChannelDump
PerLayerDump:
ld1r {v27.4s}, [x9]
ld1r {v28.4s}, [x10]
ld1r {v29.4s}, [x11]
b ContinueFunc
PerChannelDump:
ld1 {v27.4s}, [x9], #16
ld1 {v28.4s}, [x10], #16
ld1 {v29.4s}, [x11], #16
ContinueFunc:
mov x12, #2
mul x21, x6, x12 // x6 * 2
mov x12, #3
mul x22, x21, x12 // x6 * 3 * 2
ld1 {v23.4s}, [x3], #16
ld1 {v24.4s}, [x3], #16
mov x9, x1
mov x10, x2
mov x12, x1
mov x13, x2
ld1 {v0.8b}, [x9], x5
ld1 {v0.8b}, [x12], x5
ssubl v0.8h, v0.8b, v25.8b
add x11, x1, x4
ld1 {v4.8h}, [x10], x13 // weight
add x12, x2, x14
ld1 {v1.8b}, [x9], x5
add x19, x1, x4
ld1 {v4.8h}, [x13], x21 // weight
add x20, x2, x22
ld1 {v1.8b}, [x12], x5
ssubl v1.8h, v1.8b, v25.8b
ld1 {v5.8h}, [x10], x13
ld1 {v2.8b}, [x11], x5
ld1 {v5.8h}, [x13], x21
ld1 {v2.8b}, [x19], x5
ssubl v2.8h, v2.8b, v25.8b
ld1 {v6.8h}, [x12], x13
ld1 {v3.8b}, [x11], x5
ld1 {v6.8h}, [x20], x21
ld1 {v3.8b}, [x19], x5
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
ld1 {v7.8h}, [x20], x21
cmp x6, #8
ble LoopC8Post
@ -66,41 +79,54 @@ ConvDw3x3Int8Corner:
add x2, x2, #16
smlal v23.4s, v0.4h, v4.4h
smlal2 v24.4s, v0.8h, v4.8h
mov x9, x1
mov x10, x2
ld1 {v0.8b}, [x9], x5
mov x12, x1
mov x13, x2
ld1 {v0.8b}, [x12], x5
ssubl v0.8h, v0.8b, v25.8b
ld1 {v4.8h}, [x10], x13 // weight
add x11, x1, x4
ld1 {v4.8h}, [x13], x21 // weight
add x19, x1, x4
smlal v23.4s, v1.4h, v5.4h
smlal2 v24.4s, v1.8h, v5.8h
add x12, x2, x14
ld1 {v1.8b}, [x9], x5
add x20, x2, x22
ld1 {v1.8b}, [x12], x5
ssubl v1.8h, v1.8b, v25.8b
smlal v23.4s, v2.4h, v6.4h
ld1 {v5.8h}, [x10], x13
ld1 {v5.8h}, [x13], x21
smlal2 v24.4s, v2.8h, v6.8h
ld1 {v2.8b}, [x11], x5
ld1 {v2.8b}, [x19], x5
ssubl v2.8h, v2.8b, v25.8b
smlal v23.4s, v3.4h, v7.4h
ld1 {v6.8h}, [x12], x13
ld1 {v6.8h}, [x20], x21
smlal2 v24.4s, v3.8h, v7.8h
ld1 {v3.8b}, [x11], x5
ld1 {v3.8b}, [x19], x5
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
cbz w8, RightShiftLoop
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZpLoop
RightShiftLoop:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
ld1 {v7.8h}, [x20], x21
cbnz x14, PerChannelPostLoop
ldr w8, [x10]
cbz w8, RightShiftLoop
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZpLoop
RightShiftLoop:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
b AddZpLoop
PerChannelPostLoop:
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
ld1 {v28.4s}, [x10], #16
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
ld1 {v27.4s}, [x9], #16
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
ld1 {v29.4s}, [x11], #16
AddZpLoop:
add v23.4s, v23.4s, v26.4s
@ -119,6 +145,11 @@ ConvDw3x3Int8Corner:
st1 {v24.s}[0], [x0], #4
ld1 {v23.4s}, [x3], #16
ld1 {v24.4s}, [x3], #16
cbz x14, NEXT_LOOP
ld1 {v27.4s}, [x9], #16
ld1 {v28.4s}, [x10], #16
ld1 {v29.4s}, [x11], #16
NEXT_LOOP:
sub x6, x6, #8
cmp x6, #8
bgt LoopC8
@ -133,18 +164,31 @@ ConvDw3x3Int8Corner:
smlal v23.4s, v3.4h, v7.4h
smlal2 v24.4s, v3.8h, v7.8h
cbz w8, RightShift
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZp
RightShift:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
cbnz x14, PerChannelPost
ldr w8, [x10]
cbz w8, RightShift
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZp
RightShift:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
b AddZp
PerChannelPost:
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
ld1 {v28.4s}, [x10], #16
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
ld1 {v27.4s}, [x9], #16
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
ld1 {v29.4s}, [x11], #16
AddZp:
add v23.4s, v23.4s, v26.4s
@ -161,5 +205,9 @@ ConvDw3x3Int8Corner:
st1 {v23.s}[0], [x0], #4
st1 {v24.s}[0], [x0], #4
sub sp, sp, #32
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ret
#endif

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -68,7 +68,8 @@ void IndirectGemmInt8_2x4(int8_t *output, const int8_t *input, const int8_t *wei
int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp,
int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max);
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min,
int32_t acc_max, size_t per_channel);
#endif
#ifdef ENABLE_ARM64
@ -81,23 +82,23 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei
int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max);
int32_t out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,
int32_t acc_min, int32_t acc_max, size_t per_channel);
void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max);
int32_t out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift,
int32_t acc_min, int32_t acc_max, size_t per_channel);
void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max);
size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, int32_t *out_multiplier,
int32_t *left_shift, int32_t *right_shift, size_t acc_min, size_t acc_max, size_t per_channel);
void ConvDw3x3Int8Vertical(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias,
size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp,
size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min,
size_t acc_max);
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, size_t acc_min,
size_t acc_max, size_t per_channel);
void ConvDw3x3Int8Horizontal(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias,
size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp,
size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min,
size_t acc_max);
int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, size_t acc_min,
size_t acc_max, size_t per_channel);
#endif
#ifdef __cplusplus
}

File diff suppressed because it is too large Load Diff

@ -180,8 +180,7 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *>
conv_param->output_h_ = outputs[kOutputIndex]->Height();
conv_param->output_w_ = outputs[kOutputIndex]->Width();
}
auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size();
if (CheckConvDwUse3X3(conv_param) && conv_param->input_channel_ % C8NUM == 0 && weight_quant_size == 1) {
if (CheckConvDwUse3X3(conv_param) && conv_param->input_channel_ % C8NUM == 0) {
#ifdef ENABLE_ARM64
kernel =
new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);

Loading…
Cancel
Save