From 7f9d65cce07719487a8625b3597b9cb5684080d2 Mon Sep 17 00:00:00 2001 From: lixian Date: Mon, 16 Nov 2020 16:36:43 +0800 Subject: [PATCH] apply int8 4x16 kernel --- .../lite/nnacl/assembly/opt/MatmulDpInt8Opt.S | 1083 +++++++++++++++++ mindspore/lite/nnacl/int8/conv_int8.c | 13 +- mindspore/lite/nnacl/int8/conv_int8.h | 2 +- mindspore/lite/nnacl/int8/matmul_int8.c | 232 ++++ mindspore/lite/nnacl/int8/matmul_int8.h | 9 + mindspore/lite/nnacl/matmul_parameter.h | 6 + mindspore/lite/nnacl/pack.h | 6 + .../kernel/arm/int8/convolution_1x1_int8.cc | 274 +++-- .../kernel/arm/int8/convolution_1x1_int8.h | 16 +- .../runtime/kernel/arm/int8/opt_op_handler.cc | 10 + .../runtime/kernel/arm/int8/opt_op_handler.h | 4 + 11 files changed, 1507 insertions(+), 148 deletions(-) create mode 100644 mindspore/lite/nnacl/assembly/opt/MatmulDpInt8Opt.S diff --git a/mindspore/lite/nnacl/assembly/opt/MatmulDpInt8Opt.S b/mindspore/lite/nnacl/assembly/opt/MatmulDpInt8Opt.S new file mode 100644 index 0000000000..ee276f01bc --- /dev/null +++ b/mindspore/lite/nnacl/assembly/opt/MatmulDpInt8Opt.S @@ -0,0 +1,1083 @@ +#ifdef __aarch64__ + .text + .align 5 + .global MatmulInt8DpOpt +#ifndef __APPLE__ + .type MatmulInt8DpOpt, %function +#endif + +//void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep4, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, +// int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// x3: row +// x4: col +// x5: deep4 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// x14: stride +// x15: filter_peroc +// x28: filter_zp + +MatmulInt8DpOpt: + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + stp x25, x26, [sp], #16 + stp x27, x28, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + ldr x14, [sp, #48] + ldr x15, [sp, #56] + + mov x23, #4 + mul x23, x23, x5 // lhs step + mov x24, #4 + mul x24, x24, x14 // dst step + +LoopRow: + mov x16, x1 // reload rhs ptr + mov x17, x4 // reload rhs col + mov x18, x7 // reload bias ptr + mov x25, x6 // reload input_sum ptr + mov x27, x2 // reload dst ptr + ldr x28, [sp, #64] // reload filter_zp + + LoopCol: + mov x19, x27 // reload dst ptr + mov x20, x0 // reload lhs ptr + mov x21, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + cmp x17, #4 + ble LoopDepthQuarter + cmp x17, #8 + ble LoopDepthHalf + + LoopDepth: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b, v2.16b, v3.16b, v4.16b}, [x16], #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v18.4s, v3.16b, v0.4b[0] + sdot v19.4s, v4.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v22.4s, v3.16b, v0.4b[1] + sdot v23.4s, v4.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v26.4s, v3.16b, v0.4b[2] + sdot v27.4s, v4.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + sdot v30.4s, v3.16b, v0.4b[3] + sdot v31.4s, v4.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepth + + Bias: + cbz x7, NoReadBias + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x18], #64 + add v16.4s, v16.4s, v0.4s + add v17.4s, v17.4s, v1.4s + add v18.4s, v18.4s, v2.4s + add v19.4s, v19.4s, v3.4s + add v20.4s, v20.4s, v0.4s + add v21.4s, v21.4s, v1.4s + add v22.4s, v22.4s, v2.4s + add v23.4s, v23.4s, v3.4s + add v24.4s, v24.4s, v0.4s + add v25.4s, v25.4s, v1.4s + add v26.4s, v26.4s, v2.4s + add v27.4s, v27.4s, v3.4s + add v28.4s, v28.4s, v0.4s + add v29.4s, v29.4s, v1.4s + add v30.4s, v30.4s, v2.4s + add v31.4s, v31.4s, v3.4s + + NoReadBias: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSum + + PerTensorSum: + sub v16.4s, v16.4s, v12.4s + sub v17.4s, v17.4s, v12.4s + sub v18.4s, v18.4s, v12.4s + sub v19.4s, v19.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v21.4s, v21.4s, v13.4s + sub v22.4s, v22.4s, v13.4s + sub v23.4s, v23.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v25.4s, v25.4s, v14.4s + sub v26.4s, v26.4s, v14.4s + sub v27.4s, v27.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + sub v29.4s, v29.4s, v15.4s + sub v30.4s, v30.4s, v15.4s + sub v31.4s, v31.4s, v15.4s + + b PerTensor + + PerChannelSum: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x28], #64 + mul v0.4s, v8.4s, v12.4s + mul v1.4s, v9.4s, v12.4s + mul v2.4s, v10.4s, v12.4s + mul v3.4s, v11.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + mul v5.4s, v9.4s, v13.4s + mul v6.4s, v10.4s, v13.4s + mul v7.4s, v11.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v17.4s, v17.4s, v1.4s + sub v18.4s, v18.4s, v2.4s + sub v19.4s, v19.4s, v3.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + sub v22.4s, v22.4s, v6.4s + sub v23.4s, v23.4s, v7.4s + mul v0.4s, v8.4s, v14.4s + mul v1.4s, v9.4s, v14.4s + mul v2.4s, v10.4s, v14.4s + mul v3.4s, v11.4s, v14.4s + mul v4.4s, v8.4s, v15.4s + mul v5.4s, v9.4s, v15.4s + mul v6.4s, v10.4s, v15.4s + mul v7.4s, v11.4s, v15.4s + sub v24.4s, v24.4s, v0.4s + sub v25.4s, v25.4s, v1.4s + sub v26.4s, v26.4s, v2.4s + sub v27.4s, v27.4s, v3.4s + sub v28.4s, v28.4s, v4.4s + sub v29.4s, v29.4s, v5.4s + sub v30.4s, v30.4s, v6.4s + sub v31.4s, v31.4s, v7.4s + + PerTensor: + cbnz x15, PerChannel + ld1r {v0.4s}, [x12] + mov v1.16b, v0.16b + mov v2.16b, v0.16b + mov v3.16b, v0.16b + ld1r {v4.4s}, [x11] + mov v5.16b, v4.16b + mov v6.16b, v4.16b + mov v7.16b, v4.16b + ld1r {v8.4s}, [x13] + mov v9.16b, v8.16b + mov v10.16b, v8.16b + mov v11.16b, v8.16b + + b Quantization + + PerChannel: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x12], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x11], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x13], #64 + + Quantization: + sqshl v16.4s, v16.4s, v0.4s + sqshl v17.4s, v17.4s, v1.4s + sqshl v18.4s, v18.4s, v2.4s + sqshl v19.4s, v19.4s, v3.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v21.4s, v21.4s, v1.4s + sqshl v22.4s, v22.4s, v2.4s + sqshl v23.4s, v23.4s, v3.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v25.4s, v25.4s, v1.4s + sqshl v26.4s, v26.4s, v2.4s + sqshl v27.4s, v27.4s, v3.4s + sqshl v28.4s, v28.4s, v0.4s + sqshl v29.4s, v29.4s, v1.4s + sqshl v30.4s, v30.4s, v2.4s + sqshl v31.4s, v31.4s, v3.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v17.4s, v17.4s, v5.4s + sqrdmulh v18.4s, v18.4s, v6.4s + sqrdmulh v19.4s, v19.4s, v7.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v21.4s, v21.4s, v5.4s + sqrdmulh v22.4s, v22.4s, v6.4s + sqrdmulh v23.4s, v23.4s, v7.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v25.4s, v25.4s, v5.4s + sqrdmulh v26.4s, v26.4s, v6.4s + sqrdmulh v27.4s, v27.4s, v7.4s + sqrdmulh v28.4s, v28.4s, v4.4s + sqrdmulh v29.4s, v29.4s, v5.4s + sqrdmulh v30.4s, v30.4s, v6.4s + sqrdmulh v31.4s, v31.4s, v7.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + and v1.16b, v9.16b, v17.16b + sshr v1.4s, v1.4s, #31 + sqadd v17.4s, v17.4s, v1.4s + srshl v17.4s, v17.4s, v9.4s + and v2.16b, v10.16b, v18.16b + sshr v2.4s, v2.4s, #31 + sqadd v18.4s, v18.4s, v2.4s + srshl v18.4s, v18.4s, v10.4s + and v3.16b, v11.16b, v19.16b + sshr v3.4s, v3.4s, #31 + sqadd v19.4s, v19.4s, v3.4s + srshl v19.4s, v19.4s, v11.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + and v1.16b, v9.16b, v21.16b + sshr v1.4s, v1.4s, #31 + sqadd v21.4s, v21.4s, v1.4s + srshl v21.4s, v21.4s, v9.4s + and v2.16b, v10.16b, v22.16b + sshr v2.4s, v2.4s, #31 + sqadd v22.4s, v22.4s, v2.4s + srshl v22.4s, v22.4s, v10.4s + and v3.16b, v11.16b, v23.16b + sshr v3.4s, v3.4s, #31 + sqadd v23.4s, v23.4s, v3.4s + srshl v23.4s, v23.4s, v11.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + and v1.16b, v9.16b, v25.16b + sshr v1.4s, v1.4s, #31 + sqadd v25.4s, v25.4s, v1.4s + srshl v25.4s, v25.4s, v9.4s + and v2.16b, v10.16b, v26.16b + sshr v2.4s, v2.4s, #31 + sqadd v26.4s, v26.4s, v2.4s + srshl v26.4s, v26.4s, v10.4s + and v3.16b, v11.16b, v27.16b + sshr v3.4s, v3.4s, #31 + sqadd v27.4s, v27.4s, v3.4s + srshl v27.4s, v27.4s, v11.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + and v1.16b, v9.16b, v29.16b + sshr v1.4s, v1.4s, #31 + sqadd v29.4s, v29.4s, v1.4s + srshl v29.4s, v29.4s, v9.4s + and v2.16b, v10.16b, v30.16b + sshr v2.4s, v2.4s, #31 + sqadd v30.4s, v30.4s, v2.4s + srshl v30.4s, v30.4s, v10.4s + and v3.16b, v11.16b, v31.16b + sshr v3.4s, v3.4s, #31 + sqadd v31.4s, v31.4s, v3.4s + srshl v31.4s, v31.4s, v11.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v17.4s, v17.4s, v6.4s + add v18.4s, v18.4s, v6.4s + add v19.4s, v19.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v21.4s, v21.4s, v6.4s + add v22.4s, v22.4s, v6.4s + add v23.4s, v23.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v25.4s, v25.4s, v6.4s + add v26.4s, v26.4s, v6.4s + add v27.4s, v27.4s, v6.4s + add v28.4s, v28.4s, v6.4s + add v29.4s, v29.4s, v6.4s + add v30.4s, v30.4s, v6.4s + add v31.4s, v31.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v17.4s, v17.4s, v0.4s + smax v18.4s, v18.4s, v0.4s + smax v19.4s, v19.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v21.4s, v21.4s, v0.4s + smax v22.4s, v22.4s, v0.4s + smax v23.4s, v23.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v25.4s, v25.4s, v0.4s + smax v26.4s, v26.4s, v0.4s + smax v27.4s, v27.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + smax v29.4s, v29.4s, v0.4s + smax v30.4s, v30.4s, v0.4s + smax v31.4s, v31.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v17.4s, v17.4s, v1.4s + smin v18.4s, v18.4s, v1.4s + smin v19.4s, v19.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v21.4s, v21.4s, v1.4s + smin v22.4s, v22.4s, v1.4s + smin v23.4s, v23.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v25.4s, v25.4s, v1.4s + smin v26.4s, v26.4s, v1.4s + smin v27.4s, v27.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + smin v29.4s, v29.4s, v1.4s + smin v30.4s, v30.4s, v1.4s + smin v31.4s, v31.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn2 v16.8h, v17.4s + sqxtn v0.8b, v16.8h + sqxtn v18.4h, v18.4s + sqxtn2 v18.8h, v19.4s + sqxtn2 v0.16b, v18.8h + + sqxtn v20.4h, v20.4s + sqxtn2 v20.8h, v21.4s + sqxtn v1.8b, v20.8h + sqxtn v22.4h, v22.4s + sqxtn2 v22.8h, v23.4s + sqxtn2 v1.16b, v22.8h + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v2.8b, v24.8h + sqxtn v26.4h, v26.4s + sqxtn2 v26.8h, v27.4s + sqxtn2 v2.16b, v26.8h + + sqxtn v28.4h, v28.4s + sqxtn2 v28.8h, v29.4s + sqxtn v3.8b, v28.8h + sqxtn v30.4h, v30.4s + sqxtn2 v30.8h, v31.4s + sqxtn2 v3.16b, v30.8h + + b WriteStart + + LoopDepthHalf: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b, v2.16b}, [x16] + add x16, x16, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v17.4s, v2.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v21.4s, v2.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v25.4s, v2.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + sdot v29.4s, v2.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepthHalf + + BiasHalf: + cbz x7, NoReadBiasHalf + ld1 {v0.4s, v1.4s}, [x18] + add x18, x18, #64 + add v16.4s, v16.4s, v0.4s + add v17.4s, v17.4s, v1.4s + add v20.4s, v20.4s, v0.4s + add v21.4s, v21.4s, v1.4s + add v24.4s, v24.4s, v0.4s + add v25.4s, v25.4s, v1.4s + add v28.4s, v28.4s, v0.4s + add v29.4s, v29.4s, v1.4s + + NoReadBiasHalf: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSumHalf + + PerTensorSumHalf: + sub v16.4s, v16.4s, v12.4s + sub v17.4s, v17.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v21.4s, v21.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v25.4s, v25.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + sub v29.4s, v29.4s, v15.4s + + b PerTensorHalf + + PerChannelSumHalf: + ld1 {v8.4s, v9.4s}, [x28] + add x28, x28, #64 + mul v0.4s, v8.4s, v12.4s + mul v1.4s, v9.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + mul v5.4s, v9.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v17.4s, v17.4s, v1.4s + sub v20.4s, v20.4s, v4.4s + sub v21.4s, v21.4s, v5.4s + mul v2.4s, v8.4s, v14.4s + mul v3.4s, v9.4s, v14.4s + mul v6.4s, v8.4s, v15.4s + mul v7.4s, v9.4s, v15.4s + sub v24.4s, v24.4s, v2.4s + sub v25.4s, v25.4s, v3.4s + sub v28.4s, v28.4s, v6.4s + sub v29.4s, v29.4s, v7.4s + + PerTensorHalf: + cbnz x15, PerChannelHalf + ld1r {v0.4s}, [x12] + mov v1.16b, v0.16b + ld1r {v4.4s}, [x11] + mov v5.16b, v4.16b + ld1r {v8.4s}, [x13] + mov v9.16b, v8.16b + + b QuantizationHalf + + PerChannelHalf: + ld1 {v0.4s, v1.4s}, [x12] + add x12, x12, #64 + ld1 {v4.4s, v5.4s}, [x11] + add x11, x11, #64 + ld1 {v8.4s, v9.4s}, [x13] + add x13, x13, #64 + + QuantizationHalf: + sqshl v16.4s, v16.4s, v0.4s + sqshl v17.4s, v17.4s, v1.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v21.4s, v21.4s, v1.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v25.4s, v25.4s, v1.4s + sqshl v28.4s, v28.4s, v0.4s + sqshl v29.4s, v29.4s, v1.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v17.4s, v17.4s, v5.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v21.4s, v21.4s, v5.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v25.4s, v25.4s, v5.4s + sqrdmulh v28.4s, v28.4s, v4.4s + sqrdmulh v29.4s, v29.4s, v5.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + and v1.16b, v9.16b, v17.16b + sshr v1.4s, v1.4s, #31 + sqadd v17.4s, v17.4s, v1.4s + srshl v17.4s, v17.4s, v9.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + and v1.16b, v9.16b, v21.16b + sshr v1.4s, v1.4s, #31 + sqadd v21.4s, v21.4s, v1.4s + srshl v21.4s, v21.4s, v9.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + and v1.16b, v9.16b, v25.16b + sshr v1.4s, v1.4s, #31 + sqadd v25.4s, v25.4s, v1.4s + srshl v25.4s, v25.4s, v9.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + and v1.16b, v9.16b, v29.16b + sshr v1.4s, v1.4s, #31 + sqadd v29.4s, v29.4s, v1.4s + srshl v29.4s, v29.4s, v9.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v17.4s, v17.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v21.4s, v21.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v25.4s, v25.4s, v6.4s + add v28.4s, v28.4s, v6.4s + add v29.4s, v29.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v17.4s, v17.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v21.4s, v21.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v25.4s, v25.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + smax v29.4s, v29.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v17.4s, v17.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v21.4s, v21.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v25.4s, v25.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + smin v29.4s, v29.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn2 v16.8h, v17.4s + sqxtn v0.8b, v16.8h + + sqxtn v20.4h, v20.4s + sqxtn2 v20.8h, v21.4s + sqxtn v1.8b, v20.8h + + sqxtn v24.4h, v24.4s + sqxtn2 v24.8h, v25.4s + sqxtn v2.8b, v24.8h + + sqxtn v28.4h, v28.4s + sqxtn2 v28.8h, v29.4s + sqxtn v3.8b, v28.8h + + b WriteStart + + LoopDepthQuarter: + ld1 {v0.16b}, [x20], #16 + ld1 {v1.16b}, [x16] + add x16, x16, #64 + sdot v16.4s, v1.16b, v0.4b[0] + sdot v20.4s, v1.16b, v0.4b[1] + sdot v24.4s, v1.16b, v0.4b[2] + sdot v28.4s, v1.16b, v0.4b[3] + + subs x21, x21, #4 + bgt LoopDepthQuarter + + BiasQuarter: + cbz x7, NoReadBiasQuarter + ld1 {v0.4s}, [x18] + add x18, x18, #64 + add v16.4s, v16.4s, v0.4s + add v20.4s, v20.4s, v0.4s + add v24.4s, v24.4s, v0.4s + add v28.4s, v28.4s, v0.4s + + NoReadBiasQuarter: + ld1r {v12.4s}, [x25], #4 + ld1r {v13.4s}, [x25], #4 + ld1r {v14.4s}, [x25], #4 + ld1r {v15.4s}, [x25], #4 + cbnz x15, PerChannelSumQuarter + + PerTensorSumQuarter: + sub v16.4s, v16.4s, v12.4s + sub v20.4s, v20.4s, v13.4s + sub v24.4s, v24.4s, v14.4s + sub v28.4s, v28.4s, v15.4s + + b PerTensorQuarter + + PerChannelSumQuarter: + ld1 {v8.4s}, [x28] + add x28, x28, #64 + mul v0.4s, v8.4s, v12.4s + mul v4.4s, v8.4s, v13.4s + sub v16.4s, v16.4s, v0.4s + sub v20.4s, v20.4s, v4.4s + mul v2.4s, v8.4s, v14.4s + mul v6.4s, v8.4s, v15.4s + sub v24.4s, v24.4s, v2.4s + sub v28.4s, v28.4s, v6.4s + + PerTensorQuarter: + cbnz x15, PerChannelQuarter + ld1r {v0.4s}, [x12] + ld1r {v4.4s}, [x11] + ld1r {v8.4s}, [x13] + + b QuantizationHalf + + PerChannelQuarter: + ld1 {v0.4s}, [x12] + add x12, x12, #64 + ld1 {v4.4s}, [x11] + add x11, x11, #64 + ld1 {v8.4s}, [x13] + add x13, x13, #64 + + QuantizationQuarter: + sqshl v16.4s, v16.4s, v0.4s + sqshl v20.4s, v20.4s, v0.4s + sqshl v24.4s, v24.4s, v0.4s + sqshl v28.4s, v28.4s, v0.4s + + sqrdmulh v16.4s, v16.4s, v4.4s + sqrdmulh v20.4s, v20.4s, v4.4s + sqrdmulh v24.4s, v24.4s, v4.4s + sqrdmulh v28.4s, v28.4s, v4.4s + + and v0.16b, v8.16b, v16.16b + sshr v0.4s, v0.4s, #31 + sqadd v16.4s, v16.4s, v0.4s + srshl v16.4s, v16.4s, v8.4s + + and v0.16b, v8.16b, v20.16b + sshr v0.4s, v0.4s, #31 + sqadd v20.4s, v20.4s, v0.4s + srshl v20.4s, v20.4s, v8.4s + + and v0.16b, v8.16b, v24.16b + sshr v0.4s, v0.4s, #31 + sqadd v24.4s, v24.4s, v0.4s + srshl v24.4s, v24.4s, v8.4s + + and v0.16b, v8.16b, v28.16b + sshr v0.4s, v0.4s, #31 + sqadd v28.4s, v28.4s, v0.4s + srshl v28.4s, v28.4s, v8.4s + + // zp + dup v6.4s, w10 + add v16.4s, v16.4s, v6.4s + add v20.4s, v20.4s, v6.4s + add v24.4s, v24.4s, v6.4s + add v28.4s, v28.4s, v6.4s + + // min + dup v0.4s, w8 + smax v16.4s, v16.4s, v0.4s + smax v20.4s, v20.4s, v0.4s + smax v24.4s, v24.4s, v0.4s + smax v28.4s, v28.4s, v0.4s + + // max + dup v1.4s, w9 + smin v16.4s, v16.4s, v1.4s + smin v20.4s, v20.4s, v1.4s + smin v24.4s, v24.4s, v1.4s + smin v28.4s, v28.4s, v1.4s + + sqxtn v16.4h, v16.4s + sqxtn v0.8b, v16.8h + + sqxtn v20.4h, v20.4s + sqxtn v1.8b, v20.8h + + sqxtn v24.4h, v24.4s + sqxtn v2.8b, v24.8h + + sqxtn v28.4h, v28.4s + sqxtn v3.8b, v28.8h + + b WriteStart + + WriteStart: + cmp x17, #1 + beq Write1 + cmp x17, #2 + beq Write2 + cmp x17, #3 + beq Write3 + cmp x17, #4 + beq Write4 + cmp x17, #5 + beq Write5 + cmp x17, #6 + beq Write6 + cmp x17, #7 + beq Write7 + cmp x17, #8 + beq Write8 + cmp x17, #9 + beq Write9 + cmp x17, #10 + beq Write10 + cmp x17, #11 + beq Write11 + cmp x17, #12 + beq Write12 + cmp x17, #13 + beq Write13 + cmp x17, #14 + beq Write14 + cmp x17, #15 + beq Write15 + b Write16 + + Write1: + add x27, x27, #1 + st1 {v0.b}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.b}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.b}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.b}[0], [x19], x14 + b WriteEnd + Write2: + add x27, x27, #2 + st1 {v0.h}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.h}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.h}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.h}[0], [x19], x14 + b WriteEnd + Write3: + add x27, x27, #3 + add x22, x19, #2 + st1 {v0.h}[0], [x19], x14 + st1 {v0.b}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.h}[0], [x19], x14 + st1 {v1.b}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.h}[0], [x19], x14 + st1 {v2.b}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.h}[0], [x19], x14 + st1 {v3.b}[2], [x22], x14 + b WriteEnd + Write4: + add x27, x27, #4 + st1 {v0.s}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + b WriteEnd + Write5: + add x27, x27, #5 + add x22, x19, #4 + st1 {v0.s}[0], [x19], x14 + st1 {v0.b}[4], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.b}[4], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.b}[4], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.b}[4], [x22], x14 + b WriteEnd + Write6: + add x27, x27, #6 + add x22, x19, #4 + st1 {v0.s}[0], [x19], x14 + st1 {v0.h}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.h}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.h}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.h}[2], [x22], x14 + b WriteEnd + Write7: + add x27, x27, #7 + add x22, x19, #4 + add x26, x19, #6 + st1 {v0.s}[0], [x19], x14 + st1 {v0.h}[2], [x22], x14 + st1 {v0.b}[6], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.s}[0], [x19], x14 + st1 {v1.h}[2], [x22], x14 + st1 {v1.b}[6], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.s}[0], [x19], x14 + st1 {v2.h}[2], [x22], x14 + st1 {v2.b}[6], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.s}[0], [x19], x14 + st1 {v3.h}[2], [x22], x14 + st1 {v3.b}[6], [x26], x14 + b WriteEnd + Write8: + add x27, x27, #8 + st1 {v0.8b}, [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + b WriteEnd + Write9: + add x27, x27, #9 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.b}[8], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.b}[8], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.b}[8], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.b}[8], [x22], x14 + b WriteEnd + Write10: + add x27, x27, #10 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.h}[4], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.h}[4], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.h}[4], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.h}[4], [x22], x14 + b WriteEnd + Write11: + add x27, x27, #11 + add x22, x19, #8 + add x26, x19, #10 + st1 {v0.8b}, [x19], x14 + st1 {v0.h}[4], [x22], x14 + st1 {v0.b}[10], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.h}[4], [x22], x14 + st1 {v1.b}[10], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.h}[4], [x22], x14 + st1 {v2.b}[10], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.h}[4], [x22], x14 + st1 {v3.b}[10], [x26], x14 + b WriteEnd + Write12: + add x27, x27, #12 + add x22, x19, #8 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + b WriteEnd + Write13: + add x27, x27, #13 + add x22, x19, #8 + add x26, x19, #12 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.b}[12], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.b}[12], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.b}[12], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.b}[12], [x26], x14 + b WriteEnd + Write14: + add x27, x27, #14 + add x22, x19, #8 + add x26, x19, #12 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.h}[6], [x26], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.h}[6], [x26], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.h}[6], [x26], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.h}[6], [x26], x14 + b WriteEnd + Write15: + add x27, x27, #15 + add x22, x19, #8 + add x26, x19, #12 + add x21, x19, #14 + st1 {v0.8b}, [x19], x14 + st1 {v0.s}[2], [x22], x14 + st1 {v0.h}[6], [x26], x14 + st1 {v0.b}[14], [x21], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.8b}, [x19], x14 + st1 {v1.s}[2], [x22], x14 + st1 {v1.h}[6], [x26], x14 + st1 {v1.b}[14], [x21], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.8b}, [x19], x14 + st1 {v2.s}[2], [x22], x14 + st1 {v2.h}[6], [x26], x14 + st1 {v2.b}[14], [x21], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.8b}, [x19], x14 + st1 {v3.s}[2], [x22], x14 + st1 {v3.h}[6], [x26], x14 + st1 {v3.b}[14], [x21], x14 + b WriteEnd + Write16: + add x27, x27, #16 + st1 {v0.16b}, [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v1.16b}, [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v2.16b}, [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v3.16b}, [x19], x14 + + WriteEnd: + subs x17, x17, #16 + ble LoopColEnd + mov x25, x6 + b LoopCol + +LoopColEnd: + subs x3, x3, #4 + ble LoopRowEnd + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + add x6, x6, #16 + add x0, x0, x23 + add x2, x2, x24 + b LoopRow + +LoopRowEnd: + sub sp, sp, #208 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index 2e99904c89..742019ce0b 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -378,11 +378,11 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t * "b 16f \n" "10: \n" - "ld1 {v16.h}[0], [x10] \n" + "ld1 {v16.d}[0], [x10] \n" "b 16f \n" "11: \n" - "ld1 {v16.h}[0], [x10] \n" + "ld1 {v16.d}[0], [x10] \n" "add x10, x10, #8 \n" "ld1 {v16.s}[2], [x10] \n" "b 16f \n" @@ -802,11 +802,12 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, - int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) { + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) { int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; - matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias, - left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc); + matmul_func(packed_input, packed_weight, dst, row, col, deep4, col, input_sum, bias, left_shift, right_shift, + multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc, + filter_zp); return; } diff --git a/mindspore/lite/nnacl/int8/conv_int8.h b/mindspore/lite/nnacl/int8/conv_int8.h index f116d5fb82..2792122002 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.h +++ b/mindspore/lite/nnacl/int8/conv_int8.h @@ -46,7 +46,7 @@ void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t int32_t *multiplier, ConvParameter *conv_param); void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, - int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func); + int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp); void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, ConvParameter *conv_param); diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index c0cc5c2cfa..e681abf104 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -64,6 +64,21 @@ void MatrixEmptyInt8(int8_t *dst, int row, int col) { return; } +void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { + int col4 = UP_ROUND(col, C4NUM); + for (int r = 0; r < row; r++) { + int rd16 = r / C16NUM; + int rm16 = r % C16NUM; + for (int c = 0; c < col; c++) { + int cd4 = c / C4NUM; + int cm4 = c % C4NUM; + int dst_index = rd16 * col4 * C16NUM + cd4 * C16NUM * C4NUM + rm16 * C4NUM + cm4; + int src_index = r * col + c; + dst_ptr[dst_index] = src_ptr[src_index]; + } + } +} + void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) { /* Row-major to row16x4-major (block row-major) */ int col16 = UP_ROUND(col, C16NUM); @@ -268,6 +283,223 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, return; } +void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, + int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, + size_t per_channel, int32_t *filter_zp) { + /* row4x4-major * row4x16-major => (int8)row-major */ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c16div = c / C16NUM, c16mod = c % C16NUM; + size_t ci = r * col + c; + int32_t value = 0; + for (int d = 0; d < deep_4; d++) { + int d4div = d / C4NUM, d4mod = d % C4NUM; + size_t ai = r4div * deep_4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod; + size_t bi = c16div * deep_4 * C16NUM + d4div * C16NUM * C4NUM + c16mod * C4NUM + d4mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = per_channel ? input_sum[r] * filter_zp[c] : input_sum[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} + +void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, + size_t input_channel, size_t plane_size, int32_t filter_zp) { + int ic4 = UP_ROUND(input_channel, C4NUM); + int hw4 = UP_ROUND(plane_size, C4NUM); + size_t hw_4div = plane_size / C4NUM * C4NUM; + size_t ic_4div = input_channel / C4NUM * C4NUM; + + const int8_t *src_r = src_input; + int8_t *pack_r = packed_input; + /* per layer */ + for (int hwi = 0; hwi < hw_4div; hwi += C4NUM) { + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + int32_t *input_sum_r = input_sum + hwi; +#ifdef ENABLE_ARM64 + size_t src_stride = input_channel; + size_t ic_4res = input_channel - ic_4div; + asm volatile( + "dup v2.4s, wzr \n" + "mov x14, %[input_sum_r] \n" + "dup v3.4s, %w[filter_zp] \n" + + "mov x10, %[src_ic] \n" + "mov x11, %[pack_ic] \n" + + "mov x15, #0 \n" + "1: \n" + "cmp x15, %[ic_4div] \n" + "add x15, x15, #4\n" + "mov x12, x10 \n" + "add x10, x10, #4\n" + "blt 2f \n" + "cmp %[ic_4res], #0\n" + "beq 6f \n" + "cmp %[ic_4res], #1\n" + "beq 3f \n" + "cmp %[ic_4res], #2\n" + "beq 4f \n" + "cmp %[ic_4res], #3\n" + "beq 5f \n" + + "2: \n" + "ld1 {v0.s}[0], [x12], %[src_stride]\n" + "ld1 {v0.s}[1], [x12], %[src_stride]\n" + "ld1 {v0.s}[2], [x12], %[src_stride]\n" + "ld1 {v0.s}[3], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 1b \n" + + "3: \n" /* ic res 1 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.b}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[8], [x12], %[src_stride]\n" + "ld1 {v0.b}[12], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "4: \n" /* ic res 2 */ + "dup v0.4s, wzr \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "5: \n" /* ic res 3 */ + "dup v0.4s, wzr \n" + "add x13, x12, #2 \n" + + "ld1 {v0.h}[0], [x12], %[src_stride]\n" + "ld1 {v0.b}[2], [x13], %[src_stride]\n" + "ld1 {v0.h}[2], [x12], %[src_stride]\n" + "ld1 {v0.b}[6], [x13], %[src_stride]\n" + "ld1 {v0.h}[4], [x12], %[src_stride]\n" + "ld1 {v0.b}[10], [x13], %[src_stride]\n" + "ld1 {v0.h}[6], [x12], %[src_stride]\n" + "ld1 {v0.b}[14], [x13], %[src_stride]\n" + + "st1 {v0.16b}, [x11], #16\n" + "saddlp v1.8h, v0.16b \n" + "saddlp v0.4s, v1.8h \n" + "add v2.4s, v2.4s, v0.4s \n" + "b 6f \n" + + "6: \n" + "mul v2.4s, v2.4s, v3.4s \n" + + "st1 {v2.4s}, [x14], #16 \n" + + : + : [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r), + [ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp) + : "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3"); +#else + int32_t tmp_sum_value[4] = {0}; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + for (int i = 0; i < C4NUM; i++) { + tmp_sum_value[i] += src_ic[0 + i * input_channel]; + tmp_sum_value[i] += src_ic[1 + i * input_channel]; + tmp_sum_value[i] += src_ic[2 + i * input_channel]; + tmp_sum_value[i] += src_ic[3 + i * input_channel]; + pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel]; + pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel]; + pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel]; + pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel]; + } + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + for (int i = 0; i < C4NUM; i++) { + tmp_sum_value[i] += src_ic[i * input_channel]; + pack_ic[i * C4NUM] = src_ic[i * input_channel]; + } + src_ic += 1; + pack_ic += 1; + } + + for (int ici = input_channel; ici < ic4; ici += 1) { + for (int i = 0; i < C4NUM; i++) { + pack_ic[i * C4NUM] = 0; + } + pack_ic += 1; + } + + for (int i = 0; i < C4NUM; i++) { + input_sum_r[i] = tmp_sum_value[i] * filter_zp; + } +#endif + src_r += input_channel * C4NUM; + pack_r += ic4 * C4NUM; + } + + if (hw_4div != plane_size) { + memset(pack_r, 0, C4NUM * ic4); + for (int hwi = hw_4div; hwi < plane_size; hwi += 1) { + int32_t tmp_sum_value = 0; + const int8_t *src_ic = src_r; + int8_t *pack_ic = pack_r; + for (int ici = 0; ici < ic_4div; ici += C4NUM) { + tmp_sum_value += src_ic[0]; + tmp_sum_value += src_ic[1]; + tmp_sum_value += src_ic[2]; + tmp_sum_value += src_ic[3]; + pack_ic[0] = src_ic[0]; + pack_ic[1] = src_ic[1]; + pack_ic[2] = src_ic[2]; + pack_ic[3] = src_ic[3]; + src_ic += C4NUM; + pack_ic += C4NUM * C4NUM; + } + for (int ici = ic_4div; ici < input_channel; ici += 1) { + tmp_sum_value += src_ic[0]; + pack_ic[0] = src_ic[0]; + src_ic += 1; + pack_ic += 1; + } + input_sum[hwi] = tmp_sum_value * filter_zp; + src_r += input_channel; + pack_r += C4NUM; + } + for (int hwi = plane_size; hwi < hw4; hwi++) { + input_sum[hwi] = 0; + } + } + return; +} + void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst) { int row_16 = UP_ROUND(row, C16NUM); int stride = sizeof(int8_t) * 16 * 4; diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index 4df4e6be5b..2be58e4370 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -52,6 +52,15 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, bool peroc); +/* 4x4 4x16 -> 4x16 */ +void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col); +void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, + size_t input_channel, size_t plane_size, int32_t filter_zp); +void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, + int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, + size_t per_channel, int32_t *filter_zp); + #ifdef ENABLE_ARM64 void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, diff --git a/mindspore/lite/nnacl/matmul_parameter.h b/mindspore/lite/nnacl/matmul_parameter.h index ae4e5dbc31..51aadae35e 100644 --- a/mindspore/lite/nnacl/matmul_parameter.h +++ b/mindspore/lite/nnacl/matmul_parameter.h @@ -27,6 +27,11 @@ typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel); +typedef void (*MATMUL_OPT_DP_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, + int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel, int *filter_zp); + typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 } OutType; typedef struct MatMulParameter { @@ -40,6 +45,7 @@ typedef struct MatMulParameter { int col_2_; int col_4_; int col_8_; + int col_16_; int deep_; int deep_4_; int deep_16_; diff --git a/mindspore/lite/nnacl/pack.h b/mindspore/lite/nnacl/pack.h index 5631659f03..a3bdfab486 100644 --- a/mindspore/lite/nnacl/pack.h +++ b/mindspore/lite/nnacl/pack.h @@ -38,6 +38,12 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_ void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16); +void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, + size_t plane_size, size_t input_channel, size_t output_channel); + +void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr, + size_t plane_size, size_t input_channel, size_t output_channel); + void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size); void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc index 3b3aa30cb9..b49c92e5e6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc @@ -26,16 +26,6 @@ using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_OK; namespace mindspore::kernel { -int Convolution1x1Int8Pre(void *cdata, int task_id) { - auto conv = reinterpret_cast(cdata); - auto error_code = conv->RunPre(task_id); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "conv1x1 Int8 RunPre error task_id[" << task_id << "] error_code[" << error_code << "]"; - return RET_ERROR; - } - return RET_OK; -} - Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() { if (matmul_param_ != nullptr) { delete matmul_param_; @@ -73,13 +63,42 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() { return; } +int Convolution1x1Int8CPUKernel::InitRunBuf() { + input_sum_ = reinterpret_cast(ctx_->allocator->Malloc(input_sum_size_ * sizeof(int32_t))); + if (input_sum_ == nullptr) { + MS_LOG(ERROR) << "malloc input_sum_ failed."; + return RET_ERROR; + } + + size_t size = support_optimize_ ? UP_ROUND(matmul_param_->row_, C8NUM) * UP_ROUND(matmul_param_->deep_, C4NUM) + : UP_ROUND(matmul_param_->row_, C4NUM) * UP_ROUND(matmul_param_->deep_, C16NUM); + packed_input_ = reinterpret_cast(ctx_->allocator->Malloc(size * sizeof(int8_t))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "conv1x1 int8 Malloc packed_input_ error!"; + return RET_ERROR; + } + return RET_OK; +} + +void Convolution1x1Int8CPUKernel::FreeRunBuf() { + if (packed_input_ != nullptr) { + ctx_->allocator->Free(packed_input_); + packed_input_ = nullptr; + } + if (input_sum_ != nullptr) { + ctx_->allocator->Free(input_sum_); + input_sum_ = nullptr; + } + return; +} + void Convolution1x1Int8CPUKernel::CheckSupportOptimize() { support_optimize_ = false; - matmul_func_ = MatMulInt8_8x8_r; + matmul_func_ = MatMulInt8_4x16_r; #ifdef ENABLE_ARM64 if (mindspore::lite::IsSupportSDot()) { support_optimize_ = true; - matmul_func_ = MatMulRInt8_optimize_handler; + matmul_func_ = MatMulDpInt8_optimize_handler; } else { support_optimize_ = false; matmul_func_ = nullptr; @@ -104,7 +123,8 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe } if (filter_peroc_) { - filter_zp_ptr_ = reinterpret_cast(malloc(output_channel * sizeof(int32_t))); + /* filter zp */ + filter_zp_ptr_ = reinterpret_cast(malloc(round_oc * sizeof(int32_t))); if (filter_zp_ptr_ == nullptr) { return RET_ERROR; } @@ -112,24 +132,33 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_; } + /* left shift */ left_shift_ = reinterpret_cast(malloc(round_oc * sizeof(int32_t))); if (left_shift_ == nullptr) { return RET_ERROR; } memset(left_shift_, 0, round_oc * sizeof(int32_t)); memcpy(left_shift_, conv_param_->conv_quant_arg_.left_shift_, output_channel * sizeof(int32_t)); + + /* right shift */ right_shift_ = reinterpret_cast(malloc(round_oc * sizeof(int32_t))); if (right_shift_ == nullptr) { return RET_ERROR; } memset(right_shift_, 0, round_oc * sizeof(int32_t)); memcpy(right_shift_, conv_param_->conv_quant_arg_.right_shift_, output_channel * sizeof(int32_t)); + + /* multiplier */ multiplier_ = reinterpret_cast(malloc(round_oc * sizeof(int32_t))); if (multiplier_ == nullptr) { return RET_ERROR; } memset(multiplier_, 0, round_oc * sizeof(int32_t)); memcpy(multiplier_, conv_param_->conv_quant_arg_.quant_multiplier_, output_channel * sizeof(int32_t)); + } else { + right_shift_ = conv_param_->conv_quant_arg_.right_shift_; + left_shift_ = conv_param_->conv_quant_arg_.left_shift_; + multiplier_ = conv_param_->conv_quant_arg_.quant_multiplier_; } return RET_OK; } @@ -140,7 +169,7 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() { auto output_channel = filter_tensor->Batch(); /* weight */ - size_t size = support_optimize_ ? UP_ROUND(input_channel, C4NUM) * UP_ROUND(output_channel, C8NUM) * sizeof(int8_t) + size_t size = support_optimize_ ? UP_ROUND(input_channel, C4NUM) * UP_ROUND(output_channel, C16NUM) * sizeof(int8_t) : UP_ROUND(input_channel, C16NUM) * UP_ROUND(output_channel, C4NUM) * sizeof(int8_t); packed_weight_ = reinterpret_cast(malloc(size)); if (packed_weight_ == nullptr) { @@ -149,16 +178,14 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() { } memset(packed_weight_, 0, size); if (support_optimize_) { - RowMajor2Row8x4MajorInt8(reinterpret_cast(filter_tensor->MutableData()), packed_weight_, output_channel, - input_channel); + RowMajor2Row4x16MajorInt8(reinterpret_cast(filter_tensor->MutableData()), packed_weight_, output_channel, + input_channel); } else { RowMajor2Row16x4MajorInt8(reinterpret_cast(filter_tensor->MutableData()), packed_weight_, output_channel, input_channel); } - int col4 = UP_ROUND(output_channel, C4NUM); - int col8 = UP_ROUND(output_channel, C8NUM); - size = support_optimize_ ? col8 : col4; + size = support_optimize_ ? UP_ROUND(output_channel, C16NUM) : UP_ROUND(output_channel, C4NUM); bias_data_ = malloc(size * sizeof(int32_t)); if (bias_data_ == nullptr) { MS_LOG(ERROR) << "Conv1x1 int8 Malloc bias_ptr_ error!"; @@ -166,10 +193,10 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() { } memset(bias_data_, 0, size * sizeof(int32_t)); if (in_tensors_.size() == 3) { - memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t)); + memcpy(bias_data_, in_tensors_[kBiasIndex]->data_c(), output_channel * sizeof(int32_t)); } - InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, size); + InitBiasByzp(filter_tensor->data_c(), input_channel, output_channel, size); return RET_OK; } @@ -198,7 +225,7 @@ int Convolution1x1Int8CPUKernel::InitWeightBiasArm32() { } memset(bias_data_, 0, col2 * sizeof(int32_t)); if (in_tensors_.size() == 3) { - memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t)); + memcpy(bias_data_, in_tensors_[kBiasIndex]->data_c(), output_channel * sizeof(int32_t)); } InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, col2); @@ -248,6 +275,7 @@ int Convolution1x1Int8CPUKernel::InitParam() { matmul_param_->col_2_ = UP_ROUND(matmul_param_->col_, C2NUM); matmul_param_->col_4_ = UP_ROUND(matmul_param_->col_, C4NUM); matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM); + matmul_param_->col_16_ = UP_ROUND(matmul_param_->col_, C16NUM); matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM); matmul_param_->row_8_ = UP_ROUND(matmul_param_->row_, C8NUM); matmul_param_->deep_4_ = UP_ROUND(matmul_param_->deep_, C4NUM); @@ -255,13 +283,14 @@ int Convolution1x1Int8CPUKernel::InitParam() { int row_pack_count = 0; int col_pack_count = 0; + #ifdef ENABLE_ARM32 row_pack_count = C4NUM; col_pack_count = C2NUM; #else if (support_optimize_) { - row_pack_count = C8NUM; - col_pack_count = C8NUM; + row_pack_count = C4NUM; + col_pack_count = C16NUM; } else { row_pack_count = C4NUM; col_pack_count = C4NUM; @@ -269,17 +298,18 @@ int Convolution1x1Int8CPUKernel::InitParam() { #endif /* init input sum size */ - if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { - input_sum_size_ = UP_ROUND(matmul_param_->col_, col_pack_count) * UP_ROUND(matmul_param_->row_, row_pack_count); - } else { + if (support_optimize_) { input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count); + } else { + if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { + input_sum_size_ = UP_ROUND(matmul_param_->col_, col_pack_count) * UP_ROUND(matmul_param_->row_, row_pack_count); + } else { + input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count); + } } - thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->col_, col_pack_count)); - thread_stride_ = UP_DIV(UP_DIV(matmul_param_->col_, col_pack_count), thread_count_); - - thread_count_hw_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_pack_count)); - thread_stride_hw_ = UP_DIV(UP_DIV(matmul_param_->row_, row_pack_count), thread_count_hw_); + thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(matmul_param_->row_, row_pack_count)); + thread_stride_ = UP_DIV(UP_DIV(matmul_param_->row_, row_pack_count), thread_count_); if (pre_trans_input_) { input_ptr_ = reinterpret_cast(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t))); @@ -306,145 +336,121 @@ int Convolution1x1Int8CPUKernel::ReSize() { } void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_output) { + /* deal with pad and stride */ output_ptr_ = src_output; if (pre_trans_input_) { Conv1x1InputPack(src_input, input_ptr_, conv_param_, sizeof(int8_t)); } else { input_ptr_ = src_input; } - - if (support_optimize_) { - ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8Pre, this, thread_count_hw_); - } else { - RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_); - PackInputSum16x4Int8(packed_input_, input_sum_, filter_zp_ptr_, conv_param_); - } return; } -int Convolution1x1Int8CPUKernel::RunImpl(int task_id) { - int32_t *cur_input_sum = input_sum_; - int32_t *cur_left_shift = conv_param_->conv_quant_arg_.left_shift_; - int32_t *cur_right_shift = conv_param_->conv_quant_arg_.right_shift_; - int32_t *cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_; - -#ifdef ENABLE_ARM32 - int cur_stride = thread_stride_ * C2NUM; - int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C2NUM; - int cur_oc = MSMIN(cur_stride, res_stride); - if (cur_oc <= 0) { +int Convolution1x1Int8CPUKernel::RunArm64(int task_id) { + int cur_stride = thread_stride_ * C4NUM; + int res_stride = matmul_param_->row_ - task_id * thread_stride_ * C4NUM; + int cur_hw = MSMIN(cur_stride, res_stride); + if (cur_hw <= 0) { return RET_OK; } + + int8_t *hw_in = input_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->input_channel_; + int8_t *hw_out = output_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->output_channel_; + int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_; + int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_ * C4NUM * matmul_param_->col_4_ + : input_sum_ + task_id * thread_stride_ * C4NUM; + + RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_); + if (filter_peroc_) { - cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C2NUM; - cur_left_shift = left_shift_ + task_id * thread_stride_ * C2NUM; - cur_right_shift = right_shift_ + task_id * thread_stride_ * C2NUM; - cur_multiplier = multiplier_ + task_id * thread_stride_ * C2NUM; - } - Conv1x1Int8Arm32(packed_input_, packed_weight_ + task_id * thread_stride_ * C2NUM * matmul_param_->deep_16_, - output_ptr_ + task_id * thread_stride_ * C2NUM, cur_input_sum, - reinterpret_cast(bias_data_) + task_id * thread_stride_ * C2NUM, matmul_param_->row_, - cur_oc, matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_); -#else - if (support_optimize_) { - int cur_stride = thread_stride_ * C8NUM; - int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C8NUM; - int cur_oc = MSMIN(cur_stride, res_stride); - if (cur_oc <= 0) { - return RET_OK; - } - if (filter_peroc_) { - cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM; - cur_left_shift = left_shift_ + task_id * thread_stride_ * C8NUM; - cur_right_shift = right_shift_ + task_id * thread_stride_ * C8NUM; - cur_multiplier = multiplier_ + task_id * thread_stride_ * C8NUM; - } - Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_, - output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum, - reinterpret_cast(bias_data_) + task_id * thread_stride_ * C8NUM, matmul_param_->row_, - cur_oc, matmul_param_->deep_4_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_, - matmul_func_); + PackInputSum16x4PerChannel(hw_packed_in, hw_input_sum, filter_zp_ptr_, cur_hw, matmul_param_->deep_, + matmul_param_->col_); } else { - int cur_stride = thread_stride_ * C4NUM; - int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C4NUM; - int cur_oc = MSMIN(cur_stride, res_stride); - if (cur_oc <= 0) { - return RET_OK; - } - if (filter_peroc_) { - cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM; - cur_left_shift = left_shift_ + task_id * thread_stride_ * C4NUM; - cur_right_shift = right_shift_ + task_id * thread_stride_ * C4NUM; - cur_multiplier = multiplier_ + task_id * thread_stride_ * C4NUM; - } - Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_, - output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum, - reinterpret_cast(bias_data_) + task_id * thread_stride_ * C4NUM, matmul_param_->row_, cur_oc, - matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_); + PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, + UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_); } -#endif + + Conv1x1Int8(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast(bias_data_), cur_hw, + matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_); return RET_OK; } -int Convolution1x1Int8CPUKernel::RunPre(int task_id) { - int cur_stride = thread_stride_hw_ * C8NUM; - int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C8NUM; +int Convolution1x1Int8CPUKernel::RunArm32(int task_id) { + int cur_stride = thread_stride_ * C4NUM; + int res_stride = matmul_param_->row_ - task_id * thread_stride_ * C4NUM; int cur_hw = MSMIN(cur_stride, res_stride); if (cur_hw <= 0) { return RET_OK; } + int8_t *hw_in = input_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->input_channel_; + int8_t *hw_out = output_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->output_channel_; + int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_; + int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_ * C4NUM * matmul_param_->col_2_ + : input_sum_ + task_id * thread_stride_ * C4NUM; + + RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_); + if (filter_peroc_) { - Conv1x1PreOptPeroc(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_, - packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_, - input_sum_ + task_id * thread_stride_hw_ * C8NUM * C8NUM, matmul_param_->deep_, - matmul_param_->col_, cur_hw, filter_zp_ptr_, matmul_param_->row_8_ * C8NUM); + PackInputSum16x4PerChannelArm32(hw_packed_in, hw_input_sum, filter_zp_ptr_, cur_hw, conv_param_->input_channel_, + conv_param_->output_channel_); } else { - Conv1x1PreOptPert(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_, - packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_, - input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, cur_hw, conv_param_); + PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, + UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_); } - return RET_OK; -} + Conv1x1Int8Arm32(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast(bias_data_), cur_hw, + matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_); -int Convolution1x1Int8Impl(void *cdata, int task_id) { - auto conv = reinterpret_cast(cdata); - auto error_code = conv->RunImpl(task_id); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "conv1x1 Int8 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; - return RET_ERROR; - } return RET_OK; } -int Convolution1x1Int8CPUKernel::InitRunBuf() { - input_sum_ = reinterpret_cast(ctx_->allocator->Malloc(input_sum_size_ * sizeof(int32_t))); - if (input_sum_ == nullptr) { - MS_LOG(ERROR) << "malloc input_sum_ failed."; - return RET_ERROR; +int Convolution1x1Int8CPUKernel::RunArm64Opt(int task_id) { + int cur_stride = thread_stride_ * C4NUM; + int res_stride = matmul_param_->row_ - task_id * thread_stride_ * C4NUM; + int cur_hw = MSMIN(cur_stride, res_stride); + if (cur_hw <= 0) { + return RET_OK; } + int8_t *hw_in = input_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->input_channel_; + int8_t *hw_out = output_ptr_ + task_id * thread_stride_ * C4NUM * conv_param_->output_channel_; + int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_4_; + int32_t *hw_input_sum = input_sum_ + task_id * thread_stride_ * C4NUM; - size_t size = support_optimize_ ? UP_ROUND(matmul_param_->row_, C8NUM) * UP_ROUND(matmul_param_->deep_, C4NUM) - : UP_ROUND(matmul_param_->row_, C4NUM) * UP_ROUND(matmul_param_->deep_, C16NUM); - packed_input_ = reinterpret_cast(ctx_->allocator->Malloc(size * sizeof(int8_t))); - if (packed_input_ == nullptr) { - MS_LOG(ERROR) << "conv1x1 int8 Malloc packed_input_ error!"; - return RET_ERROR; + if (filter_peroc_) { + PackInput4x4AndInputSumPert(hw_in, hw_packed_in, hw_input_sum, matmul_param_->deep_, cur_hw, 1); + } else { + PackInput4x4AndInputSumPert(hw_in, hw_packed_in, hw_input_sum, matmul_param_->deep_, cur_hw, + conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_); } + + Conv1x1Int8Opt(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast(bias_data_), cur_hw, + matmul_param_->col_, matmul_param_->deep_4_, left_shift_, right_shift_, multiplier_, conv_param_, + matmul_func_, filter_zp_ptr_); + return RET_OK; } -void Convolution1x1Int8CPUKernel::FreeRunBuf() { - if (packed_input_ != nullptr) { - ctx_->allocator->Free(packed_input_); - packed_input_ = nullptr; +int Convolution1x1Int8CPUKernel::DoRun(int task_id) { +#ifdef ENABLE_ARM32 + return RunArm32(task_id); +#else + if (support_optimize_) { + return RunArm64Opt(task_id); + } else { + return RunArm64(task_id); } - if (input_sum_ != nullptr) { - ctx_->allocator->Free(input_sum_); - input_sum_ = nullptr; +#endif +} + +int Convolution1x1Int8Run(void *cdata, int task_id) { + auto conv = reinterpret_cast(cdata); + auto error_code = conv->DoRun(task_id); + if (error_code != RET_OK) { + MS_LOG(ERROR) << "conv1x1 Int8 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; + return RET_ERROR; } - return; + return RET_OK; } int Convolution1x1Int8CPUKernel::Run() { @@ -461,7 +467,7 @@ int Convolution1x1Int8CPUKernel::Run() { for (int batch_index = 0; batch_index < conv_param_->input_batch_; batch_index++) { Pre1x1Trans(src_in + batch_index * conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_channel_, src_out + batch_index * matmul_param_->row_ * matmul_param_->col_); - auto ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8Impl, this, thread_count_); + auto ret = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8Run, this, thread_count_); if (ret != RET_OK) { MS_LOG(ERROR) << "ParallelLaunch run error error_code[" << ret << "]"; FreeRunBuf(); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h index e224cb9803..d2f6d512b9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h @@ -45,8 +45,12 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { void FreeRunBuf(); public: - int RunImpl(int task_id); - int RunPre(int task_id); + int DoRun(int task_id); + + private: + int RunArm32(int task_id); + int RunArm64(int task_id); + int RunArm64Opt(int task_id); private: void FreeResizeBuf(); @@ -58,8 +62,8 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { int InitBiasByzp(void *src_weight, int input_channel, int output_channel, int round_oc); private: - int32_t *input_sum_ = nullptr; /* per-oc: oc4 format */ - int32_t *filter_zp_ptr_ = nullptr; /* per-oc */ + int32_t *input_sum_ = nullptr; /* per-oc */ + int32_t *filter_zp_ptr_ = nullptr; /* per-oc up round */ int32_t *left_shift_ = nullptr; /* per-oc up round */ int32_t *right_shift_ = nullptr; /* per-oc up round */ int32_t *multiplier_ = nullptr; /* per-oc up round */ @@ -69,12 +73,10 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { int8_t *output_ptr_ = nullptr; size_t thread_count_ = 1; size_t thread_stride_ = 0; - size_t thread_count_hw_ = 1; - size_t thread_stride_hw_ = 0; bool pre_trans_input_ = false; size_t input_sum_size_ = 0; MatMulParameter *matmul_param_ = nullptr; - MATMUL_OPT_R_FUNC matmul_func_ = nullptr; + MATMUL_OPT_DP_FUNC matmul_func_ = nullptr; bool support_optimize_ = false; bool filter_peroc_ = false; }; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc b/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc index 46e242bdec..ce8ff62226 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc @@ -33,6 +33,9 @@ extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, in const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride, size_t peroc); +extern void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, size_t row8, size_t col8, size_t deep4, + const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int *multiplier, + int *left_shift, int *right_shift, size_t stride, size_t peroc, int *filter_zp); #ifdef ENABLE_ARM64 void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, @@ -57,6 +60,13 @@ void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi, output_zp, multiplier, left_shift, right_shift, row, col, stride, per_channel); } +void MatMulDpInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, + int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel, int32_t *filter_zp) { + return MatmulInt8DpOpt(a, b, dst, row, col, deep_4, input_sum, bias, mini, maxi, output_zp, multiplier, left_shift, + right_shift, stride, per_channel, filter_zp); +} #endif #ifdef __cplusplus diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.h b/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.h index 1e273706e1..022378327f 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.h @@ -33,6 +33,10 @@ void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel); +void MatMulDpInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, + size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, + int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, + int32_t maxi, size_t per_channel, int32_t *filter_zp); #endif #ifdef __cplusplus