From 7175e1921e808b482dd95fdadb3435b376f4bc67 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Tue, 8 Sep 2020 11:10:51 +0800 Subject: [PATCH] [MSLITE][Develop] arm cpu int8 conv depthwise support activation per channel --- .../nnacl/assembly/arm64/ConvDwInt8Center.S | 723 +++++------------- mindspore/lite/nnacl/int8/common_func.h | 7 +- .../lite/nnacl/int8/conv_depthwise_int8.c | 184 ++--- .../lite/nnacl/int8/conv_depthwise_int8.h | 5 +- mindspore/lite/nnacl/pack.c | 58 ++ mindspore/lite/nnacl/pack.h | 7 + .../arm/int8/convolution_depthwise_int8.cc | 13 +- .../convolution_depthwise_slidewindow_int8.cc | 187 ++++- .../convolution_depthwise_slidewindow_int8.h | 12 +- .../arm/int8/deconvolution_depthwise_int8.cc | 4 +- 10 files changed, 543 insertions(+), 657 deletions(-) diff --git a/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8Center.S b/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8Center.S index c2705a32c4..f80318165f 100644 --- a/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8Center.S +++ b/mindspore/lite/nnacl/assembly/arm64/ConvDwInt8Center.S @@ -7,13 +7,15 @@ .type ConvDwInt8Center, %function #endif -// void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height, size_t width, -// size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, -// size_t in_kh_step, size_t in_kw_step, int out_multiplier, int left_shift, -// int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max); +// void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, +// size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, +// size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, +// int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, +// int32_t *acc_min, int32_t *acc_max) + // x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: weight, x6: kernel_h, x7: kernel_w, // x8: out_h_step, x9: block_channel, x10: in_sh_step, x11: in_sw_step, x12: in_kh_step, x13: in_kw_step -// x14: out_multiplier, #56: left_shift, #64: right_shift, #72:out_zp, #80: acc_min, #88: acc_max +// x14: in_zp, #56: out_zp, #64: out_multiplier, #72:left_shift, #80: right_shift, #88: acc_min, #96: acc_max ConvDwInt8Center: // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers @@ -33,489 +35,174 @@ ConvDwInt8Center: ldr x12, [sp, #32] ldr x13, [sp, #40] - ldr w14, [sp, #56] - dup v26.4s, w14 + ldr x14, [sp, #48] // input_zp + ld1 {v19.8b}, [x14], #8 + + ldr x15, [sp, #56] // output_zp + ld1 {v20.4s}, [x15], #16 + ld1 {v21.4s}, [x15], #16 - ldr x15, [sp, #48] - dup v27.4s, w15 + ldr x16, [sp, #64] // out_multiplier + ld1 {v22.4s}, [x16], #16 + ld1 {v23.4s}, [x16], #16 - ldr w16, [sp, #64] - dup v28.4s, w16 + ldr x17, [sp, #72] // left_shift + ld1 {v24.4s}, [x17], #16 + ld1 {v25.4s}, [x17], #16 - ldr w17, [sp, #72] - dup v29.4s, w17 - - ldr w18, [sp, #80] - dup v30.4s, w18 + ldr x18, [sp, #80] // right shift + ld1 {v26.4s}, [x18], #16 + ld1 {v27.4s}, [x18], #16 - ldr w19, [sp, #88] - dup v31.4s, w19 + ldr x19, [sp, #88] // acc_min + ld1 {v28.4s}, [x19], #16 + ld1 {v29.4s}, [x19], #16 - ld1 {v24.4s}, [x3] + ldr x20, [sp, #96] // acc_max + ld1 {v30.4s}, [x20], #16 + ld1 {v31.4s}, [x20], #16 + + ld1 {v17.4s}, [x3], #16 + ld1 {v18.4s}, [x3], #16 LoopH: mov x23, x1 mov x24, x5 mov x3, x0 - cmp x24, #8 - blt LoopW - cmp x24, #16 - blt LoopW8 - LoopW16: - mov x19, #16 + LoopW4: + mov x19, #4 mul x19, x19, x11 + mov x25, #4 + mul x25, x25, x9 + mov x16, x23 mov x17, x2 mov x20, x6 - mov v0.16b, v24.16b - mov v1.16b, v24.16b - mov v2.16b, v24.16b - mov v3.16b, v24.16b - mov v4.16b, v24.16b - mov v5.16b, v24.16b - mov v6.16b, v24.16b - mov v7.16b, v24.16b - mov v8.16b, v24.16b - mov v9.16b, v24.16b - mov v10.16b, v24.16b - mov v11.16b, v24.16b - mov v12.16b, v24.16b - mov v13.16b, v24.16b - mov v14.16b, v24.16b - mov v15.16b, v24.16b - LoopKh16: + + mov v0.16b, v17.16b + mov v1.16b, v18.16b + mov v2.16b, v17.16b + mov v3.16b, v18.16b + mov v4.16b, v17.16b + mov v5.16b, v18.16b + mov v6.16b, v17.16b + mov v7.16b, v18.16b + LoopKh4: mov x18, x7 mov x21, x16 - LoopKw16: + LoopKw4: mov x22, x21 - ld1 {v25.4h}, [x17], #8 - ld1 {v16.4h}, [x22], x11 - ld1 {v17.4h}, [x22], x11 - smlal v0.4s, v16.4h, v25.4h - smlal v1.4s, v17.4h, v25.4h - ld1 {v18.4h}, [x22], x11 - ld1 {v19.4h}, [x22], x11 - smlal v2.4s, v18.4h, v25.4h - smlal v3.4s, v19.4h, v25.4h - ld1 {v20.4h}, [x22], x11 - ld1 {v21.4h}, [x22], x11 - smlal v4.4s, v20.4h, v25.4h - smlal v5.4s, v21.4h, v25.4h - ld1 {v22.4h}, [x22], x11 - ld1 {v23.4h}, [x22], x11 - smlal v6.4s, v22.4h, v25.4h - smlal v7.4s, v23.4h, v25.4h - ld1 {v16.4h}, [x22], x11 - ld1 {v17.4h}, [x22], x11 - smlal v8.4s, v16.4h, v25.4h - smlal v9.4s, v17.4h, v25.4h - ld1 {v18.4h}, [x22], x11 - ld1 {v19.4h}, [x22], x11 - smlal v10.4s, v18.4h, v25.4h - smlal v11.4s, v19.4h, v25.4h - ld1 {v20.4h}, [x22], x11 - ld1 {v21.4h}, [x22], x11 - smlal v12.4s, v20.4h, v25.4h - smlal v13.4s, v21.4h, v25.4h - ld1 {v22.4h}, [x22], x11 - ld1 {v23.4h}, [x22], x11 - smlal v14.4s, v22.4h, v25.4h - smlal v15.4s, v23.4h, v25.4h - subs x18, x18, #1 - add x21, x21, x13 - bne LoopKw16 - add x16, x16, x12 - subs x20, x20, #1 - bne LoopKh16 - - sqshl v0.4s, v0.4s, v26.4s - sqshl v1.4s, v1.4s, v26.4s - sqshl v2.4s, v2.4s, v26.4s - sqshl v3.4s, v3.4s, v26.4s - sqshl v4.4s, v4.4s, v26.4s - sqshl v5.4s, v5.4s, v26.4s - sqshl v6.4s, v6.4s, v26.4s - sqshl v7.4s, v7.4s, v26.4s - sqshl v8.4s, v8.4s, v26.4s - sqshl v9.4s, v9.4s, v26.4s - sqshl v10.4s, v10.4s, v26.4s - sqshl v11.4s, v11.4s, v26.4s - sqshl v12.4s, v12.4s, v26.4s - sqshl v13.4s, v13.4s, v26.4s - sqshl v14.4s, v14.4s, v26.4s - sqshl v15.4s, v15.4s, v26.4s - sqrdmulh v0.4s, v0.4s, v27.4s - sqrdmulh v1.4s, v1.4s, v27.4s - sqrdmulh v2.4s, v2.4s, v27.4s - sqrdmulh v3.4s, v3.4s, v27.4s - sqrdmulh v4.4s, v4.4s, v27.4s - sqrdmulh v5.4s, v5.4s, v27.4s - sqrdmulh v6.4s, v6.4s, v27.4s - sqrdmulh v7.4s, v7.4s, v27.4s - sqrdmulh v8.4s, v8.4s, v27.4s - sqrdmulh v9.4s, v9.4s, v27.4s - sqrdmulh v10.4s, v10.4s, v27.4s - sqrdmulh v11.4s, v11.4s, v27.4s - sqrdmulh v12.4s, v12.4s, v27.4s - sqrdmulh v13.4s, v13.4s, v27.4s - sqrdmulh v14.4s, v14.4s, v27.4s - sqrdmulh v15.4s, v15.4s, v27.4s - - and v16.16b, v28.16b, v0.16b - sshr v16.4s, v16.4s, #31 - sqadd v0.4s, v0.4s, v16.4s - srshl v0.4s, v0.4s, v28.4s - and v17.16b, v28.16b, v1.16b - sshr v17.4s, v17.4s, #31 - sqadd v1.4s, v1.4s, v17.4s - srshl v1.4s, v1.4s, v28.4s - and v18.16b, v28.16b, v2.16b - sshr v18.4s, v18.4s, #31 - sqadd v2.4s, v2.4s, v18.4s - srshl v2.4s, v2.4s, v28.4s - and v19.16b, v28.16b, v3.16b - sshr v19.4s, v19.4s, #31 - sqadd v3.4s, v3.4s, v19.4s - srshl v3.4s, v3.4s, v28.4s - and v20.16b, v28.16b, v4.16b - sshr v20.4s, v20.4s, #31 - sqadd v4.4s, v4.4s, v20.4s - srshl v4.4s, v4.4s, v28.4s - and v21.16b, v28.16b, v5.16b - sshr v21.4s, v21.4s, #31 - sqadd v5.4s, v5.4s, v21.4s - srshl v5.4s, v5.4s, v28.4s - and v22.16b, v28.16b, v6.16b - sshr v22.4s, v22.4s, #31 - sqadd v6.4s, v6.4s, v22.4s - srshl v6.4s, v6.4s, v28.4s - and v23.16b, v28.16b, v7.16b - sshr v23.4s, v23.4s, #31 - sqadd v7.4s, v7.4s, v23.4s - srshl v7.4s, v7.4s, v28.4s - and v16.16b, v28.16b, v8.16b - sshr v16.4s, v16.4s, #31 - sqadd v8.4s, v8.4s, v16.4s - srshl v8.4s, v8.4s, v28.4s - and v17.16b, v28.16b, v9.16b - sshr v17.4s, v17.4s, #31 - sqadd v9.4s, v9.4s, v17.4s - srshl v9.4s, v9.4s, v28.4s - and v18.16b, v28.16b, v10.16b - sshr v18.4s, v18.4s, #31 - sqadd v10.4s, v10.4s, v18.4s - srshl v10.4s, v10.4s, v28.4s - and v19.16b, v28.16b, v11.16b - sshr v19.4s, v19.4s, #31 - sqadd v11.4s, v11.4s, v19.4s - srshl v11.4s, v11.4s, v28.4s - and v20.16b, v28.16b, v12.16b - sshr v20.4s, v20.4s, #31 - sqadd v12.4s, v12.4s, v20.4s - srshl v12.4s, v12.4s, v28.4s - and v21.16b, v28.16b, v13.16b - sshr v21.4s, v21.4s, #31 - sqadd v13.4s, v13.4s, v21.4s - srshl v13.4s, v13.4s, v28.4s - and v22.16b, v28.16b, v14.16b - sshr v22.4s, v22.4s, #31 - sqadd v14.4s, v14.4s, v22.4s - srshl v14.4s, v14.4s, v28.4s - and v23.16b, v28.16b, v15.16b - sshr v23.4s, v23.4s, #31 - sqadd v15.4s, v15.4s, v23.4s - srshl v15.4s, v15.4s, v28.4s - - add v0.4s, v0.4s, v29.4s - add v1.4s, v1.4s, v29.4s - add v2.4s, v2.4s, v29.4s - add v3.4s, v3.4s, v29.4s - add v4.4s, v4.4s, v29.4s - add v5.4s, v5.4s, v29.4s - add v6.4s, v6.4s, v29.4s - add v7.4s, v7.4s, v29.4s - add v8.4s, v8.4s, v29.4s - add v9.4s, v9.4s, v29.4s - add v10.4s, v10.4s, v29.4s - add v11.4s, v11.4s, v29.4s - add v12.4s, v12.4s, v29.4s - add v13.4s, v13.4s, v29.4s - add v14.4s, v14.4s, v29.4s - add v15.4s, v15.4s, v29.4s - smax v0.4s, v0.4s, v30.4s - smax v1.4s, v1.4s, v30.4s - smax v2.4s, v2.4s, v30.4s - smax v3.4s, v3.4s, v30.4s - smax v4.4s, v4.4s, v30.4s - smax v5.4s, v5.4s, v30.4s - smax v6.4s, v6.4s, v30.4s - smax v7.4s, v7.4s, v30.4s - smax v8.4s, v8.4s, v30.4s - smax v9.4s, v9.4s, v30.4s - smax v10.4s, v10.4s, v30.4s - smax v11.4s, v11.4s, v30.4s - smax v12.4s, v12.4s, v30.4s - smax v13.4s, v13.4s, v30.4s - smax v14.4s, v14.4s, v30.4s - smax v15.4s, v15.4s, v30.4s - smin v0.4s, v0.4s, v31.4s - smin v1.4s, v1.4s, v31.4s - smin v2.4s, v2.4s, v31.4s - smin v3.4s, v3.4s, v31.4s - smin v4.4s, v4.4s, v31.4s - smin v5.4s, v5.4s, v31.4s - smin v6.4s, v6.4s, v31.4s - smin v7.4s, v7.4s, v31.4s - smin v8.4s, v8.4s, v31.4s - smin v9.4s, v9.4s, v31.4s - smin v10.4s, v10.4s, v31.4s - smin v11.4s, v11.4s, v31.4s - smin v12.4s, v12.4s, v31.4s - smin v13.4s, v13.4s, v31.4s - smin v14.4s, v14.4s, v31.4s - smin v15.4s, v15.4s, v31.4s + ld1 {v16.8h}, [x17], #16 - sqxtn v0.4h, v0.4s - sqxtn v1.4h, v1.4s - sqxtn v2.4h, v2.4s - sqxtn v3.4h, v3.4s - sqxtn v4.4h, v4.4s - sqxtn v5.4h, v5.4s - sqxtn v6.4h, v6.4s - sqxtn v7.4h, v7.4s - sqxtn v8.4h, v8.4s - sqxtn v9.4h, v9.4s - sqxtn v10.4h, v10.4s - sqxtn v11.4h, v11.4s - sqxtn v12.4h, v12.4s - sqxtn v13.4h, v13.4s - sqxtn v14.4h, v14.4s - sqxtn v15.4h, v15.4s - sqxtn v0.8b, v0.8h - sqxtn v1.8b, v1.8h - sqxtn v2.8b, v2.8h - sqxtn v3.8b, v3.8h - sqxtn v4.8b, v4.8h - sqxtn v5.8b, v5.8h - sqxtn v6.8b, v6.8h - sqxtn v7.8b, v7.8h - sqxtn v8.8b, v8.8h - sqxtn v9.8b, v9.8h - sqxtn v10.8b, v10.8h - sqxtn v11.8b, v11.8h - sqxtn v12.8b, v12.8h - sqxtn v13.8b, v13.8h - sqxtn v14.8b, v14.8h - sqxtn v15.8b, v15.8h - - add x17, x3, #1 - add x18, x3, #2 - add x21, x3, #3 - st1 {v0.b}[0], [x3], x9 - st1 {v0.b}[1], [x17], x9 - st1 {v0.b}[2], [x18], x9 - st1 {v0.b}[3], [x21], x9 - - st1 {v1.b}[0], [x3], x9 - st1 {v1.b}[1], [x17], x9 - st1 {v1.b}[2], [x18], x9 - st1 {v1.b}[3], [x21], x9 - - st1 {v2.b}[0], [x3], x9 - st1 {v2.b}[1], [x17], x9 - st1 {v2.b}[2], [x18], x9 - st1 {v2.b}[3], [x21], x9 - - st1 {v3.b}[0], [x3], x9 - st1 {v3.b}[1], [x17], x9 - st1 {v3.b}[2], [x18], x9 - st1 {v3.b}[3], [x21], x9 - - st1 {v4.b}[0], [x3], x9 - st1 {v4.b}[1], [x17], x9 - st1 {v4.b}[2], [x18], x9 - st1 {v4.b}[3], [x21], x9 - - st1 {v5.b}[0], [x3], x9 - st1 {v5.b}[1], [x17], x9 - st1 {v5.b}[2], [x18], x9 - st1 {v5.b}[3], [x21], x9 - - st1 {v6.b}[0], [x3], x9 - st1 {v6.b}[1], [x17], x9 - st1 {v6.b}[2], [x18], x9 - st1 {v6.b}[3], [x21], x9 - - st1 {v7.b}[0], [x3], x9 - st1 {v7.b}[1], [x17], x9 - st1 {v7.b}[2], [x18], x9 - st1 {v7.b}[3], [x21], x9 - - st1 {v8.b}[0], [x3], x9 - st1 {v8.b}[1], [x17], x9 - st1 {v8.b}[2], [x18], x9 - st1 {v8.b}[3], [x21], x9 - - st1 {v9.b}[0], [x3], x9 - st1 {v9.b}[1], [x17], x9 - st1 {v9.b}[2], [x18], x9 - st1 {v9.b}[3], [x21], x9 - - st1 {v10.b}[0], [x3], x9 - st1 {v10.b}[1], [x17], x9 - st1 {v10.b}[2], [x18], x9 - st1 {v10.b}[3], [x21], x9 - - st1 {v11.b}[0], [x3], x9 - st1 {v11.b}[1], [x17], x9 - st1 {v11.b}[2], [x18], x9 - st1 {v11.b}[3], [x21], x9 - - st1 {v12.b}[0], [x3], x9 - st1 {v12.b}[1], [x17], x9 - st1 {v12.b}[2], [x18], x9 - st1 {v12.b}[3], [x21], x9 - - st1 {v13.b}[0], [x3], x9 - st1 {v13.b}[1], [x17], x9 - st1 {v13.b}[2], [x18], x9 - st1 {v13.b}[3], [x21], x9 - - st1 {v14.b}[0], [x3], x9 - st1 {v14.b}[1], [x17], x9 - st1 {v14.b}[2], [x18], x9 - st1 {v14.b}[3], [x21], x9 - - st1 {v15.b}[0], [x3], x9 - st1 {v15.b}[1], [x17], x9 - st1 {v15.b}[2], [x18], x9 - st1 {v15.b}[3], [x21], x9 + ld1 {v15.8b}, [x22], x11 + ssubl v14.8h, v15.8b, v19.8b + smlal v0.4s, v14.4h, v16.4h + smlal2 v1.4s, v14.8h, v16.8h + + ld1 {v13.8b}, [x22], x11 + ssubl v12.8h, v13.8b, v19.8b + smlal v2.4s, v12.4h, v16.4h + smlal2 v3.4s, v12.8h, v16.8h + + ld1 {v11.8b}, [x22], x11 + ssubl v10.8h, v11.8b, v19.8b + smlal v4.4s, v10.4h, v16.4h + smlal2 v5.4s, v10.8h, v16.8h + + ld1 {v9.8b}, [x22], x11 + ssubl v8.8h, v9.8b, v19.8b + smlal v6.4s, v8.4h, v16.4h + smlal2 v7.4s, v8.8h, v16.8h - add x23, x23, x19 - sub x24, x24, #16 - cmp x24, #0 - ble LoopWEnd - cmp x24, #8 - blt LoopW - cmp x24, #16 - bge LoopW16 - LoopW8: - mov x19, #8 - mul x19, x19, x11 - mov x16, x23 - mov x17, x2 - mov x20, x6 - mov v0.16b, v24.16b - mov v1.16b, v24.16b - mov v2.16b, v24.16b - mov v3.16b, v24.16b - mov v4.16b, v24.16b - mov v5.16b, v24.16b - mov v6.16b, v24.16b - mov v7.16b, v24.16b - LoopKh8: - mov x18, x7 - mov x21, x16 - LoopKw8: - mov x22, x21 - ld1 {v25.4h}, [x17], #8 - ld1 {v16.4h}, [x22], x11 - ld1 {v17.4h}, [x22], x11 - smlal v0.4s, v16.4h, v25.4h - smlal v1.4s, v17.4h, v25.4h - ld1 {v18.4h}, [x22], x11 - ld1 {v19.4h}, [x22], x11 - smlal v2.4s, v18.4h, v25.4h - smlal v3.4s, v19.4h, v25.4h - ld1 {v20.4h}, [x22], x11 - ld1 {v21.4h}, [x22], x11 - smlal v4.4s, v20.4h, v25.4h - smlal v5.4s, v21.4h, v25.4h - ld1 {v22.4h}, [x22], x11 - ld1 {v23.4h}, [x22], x11 - smlal v6.4s, v22.4h, v25.4h - smlal v7.4s, v23.4h, v25.4h subs x18, x18, #1 add x21, x21, x13 - bne LoopKw8 + bne LoopKw4 add x16, x16, x12 subs x20, x20, #1 - bne LoopKh8 - - sqshl v0.4s, v0.4s, v26.4s - sqshl v1.4s, v1.4s, v26.4s - sqshl v2.4s, v2.4s, v26.4s - sqshl v3.4s, v3.4s, v26.4s - sqshl v4.4s, v4.4s, v26.4s - sqshl v5.4s, v5.4s, v26.4s - sqshl v6.4s, v6.4s, v26.4s - sqshl v7.4s, v7.4s, v26.4s - sqrdmulh v0.4s, v0.4s, v27.4s - sqrdmulh v1.4s, v1.4s, v27.4s - sqrdmulh v2.4s, v2.4s, v27.4s - sqrdmulh v3.4s, v3.4s, v27.4s - sqrdmulh v4.4s, v4.4s, v27.4s - sqrdmulh v5.4s, v5.4s, v27.4s - sqrdmulh v6.4s, v6.4s, v27.4s - sqrdmulh v7.4s, v7.4s, v27.4s - - and v16.16b, v28.16b, v0.16b - sshr v16.4s, v16.4s, #31 - sqadd v0.4s, v0.4s, v16.4s - srshl v0.4s, v0.4s, v28.4s - and v17.16b, v28.16b, v1.16b - sshr v17.4s, v17.4s, #31 - sqadd v1.4s, v1.4s, v17.4s - srshl v1.4s, v1.4s, v28.4s - and v18.16b, v28.16b, v2.16b - sshr v18.4s, v18.4s, #31 - sqadd v2.4s, v2.4s, v18.4s - srshl v2.4s, v2.4s, v28.4s - and v19.16b, v28.16b, v3.16b - sshr v19.4s, v19.4s, #31 - sqadd v3.4s, v3.4s, v19.4s - srshl v3.4s, v3.4s, v28.4s - and v20.16b, v28.16b, v4.16b - sshr v20.4s, v20.4s, #31 - sqadd v4.4s, v4.4s, v20.4s - srshl v4.4s, v4.4s, v28.4s - and v21.16b, v28.16b, v5.16b - sshr v21.4s, v21.4s, #31 - sqadd v5.4s, v5.4s, v21.4s - srshl v5.4s, v5.4s, v28.4s - and v22.16b, v28.16b, v6.16b - sshr v22.4s, v22.4s, #31 - sqadd v6.4s, v6.4s, v22.4s - srshl v6.4s, v6.4s, v28.4s - and v23.16b, v28.16b, v7.16b - sshr v23.4s, v23.4s, #31 - sqadd v7.4s, v7.4s, v23.4s - srshl v7.4s, v7.4s, v28.4s - - add v0.4s, v0.4s, v29.4s - add v1.4s, v1.4s, v29.4s - add v2.4s, v2.4s, v29.4s - add v3.4s, v3.4s, v29.4s - add v4.4s, v4.4s, v29.4s - add v5.4s, v5.4s, v29.4s - add v6.4s, v6.4s, v29.4s - add v7.4s, v7.4s, v29.4s - smax v0.4s, v0.4s, v30.4s - smax v1.4s, v1.4s, v30.4s - smax v2.4s, v2.4s, v30.4s - smax v3.4s, v3.4s, v30.4s - smax v4.4s, v4.4s, v30.4s - smax v5.4s, v5.4s, v30.4s - smax v6.4s, v6.4s, v30.4s - smax v7.4s, v7.4s, v30.4s - smin v0.4s, v0.4s, v31.4s + bne LoopKh4 + + sqshl v0.4s, v0.4s, v24.4s + sqshl v1.4s, v1.4s, v25.4s + sqshl v2.4s, v2.4s, v24.4s + sqshl v3.4s, v3.4s, v25.4s + sqshl v4.4s, v4.4s, v24.4s + sqshl v5.4s, v5.4s, v25.4s + sqshl v6.4s, v6.4s, v24.4s + sqshl v7.4s, v7.4s, v25.4s + + sqrdmulh v0.4s, v0.4s, v22.4s + sqrdmulh v1.4s, v1.4s, v23.4s + sqrdmulh v2.4s, v2.4s, v22.4s + sqrdmulh v3.4s, v3.4s, v23.4s + sqrdmulh v4.4s, v4.4s, v22.4s + sqrdmulh v5.4s, v5.4s, v23.4s + sqrdmulh v6.4s, v6.4s, v22.4s + sqrdmulh v7.4s, v7.4s, v23.4s + + and v15.16b, v26.16b, v0.16b + sshr v15.4s, v15.4s, #31 + sqadd v0.4s, v0.4s, v15.4s + srshl v0.4s, v0.4s, v26.4s + + and v14.16b, v27.16b, v1.16b + sshr v14.4s, v14.4s, #31 + sqadd v1.4s, v1.4s, v14.4s + srshl v1.4s, v1.4s, v27.4s + + and v13.16b, v26.16b, v2.16b + sshr v13.4s, v13.4s, #31 + sqadd v2.4s, v2.4s, v13.4s + srshl v2.4s, v2.4s, v26.4s + + and v12.16b, v27.16b, v3.16b + sshr v12.4s, v12.4s, #31 + sqadd v3.4s, v3.4s, v12.4s + srshl v3.4s, v3.4s, v27.4s + + and v11.16b, v26.16b, v4.16b + sshr v11.4s, v11.4s, #31 + sqadd v4.4s, v4.4s, v11.4s + srshl v4.4s, v4.4s, v26.4s + + and v10.16b, v27.16b, v5.16b + sshr v10.4s, v10.4s, #31 + sqadd v5.4s, v5.4s, v10.4s + srshl v5.4s, v5.4s, v27.4s + + and v9.16b, v26.16b, v6.16b + sshr v9.4s, v9.4s, #31 + sqadd v6.4s, v6.4s, v9.4s + srshl v6.4s, v6.4s, v26.4s + + and v8.16b, v27.16b, v7.16b + sshr v8.4s, v8.4s, #31 + sqadd v7.4s, v7.4s, v8.4s + srshl v7.4s, v7.4s, v27.4s + + add v0.4s, v0.4s, v20.4s + add v1.4s, v1.4s, v21.4s + add v2.4s, v2.4s, v20.4s + add v3.4s, v3.4s, v21.4s + add v4.4s, v4.4s, v20.4s + add v5.4s, v5.4s, v21.4s + add v6.4s, v6.4s, v20.4s + add v7.4s, v7.4s, v21.4s + smax v0.4s, v0.4s, v28.4s + smax v1.4s, v1.4s, v29.4s + smax v2.4s, v2.4s, v28.4s + smax v3.4s, v3.4s, v29.4s + smax v4.4s, v4.4s, v28.4s + smax v5.4s, v5.4s, v29.4s + smax v6.4s, v6.4s, v28.4s + smax v7.4s, v7.4s, v29.4s + smin v0.4s, v0.4s, v30.4s smin v1.4s, v1.4s, v31.4s - smin v2.4s, v2.4s, v31.4s + smin v2.4s, v2.4s, v30.4s smin v3.4s, v3.4s, v31.4s - smin v4.4s, v4.4s, v31.4s + smin v4.4s, v4.4s, v30.4s smin v5.4s, v5.4s, v31.4s - smin v6.4s, v6.4s, v31.4s + smin v6.4s, v6.4s, v30.4s smin v7.4s, v7.4s, v31.4s sqxtn v0.4h, v0.4s @@ -535,93 +222,81 @@ ConvDwInt8Center: sqxtn v6.8b, v6.8h sqxtn v7.8b, v7.8h - add x17, x3, #1 - add x18, x3, #2 - add x21, x3, #3 - st1 {v0.b}[0], [x3], x9 - st1 {v0.b}[1], [x17], x9 - st1 {v0.b}[2], [x18], x9 - st1 {v0.b}[3], [x21], x9 - - st1 {v1.b}[0], [x3], x9 - st1 {v1.b}[1], [x17], x9 - st1 {v1.b}[2], [x18], x9 - st1 {v1.b}[3], [x21], x9 - - st1 {v2.b}[0], [x3], x9 - st1 {v2.b}[1], [x17], x9 - st1 {v2.b}[2], [x18], x9 - st1 {v2.b}[3], [x21], x9 - - st1 {v3.b}[0], [x3], x9 - st1 {v3.b}[1], [x17], x9 - st1 {v3.b}[2], [x18], x9 - st1 {v3.b}[3], [x21], x9 - - st1 {v4.b}[0], [x3], x9 - st1 {v4.b}[1], [x17], x9 - st1 {v4.b}[2], [x18], x9 - st1 {v4.b}[3], [x21], x9 - - st1 {v5.b}[0], [x3], x9 - st1 {v5.b}[1], [x17], x9 - st1 {v5.b}[2], [x18], x9 - st1 {v5.b}[3], [x21], x9 - - st1 {v6.b}[0], [x3], x9 - st1 {v6.b}[1], [x17], x9 - st1 {v6.b}[2], [x18], x9 - st1 {v6.b}[3], [x21], x9 - - st1 {v7.b}[0], [x3], x9 - st1 {v7.b}[1], [x17], x9 - st1 {v7.b}[2], [x18], x9 - st1 {v7.b}[3], [x21], x9 - + mov x16, x3 + add x17, x16, x9 + add x18, x17, x9 + add x21, x18, x9 + + st1 {v0.s}[0], [x16], #4 + st1 {v1.s}[0], [x16], #4 + st1 {v2.s}[0], [x17], #4 + st1 {v3.s}[0], [x17], #4 + st1 {v4.s}[0], [x18], #4 + st1 {v5.s}[0], [x18], #4 + st1 {v6.s}[0], [x21], #4 + st1 {v7.s}[0], [x21], #4 + + add x3, x3, x25 add x23, x23, x19 - sub x24, x24, #8 + sub x24, x24, #4 cmp x24, #0 ble LoopWEnd - cmp x24, #8 - bge LoopW8 + cmp x24, #4 + bge LoopW4 + LoopW: mov x16, x23 mov x17, x2 mov x20, x6 - mov v0.16b, v24.16b + mov v0.16b, v17.16b + mov v1.16b, v18.16b LoopKh: mov x18, x7 mov x22, x16 LoopKw: - ld1 {v16.4h}, [x22], x13 - ld1 {v25.4h}, [x17], #8 - smlal v0.4s, v16.4h, v25.4h + ld1 {v15.8b}, [x22], x13 + ssubl v14.8h, v15.8b, v19.8b + ld1 {v16.8h}, [x17], #16 + smlal v0.4s, v14.4h, v16.4h + smlal2 v1.4s, v14.8h, v16.8h subs x18, x18, #1 bne LoopKw add x16, x16, x12 subs x20, x20, #1 bne LoopKh - sqshl v0.4s, v0.4s, v26.4s - sqrdmulh v0.4s, v0.4s, v27.4s + sqshl v0.4s, v0.4s, v24.4s + sqrdmulh v0.4s, v0.4s, v22.4s + sqshl v1.4s, v1.4s, v25.4s + sqrdmulh v1.4s, v1.4s, v23.4s - and v16.16b, v28.16b, v0.16b - sshr v16.4s, v16.4s, #31 - sqadd v0.4s, v0.4s, v16.4s - srshl v0.4s, v0.4s, v28.4s + and v15.16b, v26.16b, v0.16b + sshr v15.4s, v15.4s, #31 + sqadd v0.4s, v0.4s, v15.4s + srshl v0.4s, v0.4s, v26.4s - add v0.4s, v0.4s, v29.4s - smax v0.4s, v0.4s, v30.4s - smin v0.4s, v0.4s, v31.4s + and v14.16b, v27.16b, v1.16b + sshr v14.4s, v14.4s, #31 + sqadd v1.4s, v1.4s, v14.4s + srshl v1.4s, v1.4s, v27.4s + + add v0.4s, v0.4s, v20.4s + smax v0.4s, v0.4s, v28.4s + smin v0.4s, v0.4s, v30.4s sqxtn v0.4h, v0.4s sqxtn v0.8b, v0.8h + add v1.4s, v1.4s, v21.4s + smax v1.4s, v1.4s, v29.4s + smin v1.4s, v1.4s, v31.4s + + sqxtn v1.4h, v1.4s + sqxtn v1.8b, v1.8h + mov x17, x3 - st1 {v0.b}[0], [x17], #1 - st1 {v0.b}[1], [x17], #1 - st1 {v0.b}[2], [x17], #1 - st1 {v0.b}[3], [x17], #1 + st1 {v0.s}[0], [x17], #4 + st1 {v1.s}[0], [x17], #4 add x3, x3, x9 add x23, x23, x11 diff --git a/mindspore/lite/nnacl/int8/common_func.h b/mindspore/lite/nnacl/int8/common_func.h index 3e79180d19..eabb22c6f4 100644 --- a/mindspore/lite/nnacl/int8/common_func.h +++ b/mindspore/lite/nnacl/int8/common_func.h @@ -45,10 +45,11 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step); -void ConvDwInt8Center(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, size_t height, +void ConvDwInt8Center(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height, size_t width, size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, - size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int out_multiplier, - int left_shift, int right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max); + size_t in_sh_step, size_t in_sw_step, size_t in_kh_step, size_t in_kw_step, int8_t *in_zp, + int32_t *out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, + int32_t *acc_min, int32_t *acc_max); void ConvDwInt8Row(int32_t *output_ptr, const int8_t *input_ptr, const int16_t *weight_ptr, int num_pixels, int output_channel, int input_step, int8_t input_zp); void ConvDwInt8PostAlign4(int8_t *dst, int32_t *buffer, int num_pixels, int32_t output_zp, int32_t out_multiplier, diff --git a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c index 2514bd1bb4..b84bc58357 100644 --- a/mindspore/lite/nnacl/int8/conv_depthwise_int8.c +++ b/mindspore/lite/nnacl/int8/conv_depthwise_int8.c @@ -138,75 +138,67 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da } /*conv depthwise int8 end*/ -/*conv depthwise sliding window int8 begin*/ -void DepthwiseBorderPixelInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, - int width, int in_kh_step, int in_kw_step, int kernel_w, int *out_multiplier, - int *left_shift, int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max, - bool per_channel) { - int tmp_buffer[C4NUM]; - for (int i = 0; i < C4NUM; i++) { +/*conv depthwise sliding window perchannel int8 begin*/ +void DepthwiseBorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, + int width, int in_kh_step, int in_kw_step, int kernel_w, int8_t *input_zp, + int32_t *out_zp, int *out_multiplier, int *left_shift, int *right_shift, int32_t *acc_min, + int32_t *acc_max) { + int tmp_buffer[C8NUM]; + for (int i = 0; i < C8NUM; i++) { tmp_buffer[i] = 0; } - const int16_t *src_kh = src; + const int8_t *src_kh = src; const int16_t *weight_kh = weight; for (int kh = 0; kh < height; kh++) { - const int16_t *src_kw = src_kh; + const int8_t *src_kw = src_kh; const int16_t *weight_kw = weight_kh; for (int kw = 0; kw < width; kw++) { - for (int c = 0; c < C4NUM; c++) { - tmp_buffer[c] += src_kw[c] * weight_kw[c]; + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += (src_kw[c] - input_zp[c]) * weight_kw[c]; } src_kw += in_kw_step; - weight_kw += C4NUM; + weight_kw += C8NUM; } // kernel_w loop src_kh += in_kh_step; - weight_kh += kernel_w * C4NUM; + weight_kh += kernel_w * C8NUM; } // kernel_h loop - int32_t left = left_shift[0]; - int32_t right = right_shift[0]; - int32_t multiplier = out_multiplier[0]; - for (int c = 0; c < C4NUM; c++) { - if (per_channel) { - left = left_shift[c]; - right = right_shift[c]; - multiplier = out_multiplier[c]; - } + + for (int c = 0; c < C8NUM; c++) { tmp_buffer[c] += bias[c]; tmp_buffer[c] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right); - tmp_buffer[c] += out_zp; - tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); - tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + tmp_buffer[c] += out_zp[c]; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); dst[c] = (tmp_buffer[c]); } } -void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int top, +void DepthwiseBorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top, int bottom, int left, int right, const ConvParameter *conv_param, - const SlidingWindowParam *sliding, int *out_multiplier, int *left_shift, int *right_shift, - bool per_channel) { + const SlidingWindowParam *sliding, int8_t *in_zp, int32_t *out_zp, int *out_multiplier, + int *left_shift, int *right_shift, int32_t *acc_min, int32_t *acc_max) { int8_t *dst_h = dst + top * sliding->out_h_step_; for (int oh = top; oh < bottom; oh++) { int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); - const int16_t *src_h = src + ih * sliding->in_h_step_; + const int8_t *src_h = src + ih * sliding->in_h_step_; int8_t *dst_kernel = dst_h + left * sliding->block_channel_; for (int ow = left; ow < right; ow++) { int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); - const int16_t *src_w = src_h + iw * sliding->block_channel_; + const int8_t *src_w = src_h + iw * sliding->block_channel_; - const int16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; - const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C4NUM; + const int8_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; + const int16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * C8NUM; DepthwiseBorderPixelInt8(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, - sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, out_multiplier, - left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - per_channel); + sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_w_, in_zp, out_zp, + out_multiplier, left_shift, right_shift, acc_min, acc_max); dst_kernel += sliding->block_channel_; } // width loop @@ -215,52 +207,46 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight, } #ifndef ENABLE_ARM64 -void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, const int32_t *bias, int height, +void DepthwiseCenterInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, - int in_sw_step, int in_kh_step, int in_kw_step, int *out_multiplier, int *left_shift, - int *right_shift, int32_t out_zp, int32_t acc_min, int32_t acc_max, bool per_channel) { - int tmp_buffer[C4NUM]; + int in_sw_step, int in_kh_step, int in_kw_step, int8_t *in_zp, int32_t *out_zp, + int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t *acc_min, + int32_t *acc_max) { + int tmp_buffer[C8NUM]; int8_t *dst_h = dst; - const int16_t *src_h = src; + const int8_t *src_h = src; for (int oh = 0; oh < height; oh++) { int8_t *dst_w = dst_h; - const int16_t *src_w = src_h; + const int8_t *src_w = src_h; for (int ow = 0; ow < width; ow++) { - const int16_t *src_kh = src_w; + const int8_t *src_kh = src_w; const int16_t *weight_kh = weight; - for (int i = 0; i < C4NUM; i++) { + for (int i = 0; i < C8NUM; i++) { tmp_buffer[i] = 0; } for (int kh = 0; kh < kernel_h; kh++) { - const int16_t *src_kw = src_kh; + const int8_t *src_kw = src_kh; const int16_t *weight_kw = weight_kh; for (int kw = 0; kw < kernel_w; kw++) { - for (int c = 0; c < C4NUM; c++) { - tmp_buffer[c] += src_kw[c] * weight_kw[c]; + for (int c = 0; c < C8NUM; c++) { + tmp_buffer[c] += (src_kw[c] - in_zp[c]) * weight_kw[c]; } src_kw += in_kw_step; - weight_kw += C4NUM; + weight_kw += C8NUM; } // kernel_w loop src_kh += in_kh_step; - weight_kh += kernel_w * C4NUM; + weight_kh += kernel_w * C8NUM; } // kernel_h loop // add bias relu - int32_t left = left_shift[0]; - int32_t right = right_shift[0]; - int32_t multiplier = out_multiplier[0]; - for (int c = 0; c < C4NUM; c++) { - if (per_channel) { - left = left_shift[c]; - right = right_shift[c]; - multiplier = out_multiplier[c]; - } + for (int c = 0; c < C8NUM; c++) { tmp_buffer[c] += bias[c]; tmp_buffer[c] = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left), multiplier), -right); - tmp_buffer[c] += out_zp; - tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min); - tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max); + SaturatingRoundingDoublingHighMul(tmp_buffer[c] * (1 << (unsigned int)left_shift[c]), out_multiplier[c]), + -right_shift[c]); + tmp_buffer[c] += out_zp[c]; + tmp_buffer[c] = MSMAX(tmp_buffer[c], acc_min[c]); + tmp_buffer[c] = MSMIN(tmp_buffer[c], acc_max[c]); dst_w[c] = (tmp_buffer[c]); } dst_w += block_channel; @@ -272,69 +258,65 @@ void DepthwiseCenterInt8(int8_t *dst, const int16_t *src, const int16_t *weight, } #endif -void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, - const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id) { - const int16_t *src = input_data; +void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + int8_t *input_zp, int32_t *output_zp, const ConvParameter *conv_param, + const SlidingWindowParam *sliding, int task_id) { + const int8_t *src = input_data; int8_t *dst = output_data; - bool per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL; - int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_; - int *left_shift = conv_param->conv_quant_arg_.left_shift_; - int *right_shift = conv_param->conv_quant_arg_.right_shift_; for (int b = 0; b < conv_param->output_batch_; b++) { for (int oc = task_id; oc < sliding->c_block_; oc += conv_param->thread_num_) { - const int16_t *src_data = src + oc * C4NUM; - int8_t *dst_data = dst + oc * C4NUM; + const int8_t *src_data = src + oc * C8NUM; + int8_t *dst_data = dst + oc * C8NUM; const int16_t *weight = weight_data + oc * sliding->kernel_step_; - const int32_t *bias = bias_data + oc * C4NUM; + const int32_t *bias = bias_data + oc * C8NUM; - if (per_channel) { - out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C4NUM; - left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C4NUM; - right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C4NUM; - } + int *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_ + oc * C8NUM; + int *left_shift = conv_param->conv_quant_arg_.left_shift_ + oc * C8NUM; + int *right_shift = conv_param->conv_quant_arg_.right_shift_ + oc * C8NUM; + int *acc_min = conv_param->conv_quant_arg_.out_act_min_ + oc * C8NUM; + int *acc_max = conv_param->conv_quant_arg_.out_act_max_ + oc * C8NUM; + int8_t *in_zp = input_zp + oc * C8NUM; + int32_t *out_zp = output_zp + oc * C8NUM; DepthwiseBorderInt8(dst_data, src_data, weight, bias, 0, sliding->top_, 0, conv_param->output_w_, conv_param, - sliding, out_multiplier, left_shift, right_shift, per_channel); + sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max); DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->bottom_, conv_param->output_h_, 0, - conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift, - per_channel); + conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, + right_shift, acc_min, acc_max); DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, 0, sliding->left_, - conv_param, sliding, out_multiplier, left_shift, right_shift, per_channel); + conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, + acc_max); DepthwiseBorderInt8(dst_data, src_data, weight, bias, sliding->top_, sliding->bottom_, sliding->right_, - conv_param->output_w_, conv_param, sliding, out_multiplier, left_shift, right_shift, - per_channel); + conv_param->output_w_, conv_param, sliding, in_zp, out_zp, out_multiplier, left_shift, + right_shift, acc_min, acc_max); if (sliding->right_ > sliding->left_ && sliding->bottom_ > sliding->top_) { int in_h_start = sliding->top_ * conv_param->stride_h_ - conv_param->pad_u_; int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_l_; - const int16_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; + const int8_t *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; int8_t *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; #ifdef ENABLE_ARM64 ConvDwInt8Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(int8_t), - sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int16_t), - sliding->in_sw_step_ * sizeof(int16_t), sliding->in_kh_step_ * sizeof(int16_t), - sliding->in_kw_step_ * sizeof(int16_t), conv_param->conv_quant_arg_.quant_multiplier_[0], - conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0], - conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]); + sliding->block_channel_ * sizeof(int8_t), sliding->in_sh_step_ * sizeof(int8_t), + sliding->in_sw_step_ * sizeof(int8_t), sliding->in_kh_step_ * sizeof(int8_t), + sliding->in_kw_step_ * sizeof(int8_t), in_zp, out_zp, out_multiplier, left_shift, right_shift, + acc_min, acc_max); #else - - DepthwiseCenterInt8( - out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, - conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, - sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, out_multiplier, - left_shift, right_shift, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], per_channel); + DepthwiseCenterInt8(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, + sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, + sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, + sliding->in_kh_step_, sliding->in_kw_step_, in_zp, out_zp, out_multiplier, left_shift, + right_shift, acc_min, acc_max); #endif } - } // output C4 loop + } // output C8 loop src += sliding->in_step_; dst += sliding->out_step_; } // batch loop - // output nhwc4 + // output nhwc8 } -/*conv depthwise sliding window int8 end*/ +/*conv depthwise sliding window perchannel int8 end*/ /*deconv depthwise int8 begin*/ void DeconvDepthwiseBorderPixelInt8(int32_t *dst, const int16_t *src, const int16_t *weight, int height, int width, diff --git a/mindspore/lite/nnacl/int8/conv_depthwise_int8.h b/mindspore/lite/nnacl/int8/conv_depthwise_int8.h index 004b9dff27..016f4fd5b2 100644 --- a/mindspore/lite/nnacl/int8/conv_depthwise_int8.h +++ b/mindspore/lite/nnacl/int8/conv_depthwise_int8.h @@ -27,8 +27,9 @@ extern "C" { void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, const ConvParameter *conv_param, int task_id); -void ConvDwSWInt8(int8_t *output_data, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, - const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); +void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data, + int8_t *input_zp, int32_t *output_zp, const ConvParameter *conv_param, + const SlidingWindowParam *sliding, int task_id); void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *input_data, const int16_t *weight_data, const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index 83d7caa928..d5e4fac038 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -869,6 +869,45 @@ void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int c } } +void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + int nhwc8_batch_offset = 0; + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + nhwc8_batch_offset + i * c8 * C8NUM, (int8_t *)src + batch_offset + i * channel, + channel); + } + nhwc8_batch_offset += nhwc8_batch_unit_offset; + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + +void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { + int c8 = UP_DIV(channel, C8NUM); + int nhwc8_batch_unit_offset = c8 * C8NUM * plane; + int ic_remainder_ = channel % C8NUM; + if (ic_remainder_ != 0) { + for (int b = 0; b < batch; b++) { + int batch_offset = b * channel * plane; + int nhwc8_batch_offset = b * nhwc8_batch_unit_offset; + for (int i = 0; i < plane; i++) { + memcpy((int8_t *)dst + batch_offset + i * channel, (int8_t *)src + nhwc8_batch_offset + i * c8 * C8NUM, + channel); + } + } + } else { + size_t ori_input_size = batch * plane * channel; + memcpy((int8_t *)dst, (int8_t *)src, ori_input_size); + } +} + void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { int nhwc4_batch_offset = 0; int c4 = UP_DIV(channel, C4NUM); @@ -1174,6 +1213,25 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, ConvQuantArg *quant_qrg) { int weight_zp = quant_qrg->filter_quant_args_[0].zp_; + for (int c = 0; c < channel; c++) { + if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { + weight_zp = quant_qrg->filter_quant_args_[c].zp_; + } + int c8_block_num = c / C8NUM; + int c8_block_rem = c % C8NUM; + const int8_t *src_c = origin_weight + c * plane; + int16_t *dst_c = packed_weight_ + c8_block_num * plane * C8NUM; + for (int k = 0; k < plane; k++) { + const int8_t *src_kernel = src_c + k; + int16_t *dst_kernel = dst_c + C8NUM * k + c8_block_rem; + *dst_kernel = (int16_t)(src_kernel[0] - weight_zp); + } + } +} + +void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + ConvQuantArg *quant_qrg) { + int weight_zp = quant_qrg->filter_quant_args_[0].zp_; for (int c = 0; c < channel; c++) { if (quant_qrg->per_channel_ & FILTER_PER_CHANNEL) { weight_zp = quant_qrg->filter_quant_args_[c].zp_; diff --git a/mindspore/lite/nnacl/pack.h b/mindspore/lite/nnacl/pack.h index b05083c52d..c582bedcd0 100644 --- a/mindspore/lite/nnacl/pack.h +++ b/mindspore/lite/nnacl/pack.h @@ -96,6 +96,10 @@ void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int c void PackNHWC4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); +void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int channel); + +void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); + void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); @@ -114,6 +118,9 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, ConvQuantArg *quant_qrg); + +void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel, + ConvQuantArg *quant_qrg); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc index e8a28ae3df..972fc1e99b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_int8.cc @@ -177,8 +177,17 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector const mindspore::lite::PrimitiveC *primitive) { MS_ASSERT(opParameter != nullptr); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); - auto kernel = - new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); + + kernel::LiteKernel *kernel; + auto act_quant_size = + MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size()); + if (act_quant_size == 1) { // per tensor + kernel = new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); + } else { // per channel + kernel = + new (std::nothrow) kernel::ConvolutionDepthwiseSWInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); + } + if (kernel == nullptr) { MS_LOG(ERROR) << "kernel is nullptr."; return nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.cc index b251eed3c9..ec4ee2aaa9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.cc @@ -37,6 +37,7 @@ ConvolutionDepthwiseSWInt8CPUKernel::~ConvolutionDepthwiseSWInt8CPUKernel() { free(packed_weight_); packed_weight_ = nullptr; } + FreeTmpQuant(); FreeQuantParam(); } @@ -45,8 +46,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() { // o, h, w, i -> o/8, h, w, i, 8; o == group, i == 1 auto weight_tensor = in_tensors_[kWeightIndex]; auto origin_weight = reinterpret_cast(weight_tensor->MutableData()); - int OC4 = UP_DIV(weight_tensor->Batch(), C4NUM); - int pack_weight_size = C4NUM * OC4 * weight_tensor->Height() * weight_tensor->Width(); + int OC8 = UP_DIV(weight_tensor->Batch(), C8NUM); + int pack_weight_size = C8NUM * OC8 * weight_tensor->Height() * weight_tensor->Width(); packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(int16_t))); if (packed_weight_ == nullptr) { MS_LOG(ERROR) << "Malloc buffer failed."; @@ -55,35 +56,36 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitWeightBias() { PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), weight_tensor->Batch(), &(conv_param_->conv_quant_arg_)); - bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(int32_t))); + bias_data_ = reinterpret_cast(malloc(C8NUM * OC8 * sizeof(int32_t))); if (bias_data_ == nullptr) { MS_LOG(ERROR) << "Malloc buffer failed."; return RET_ERROR; } - memset(bias_data_, 0, C4NUM * OC4 * sizeof(int32_t)); + memset(bias_data_, 0, C8NUM * OC8 * sizeof(int32_t)); if (in_tensors_.size() == kInputSize2) { auto bias_tensor = in_tensors_.at(kBiasIndex); auto ori_bias = reinterpret_cast(bias_tensor->MutableData()); memcpy(bias_data_, ori_bias, bias_tensor->ElementsNum() * sizeof(int32_t)); } - conv_param_->thread_num_ = MSMIN(thread_count_, OC4); + conv_param_->thread_num_ = MSMIN(thread_count_, OC8); return RET_OK; } int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() { - int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C4NUM * - UP_DIV(conv_param_->input_channel_, 4); - packed_input_ = reinterpret_cast(context_->allocator->Malloc(pack_input_size * sizeof(int16_t))); - if (packed_input_ == nullptr) { - MS_LOG(ERROR) << "Malloc buffer failed."; - return RET_ERROR; - } - - if (conv_param_->input_channel_ % C4NUM != 0) { + if (conv_param_->input_channel_ % C8NUM != 0) { need_align_ = true; - int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C4NUM * - UP_DIV(conv_param_->output_channel_, C4NUM); + + int pack_input_size = conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * C8NUM * + UP_DIV(conv_param_->input_channel_, C8NUM); + packed_input_ = reinterpret_cast(context_->allocator->Malloc(pack_input_size * sizeof(int8_t))); + if (packed_input_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } + + int pack_output_size = conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * C8NUM * + UP_DIV(conv_param_->output_channel_, C8NUM); packed_output_ = reinterpret_cast(context_->allocator->Malloc(pack_output_size * sizeof(int8_t))); if (packed_input_ == nullptr) { MS_LOG(ERROR) << "Malloc buffer failed."; @@ -93,6 +95,136 @@ int ConvolutionDepthwiseSWInt8CPUKernel::InitBuffer() { return RET_OK; } +void ConvolutionDepthwiseSWInt8CPUKernel::FreeTmpQuant() { + if (input_scale_ != nullptr) { + free(input_scale_); + input_scale_ = nullptr; + } + if (input_zp_ != nullptr) { + free(input_zp_); + input_zp_ = nullptr; + } + if (weight_scale_ != nullptr) { + free(weight_scale_); + weight_scale_ = nullptr; + } + if (output_scale_ != nullptr) { + free(output_scale_); + output_scale_ = nullptr; + } + if (output_zp_ != nullptr) { + free(output_zp_); + output_zp_ = nullptr; + } +} + +int ConvolutionDepthwiseSWInt8CPUKernel::ReinitFreeBefore() { + FreeTmpQuant(); + if (conv_quant_arg_->real_multiplier_ != nullptr) { + free(conv_quant_arg_->real_multiplier_); + conv_quant_arg_->real_multiplier_ = nullptr; + } + if (conv_quant_arg_->left_shift_ != nullptr) { + free(conv_quant_arg_->left_shift_); + conv_quant_arg_->left_shift_ = nullptr; + } + if (conv_quant_arg_->right_shift_ != nullptr) { + free(conv_quant_arg_->right_shift_); + conv_quant_arg_->right_shift_ = nullptr; + } + if (conv_quant_arg_->quant_multiplier_ != nullptr) { + free(conv_quant_arg_->quant_multiplier_); + conv_quant_arg_->quant_multiplier_ = nullptr; + } + if (conv_quant_arg_->out_act_min_ != nullptr) { + free(conv_quant_arg_->out_act_min_); + conv_quant_arg_->out_act_min_ = nullptr; + } + if (conv_quant_arg_->out_act_max_ != nullptr) { + free(conv_quant_arg_->out_act_max_); + conv_quant_arg_->out_act_max_ = nullptr; + } + return RET_OK; +} + +int ConvolutionDepthwiseSWInt8CPUKernel::ReinitQuantParam() { + ReinitFreeBefore(); // remalloc quant param buffer + + auto input_tensor = in_tensors_.at(kInputIndex); + auto channel = conv_param_->input_channel_; + input_scale_ = reinterpret_cast(malloc(channel * sizeof(float))); + input_zp_ = reinterpret_cast(malloc(channel * sizeof(int8_t))); + if (input_tensor->GetQuantParams().size() == kPerTensor) { + for (int i = 0; i < channel; i++) { + auto input_quant_arg = input_tensor->GetQuantParams().front(); + input_zp_[i] = input_quant_arg.zeroPoint; + input_scale_[i] = input_quant_arg.scale; + } + } else { + for (int i = 0; i < channel; i++) { + auto input_quant_arg = input_tensor->GetQuantParams()[i]; + input_zp_[i] = input_quant_arg.zeroPoint; + input_scale_[i] = input_quant_arg.scale; + } + } + + auto output_tensor = out_tensors_.at(kOutputIndex); + output_scale_ = reinterpret_cast(malloc(channel * sizeof(float))); + output_zp_ = reinterpret_cast(malloc(channel * sizeof(int32_t))); + if (output_tensor->GetQuantParams().size() == kPerTensor) { + for (int i = 0; i < channel; i++) { + auto output_quant_arg = output_tensor->GetQuantParams().front(); + output_zp_[i] = output_quant_arg.zeroPoint; + output_scale_[i] = output_quant_arg.scale; + } + } else { + for (int i = 0; i < channel; i++) { + auto output_quant_arg = output_tensor->GetQuantParams()[i]; + output_zp_[i] = output_quant_arg.zeroPoint; + output_scale_[i] = output_quant_arg.scale; + } + } + + conv_quant_arg_->real_multiplier_ = reinterpret_cast(malloc(channel * sizeof(double))); + conv_quant_arg_->left_shift_ = reinterpret_cast(malloc(channel * sizeof(int32_t))); + conv_quant_arg_->right_shift_ = reinterpret_cast(malloc(channel * sizeof(int32_t))); + conv_quant_arg_->quant_multiplier_ = reinterpret_cast(malloc(channel * sizeof(int32_t))); + conv_quant_arg_->out_act_min_ = reinterpret_cast(malloc(channel * sizeof(int32_t))); + conv_quant_arg_->out_act_max_ = reinterpret_cast(malloc(channel * sizeof(int32_t))); + + weight_scale_ = reinterpret_cast(malloc(channel * sizeof(float))); + auto weight_tensor = in_tensors_.at(kWeightIndex); + if (weight_tensor->GetQuantParams().size() == kPerTensor) { + for (int i = 0; i < channel; i++) { + auto weight_quant_arg = weight_tensor->GetQuantParams().front(); + weight_scale_[i] = weight_quant_arg.scale; + } + } else { + for (int i = 0; i < channel; i++) { + auto weight_quant_arg = weight_tensor->GetQuantParams()[i]; + weight_scale_[i] = weight_quant_arg.scale; + } + } + + for (int i = 0; i < channel; ++i) { + const double in_scale = static_cast(input_scale_[i] * weight_scale_[i]); + double real_multiplier = in_scale / static_cast(output_scale_[i]); + conv_quant_arg_->real_multiplier_[i] = real_multiplier; + QuantizeRoundParameter(real_multiplier, &conv_quant_arg_->quant_multiplier_[i], &conv_quant_arg_->left_shift_[i], + &conv_quant_arg_->right_shift_[i]); + } + + // now only consider per tensor for output + bool relu = conv_param_->act_type_ == ActType_Relu; + bool relu6 = conv_param_->act_type_ == ActType_Relu6; + for (int i = 0; i < channel; ++i) { + CalculateActivationRangeQuantized(relu, relu6, output_zp_[i], output_scale_[i], + &conv_param_->conv_quant_arg_.out_act_min_[i], + &conv_param_->conv_quant_arg_.out_act_max_[i]); + } + return RET_OK; +} + int ConvolutionDepthwiseSWInt8CPUKernel::Init() { sliding = new (std::nothrow) SlidingWindowParam; if (sliding == nullptr) { @@ -107,13 +239,19 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Init() { int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() { ConvolutionBaseCPUKernel::Init(); - InitSlidingParamConvDw(sliding, conv_param_, C4NUM); + InitSlidingParamConvDw(sliding, conv_param_, C8NUM); auto ret = ConvolutionBaseCPUKernel::SetQuantParam(); if (ret != RET_OK) { MS_LOG(ERROR) << "Set quant param failed."; return ret; } + ret = ReinitQuantParam(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "reinit quant param failed."; + return ret; + } + ret = InitWeightBias(); if (ret != RET_OK) { MS_LOG(ERROR) << "Depthwise int8 InitWeightBias error!"; @@ -123,8 +261,8 @@ int ConvolutionDepthwiseSWInt8CPUKernel::ReSize() { } int ConvolutionDepthwiseSWInt8CPUKernel::Execute(int task_id) { - ConvDwSWInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), conv_param_, - sliding, task_id); + ConvDwSWInt8(packed_output_, packed_input_, packed_weight_, reinterpret_cast(bias_data_), input_zp_, + output_zp_, conv_param_, sliding, task_id); return RET_OK; } @@ -157,7 +295,12 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Run() { auto input_tensor = in_tensors_.at(kInputIndex); auto input_addr = reinterpret_cast(input_tensor->MutableData()); - PackDepthwiseInt8Input(input_addr, packed_input_, conv_param_); + if (need_align_) { + PackNHWCToNHWC8Int8(input_addr, packed_input_, conv_param_->output_batch_, + conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + } else { + packed_input_ = input_addr; + } auto output_addr = reinterpret_cast(out_tensors_.at(kOutputIndex)->MutableData()); if (!need_align_) { @@ -171,11 +314,11 @@ int ConvolutionDepthwiseSWInt8CPUKernel::Run() { } if (need_align_) { - PackNHWC4ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_, + PackNHWC8ToNHWCInt8(packed_output_, output_addr, conv_param_->output_batch_, conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_); + context_->allocator->Free(packed_input_); context_->allocator->Free(packed_output_); } - context_->allocator->Free(packed_input_); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.h index 634812dc70..ad01fe0668 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_depthwise_slidewindow_int8.h @@ -40,11 +40,21 @@ class ConvolutionDepthwiseSWInt8CPUKernel : public ConvolutionBaseCPUKernel { int Execute(int task_id); private: + int ReinitQuantParam(); + int ReinitFreeBefore(); + void FreeTmpQuant(); + SlidingWindowParam *sliding = nullptr; int16_t *packed_weight_ = nullptr; - int16_t *packed_input_ = nullptr; + int8_t *packed_input_ = nullptr; int8_t *packed_output_ = nullptr; bool need_align_ = false; + + int8_t *input_zp_ = nullptr; + float *input_scale_ = nullptr; + float *weight_scale_ = nullptr; + int32_t *output_zp_ = nullptr; + float *output_scale_ = nullptr; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc index 535f41b025..8bf54c7b9b 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/deconvolution_depthwise_int8.cc @@ -52,8 +52,8 @@ int DeconvolutionDepthwiseInt8CPUKernel::InitWeightBias() { MS_LOG(ERROR) << "Malloc buffer failed."; return RET_ERROR; } - PackDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), - weight_tensor->Batch(), &(conv_param_->conv_quant_arg_)); + PackDeconvDepthwiseInt8Weight(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), + weight_tensor->Batch(), &(conv_param_->conv_quant_arg_)); bias_data_ = reinterpret_cast(malloc(C4NUM * OC4 * sizeof(int32_t))); if (bias_data_ == nullptr) {