diff --git a/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8_3x3.S b/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8_3x3.S new file mode 100644 index 0000000000..a31be4f7ee --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8_3x3.S @@ -0,0 +1,281 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global ConvDw3x3Int8Neon64 +#ifndef __APPLE__ +.type ConvDw3x3Int8Neon64, %function +#endif + + +// void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size, +// int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, +// int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) +// +// x0: output +// x1: input +// x2: weight +// x3: bias +// w4: col_size +// w5: row_size +// w6: channel +// w7: output_h +// w8: output_w +// w9: in_zp +// w10: out_zp +// w11: out_multiplier +// w12: left_shift +// w13: right_shift +// w14: acc_min +// w15: acc_max + +ConvDw3x3Int8Neon64: + sub sp, sp, #160 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] + ldr w11, [sp, #24] + ldr w12, [sp, #32] + ldr w13, [sp, #40] + ldr w14, [sp, #48] + ldr w15, [sp, #56] + + add x19, x3, #16 + add w20, w6, w6 // channel * 2 + add w21, w4, w4 // col_size * 2 + dup v25.8b, w9 + dup v26.4s, w12 + dup v27.4s, w11 + dup v28.4s, w13 + dup v29.4s, w10 + dup v30.4s, w14 + dup v31.4s, w15 + + // Load weights + ld1 {v0.8h}, [x2], x20 + ld1 {v1.8h}, [x2], x20 + ld1 {v2.8h}, [x2], x20 + ld1 {v3.8h}, [x2], x20 + ld1 {v4.8h}, [x2], x20 + ld1 {v5.8h}, [x2], x20 + ld1 {v6.8h}, [x2], x20 + ld1 {v7.8h}, [x2], x20 + ld1 {v8.8h}, [x2], x20 + +Loop: + mov x16, x1 + add x17, x16, x5 + add x18, x17, x5 + ld1 {v9.8b}, [x16], x4 + ld1 {v10.8b}, [x16], x4 + ld1 {v11.8b}, [x16], x4 + ld1 {v13.8b}, [x17], x4 + ld1 {v14.8b}, [x17], x4 + ld1 {v15.8b}, [x17], x4 + ld1 {v17.8b}, [x18], x4 + ld1 {v18.8b}, [x18], x4 + ld1 {v19.8b}, [x18], x4 + + ld1 {v21.4s}, [x3] + ld1 {v22.4s}, [x19] + + // subtract input zp + ssubl v9.8h, v9.8b, v25.8b + ssubl v10.8h, v10.8b, v25.8b + ssubl v11.8h, v11.8b, v25.8b + ssubl v13.8h, v13.8b, v25.8b + ssubl v14.8h, v14.8b, v25.8b + ssubl v15.8h, v15.8b, v25.8b + ssubl v17.8h, v17.8b, v25.8b + ssubl v18.8h, v18.8b, v25.8b + ssubl v19.8h, v19.8b, v25.8b + + cmp w8, #1 + beq Width1 + +Width2: + ld1 {v12.8b}, [x16] + ld1 {v16.8b}, [x17] + ld1 {v20.8b}, [x18] + + ld1 {v23.4s}, [x3] + ld1 {v24.4s}, [x19] + + ssubl v12.8h, v12.8b, v25.8b + ssubl v16.8h, v16.8b, v25.8b + ssubl v20.8h, v20.8b, v25.8b + + smlal v21.4s, v0.4h, v9.4h + smlal2 v22.4s, v0.8h, v9.8h + smlal v23.4s, v0.4h, v10.4h + smlal2 v24.4s, v0.8h, v10.8h + smlal v21.4s, v1.4h, v10.4h + smlal2 v22.4s, v1.8h, v10.8h + smlal v23.4s, v1.4h, v11.4h + smlal2 v24.4s, v1.8h, v11.8h + smlal v21.4s, v2.4h, v11.4h + smlal2 v22.4s, v2.8h, v11.8h + smlal v23.4s, v2.4h, v12.4h + smlal2 v24.4s, v2.8h, v12.8h + smlal v21.4s, v3.4h, v13.4h + smlal2 v22.4s, v3.8h, v13.8h + smlal v23.4s, v3.4h, v14.4h + smlal2 v24.4s, v3.8h, v14.8h + smlal v21.4s, v4.4h, v14.4h + smlal2 v22.4s, v4.8h, v14.8h + smlal v23.4s, v4.4h, v15.4h + smlal2 v24.4s, v4.8h, v15.8h + smlal v21.4s, v5.4h, v15.4h + smlal2 v22.4s, v5.8h, v15.8h + smlal v23.4s, v5.4h, v16.4h + smlal2 v24.4s, v5.8h, v16.8h + smlal v21.4s, v6.4h, v17.4h + smlal2 v22.4s, v6.8h, v17.8h + smlal v23.4s, v6.4h, v18.4h + smlal2 v24.4s, v6.8h, v18.8h + smlal v21.4s, v7.4h, v18.4h + smlal2 v22.4s, v7.8h, v18.8h + smlal v23.4s, v7.4h, v19.4h + smlal2 v24.4s, v7.8h, v19.8h + smlal v21.4s, v8.4h, v19.4h + smlal2 v22.4s, v8.8h, v19.8h + smlal v23.4s, v8.4h, v20.4h + smlal2 v24.4s, v8.8h, v20.8h + + // Apply left shfit + sqshl v21.4s, v21.4s, v26.4s + sqshl v22.4s, v22.4s, v26.4s + sqshl v23.4s, v23.4s, v26.4s + sqshl v24.4s, v24.4s, v26.4s + + // Apply the fixed-point part of the multiplier. + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + sqrdmulh v23.4s, v23.4s, v27.4s + sqrdmulh v24.4s, v24.4s, v27.4s + + // Apply right shfit + and v9.16b, v28.16b, v21.16b + sshr v9.4s, v9.4s, #31 + sqadd v21.4s, v21.4s, v9.4s + srshl v21.4s, v21.4s, v28.4s + and v10.16b, v28.16b, v22.16b + sshr v10.4s, v10.4s, #31 + sqadd v22.4s, v22.4s, v10.4s + srshl v22.4s, v22.4s, v28.4s + and v11.16b, v28.16b, v23.16b + sshr v11.4s, v11.4s, #31 + sqadd v23.4s, v23.4s, v11.4s + srshl v23.4s, v23.4s, v28.4s + and v12.16b, v28.16b, v24.16b + sshr v12.4s, v12.4s, #31 + sqadd v24.4s, v24.4s, v12.4s + srshl v24.4s, v24.4s, v28.4s + + // Add output zero point + sqadd v21.4s, v21.4s, v29.4s + sqadd v22.4s, v22.4s, v29.4s + sqadd v23.4s, v23.4s, v29.4s + sqadd v24.4s, v24.4s, v29.4s + + // Apply min bound + smax v21.4s, v21.4s, v30.4s + smax v22.4s, v22.4s, v30.4s + smax v23.4s, v23.4s, v30.4s + smax v24.4s, v24.4s, v30.4s + + // Apply max bound + smin v21.4s, v21.4s, v31.4s + smin v22.4s, v22.4s, v31.4s + smin v23.4s, v23.4s, v31.4s + smin v24.4s, v24.4s, v31.4s + + sqxtn v21.4h, v21.4s + sqxtn2 v21.8h, v22.4s + sqxtn v23.4h, v23.4s + sqxtn2 v23.8h, v24.4s + sqxtn v21.8b, v21.8h + sqxtn2 v21.16b, v23.8h + + st1 {v21.8b}, [x0], x6 + mov v23.d[0], v21.d[1] + st1 {v23.8b}, [x0], x6 + sub w8, w8, #2 + cbz w8, End + add x1, x1, x21 + b Loop + +Width1: + smlal v21.4s, v0.4h, v9.4h + smlal2 v22.4s, v0.8h, v9.8h + smlal v21.4s, v1.4h, v10.4h + smlal2 v22.4s, v1.8h, v10.8h + smlal v21.4s, v2.4h, v11.4h + smlal2 v22.4s, v2.8h, v11.8h + smlal v21.4s, v3.4h, v13.4h + smlal2 v22.4s, v3.8h, v13.8h + smlal v21.4s, v4.4h, v14.4h + smlal2 v22.4s, v4.8h, v14.8h + smlal v21.4s, v5.4h, v15.4h + smlal2 v22.4s, v5.8h, v15.8h + smlal v21.4s, v6.4h, v17.4h + smlal2 v22.4s, v6.8h, v17.8h + smlal v21.4s, v7.4h, v18.4h + smlal2 v22.4s, v7.8h, v18.8h + smlal v21.4s, v8.4h, v19.4h + smlal2 v22.4s, v8.8h, v19.8h + + // Apply left shfit + sqshl v21.4s, v21.4s, v26.4s + sqshl v22.4s, v22.4s, v26.4s + + // Apply the fixed-point part of the multiplier. + sqrdmulh v21.4s, v21.4s, v27.4s + sqrdmulh v22.4s, v22.4s, v27.4s + + // Apply right shfit + and v9.16b, v28.16b, v21.16b + sshr v9.4s, v9.4s, #31 + sqadd v21.4s, v21.4s, v9.4s + srshl v21.4s, v21.4s, v28.4s + and v10.16b, v28.16b, v22.16b + sshr v10.4s, v10.4s, #31 + sqadd v22.4s, v22.4s, v10.4s + srshl v22.4s, v22.4s, v28.4s + + // Add output zero point + sqadd v21.4s, v21.4s, v29.4s + sqadd v22.4s, v22.4s, v29.4s + + // Apply min bound + smax v21.4s, v21.4s, v30.4s + smax v22.4s, v22.4s, v30.4s + + // Apply max bound + smin v21.4s, v21.4s, v31.4s + smin v22.4s, v22.4s, v31.4s + + sqxtn v21.4h, v21.4s + sqxtn2 v21.8h, v22.4s + sqxtn v21.8b, v21.8h + + st1 {v21.8b}, [x0], x6 + sub w8, w8, #1 + cbz w8, End + add x1, x1, x4 + b Loop + +End: + sub sp, sp, #160 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret + +#endif diff --git a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c index 510e1530a8..5f92bafe59 100644 --- a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c +++ b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c @@ -208,8 +208,13 @@ void ConvDw3x3Int8Block(int8_t *output, const int8_t *buffer, const int16_t *wei int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max, int stride) { for (; start_c <= end_c - 8; start_c += 8) { +#ifdef ENABLE_ARM64 + ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max); +#else ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max, stride); +#endif output += 8; buffer += 8; weight += 8; diff --git a/mindspore/lite/nnacl/int8/conv_depthwise_int8.h b/mindspore/lite/nnacl/int8/conv_depthwise_int8.h index 3696746881..e9c21e03ec 100644 --- a/mindspore/lite/nnacl/int8/conv_depthwise_int8.h +++ b/mindspore/lite/nnacl/int8/conv_depthwise_int8.h @@ -43,6 +43,14 @@ void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t * void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); + +#ifdef ENABLE_ARM64 +void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, + int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, + int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min, + int32_t acc_max); +#endif + #ifdef __cplusplus } #endif