enable int8 kernel on arm32

pull/6323/head
lixian 4 years ago
parent 19bdba56f9
commit dcaf76a800

@ -9,8 +9,8 @@
#endif
// void IndirectGemmInt8_2x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier,
// size_t shift_before, size_t shift_after);
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier,
// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
// r0: output, r1: input, r2: weight, r3: bias, r4: kSize, r5: ic4, r6: oc, r7: offset
// r8: input_sum, r10: act_min, r11: act_max, r10: out_zp, r11: out_multiplier, r10: shift_before, r11: shift_after
IndirectGemmInt8_2x4:
@ -24,7 +24,7 @@ IndirectGemmInt8_2x4:
veor q15, q15, q15
.endm
// at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr"
// at return, clang generates "push {lr}, pop {pc}"" while gcc will generate "bx lr"
// according to https://stackoverflow.com/questions/53625807
// even if we jump to link register instead of saving it, we still have to save it in subroutine calls anyway
// clang's rule seems more simple, though there are no subroutine calls here
@ -127,10 +127,6 @@ IndirectGemmInt8_2x4:
vpadal.s16 q14, q6
vpadal.s16 q15, q7
// load sum
ldr r10, [sp, #16]
vld1.32 q0[], [r10]!
vld1.32 q1[], [r10]!
// pairwise add
vpadd.i32 d16, d16, d17
vpadd.i32 d18, d18, d19
@ -145,8 +141,27 @@ IndirectGemmInt8_2x4:
vpadd.i32 d17, d20, d22
vpadd.i32 d24, d24, d26
vpadd.i32 d25, d28, d30
// load sum
ldr lr, [sp, #44]
cmp lr, #0
beq NoSum
ldr r10, [sp, #16]
ldr lr, [sp, #48]
cmp lr, #0
beq SymSum
ldr lr, [sp, #52]
vld1.32 q0, [r10]
add r10, r10, lr
vld1.32 q1, [r10]
b AddSum
SymSum:
vld1.32 q0[], [r10]!
vld1.32 q1[], [r10]!
AddSum:
vsub.i32 q8, q8, q0
vsub.i32 q12, q12, q1
NoSum:
cmp r3, #0
beq NoBias
vld1.32 q2, [r3]
@ -154,18 +169,30 @@ IndirectGemmInt8_2x4:
vadd.i32 q12, q12, q2
NoBias:
ldr r10, [sp, #36]
vdup.32 q3, r10
ldr lr, [sp, #48]
cmp lr, #0
bne PerChannel
ldr lr, [sp, #36]
vld1.32 q3[], [lr]
ldr lr, [sp, #32]
vld1.32 q4[], [lr]
ldr lr, [sp, #40]
vld1.32 q5[], [lr]
b QuantizeStart
PerChannel:
ldr lr, [sp, #36]
vld1.32 q3, [lr]
ldr lr, [sp, #32]
vld1.32 q4, [lr]
ldr lr, [sp, #40]
vld1.32 q5, [lr]
QuantizeStart:
vshl.s32 q8, q8, q3
vshl.s32 q12, q12, q3
ldr r10, [sp, #32]
vdup.32 q4, r10
vqrdmulh.s32 q8, q8, q4
vqrdmulh.s32 q12, q12, q4
ldr r10, [sp, #40]
vdup.32 q5, r10
vand q3, q5, q8
vshr.s32 q3, q3, #31
vqadd.s32 q8, q8, q3
@ -192,7 +219,7 @@ IndirectGemmInt8_2x4:
vqmovn.s32 d30, q8
vqmovn.s32 d31, q12
vqmovn.s16 d0, q14
vqmovn.s16 d0, q15
// prefetching is not prefered while writing results in spite of cache missings
// you could try prfm pstl2strm
@ -234,6 +261,26 @@ IndirectGemmInt8_2x4:
cmp r6, #4
ble LoopOcEnd
ldr lr, [sp, #48]
cmp lr, #0
beq NoChannelForward
ldr lr, [sp, #44]
cmp lr, #0
beq NoSumForward
ldr lr, [sp, #16]
add lr, lr, #16
str lr, [sp, #16]
NoSumForward:
ldr lr, [sp, #36]
add lr, lr, #16
str lr, [sp, #36]
ldr lr, [sp, #32]
add lr, lr, #16
str lr, [sp, #32]
ldr lr, [sp, #40]
add lr, lr, #16
str lr, [sp, #40]
NoChannelForward:
sub r6, r6, #4
cmp r3, #0
beq NoStepFowrard

@ -8,8 +8,8 @@
#endif
// void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
// int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel);
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier,
// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset
IndirectGemmInt8_4x4:
@ -52,10 +52,7 @@ IndirectGemmInt8_4x4:
ldr x19, [sp, #48]
ldr x20, [sp, #56]
ldr x21, [sp, #64]
add x24, x6, #3
mov x23, #4
sdiv x23, x24, x23
ldr x23, [sp, #72]
mul x5, x4, x5
mov x4, #1
@ -218,10 +215,10 @@ IndirectGemmInt8_4x4:
// load sum
mov x22, x15
cbz x21, SymSum
ld1r {v8.4s}, [x22], x23
ld1r {v9.4s}, [x22], x23
ld1r {v10.4s}, [x22], x23
ld1r {v11.4s}, [x22]
ld1 {v8.4s}, [x22], x23
ld1 {v9.4s}, [x22], x23
ld1 {v10.4s}, [x22], x23
ld1 {v11.4s}, [x22]
b AddSum
SymSum:
ld1r {v8.4s}, [x22], #4

@ -9,7 +9,7 @@
// void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier,
// int32_t *shift_before, int32_t *shift_after);
// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset
// we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later)
// mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A)
@ -148,10 +148,7 @@ IndirectGemmInt8_24x4_dp:
ldr x19, [sp, #48]
ldr x20, [sp, #56]
ldr x21, [sp, #64]
add x24, x6, #3
mov x23, #4
sdiv x23, x24, x23
ldr x23, [sp, #72]
mul x5, x4, x5
mov x4, #1

@ -37,18 +37,25 @@ void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *
int output_channel, int input_step, int8_t input_zp);
void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier,
int32_t left_shift, int32_t right_shift, int32_t acc_min, int32_t acc_max);
void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8,
size_t oc4, size_t offset);
#endif
#ifdef ENABLE_ARM32
void IndirectGemmInt8_2x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
#endif
#ifdef ENABLE_ARM64
void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res,
size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift,
int32_t zp, int32_t mini, int32_t maxi);
void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8,
size_t oc4, size_t offset);
void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
int32_t *shift_after, size_t asymmetric, size_t per_channel);
int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset);
void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);

@ -28,15 +28,21 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0];
int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0];
int oc4 = UP_DIV(output_channel, C4NUM);
#ifdef ENABLE_ARM64
size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC;
size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
IndirectGemmInt8_4x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel,
output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier,
shift_before, shift_after, asymmetric, per_channel);
shift_before, shift_after, asymmetric, per_channel, oc4 * C4NUM * sizeof(int32_t));
#elif ENABLE_ARM32
size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC;
size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
IndirectGemmInt8_2x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel,
output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier,
shift_before, shift_after, asymmetric, per_channel, oc4 * C4NUM * sizeof(int32_t));
#else
int oc4 = UP_DIV(output_channel, C4NUM);
int tile_num = conv_param->tile_num_;
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
for (int oc = 0; oc < output_channel; oc++) {
@ -201,7 +207,7 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const
void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, int oc, int ic8, size_t real_cal_num) {
int oc4 = UP_DIV(oc, C4NUM);
#ifdef ENABLE_ARM64
#ifdef ENABLE_ARM
IndirectGemmInt16to32_8x4(dst, src, weight, 16, ic8, oc4, oc4 * 4 * 16 * sizeof(int32_t));
#else
const int input_unit_square = 16;

@ -33,10 +33,10 @@ using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
void ConvolutionInt8CPUKernel::CheckSupportOptimize() {
tile_num_ = 24;
// #ifdef ENABLE_ARM32
// tile_num_ = 2;
// support_optimize_ = false;
// #endif
#ifdef ENABLE_ARM32
tile_num_ = 2;
support_optimize_ = false;
#endif
#ifdef ENABLE_ARM64
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
@ -380,7 +380,11 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::Tensor *> &
int dilation_w = conv_param->dilation_w_;
kernel::LiteKernel *kernel;
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
#ifdef ENABLE_ARM32
kernel = new (std::nothrow) kernel::Convolution3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
#else
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
#endif
} else if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else {

Loading…
Cancel
Save