diff --git a/mindspore/lite/nnacl/assembly/arm32/MatmulInt8Opt.S b/mindspore/lite/nnacl/assembly/arm32/MatmulInt8Opt.S index 58edd8e9fe..ba8fb6ed47 100644 --- a/mindspore/lite/nnacl/assembly/arm32/MatmulInt8Opt.S +++ b/mindspore/lite/nnacl/assembly/arm32/MatmulInt8Opt.S @@ -3,9 +3,9 @@ .text .align 5 -.global MatmulInt8Neon32Opt +.global MatmulInt8Opt #ifndef __APPLE__ -.type MatmulInt8Neon32Opt, %function +.type MatmulInt8Opt, %function #endif //void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, @@ -16,7 +16,7 @@ // #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp // #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel, #48: filter_zp -MatmulInt8Neon32Opt: +MatmulInt8Opt: push {r0-r11, lr} vpush {q4-q7} add sp, sp, #116 diff --git a/mindspore/lite/nnacl/assembly/arm64/MatmulInt8Opt.S b/mindspore/lite/nnacl/assembly/arm64/MatmulInt8Opt.S index 7ad867385d..90da4924ac 100644 --- a/mindspore/lite/nnacl/assembly/arm64/MatmulInt8Opt.S +++ b/mindspore/lite/nnacl/assembly/arm64/MatmulInt8Opt.S @@ -1,21 +1,21 @@ #ifdef __aarch64__ .text .align 5 - .global MatmulInt8Neon64Opt + .global MatmulInt8Opt #ifndef __APPLE__ - .type MatmulInt8Neon64Opt, %function + .type MatmulInt8Opt, %function #endif -//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, +//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, // const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, -// int32_t *right_shift, int row, int col, int stride, int filter_peroc, int32_t *filter_zp) +// int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp) // x0: a(left matrix ptr) // x1: b(right matrix ptr) // x2: out ptr -// w3: row4 -// w4: col4 -// w5: deep16 +// x3: row4 +// x4: col4 +// x5: deep16 // x6: a_sums // x7: bias // w8: act_min @@ -24,385 +24,318 @@ // x11: multiplier // x12: left_shift // x13: right_shift -// w14: row -// w15: col -// w24: stride -// w27: filter_peroc +// x14: stride +// x15: filter_peroc // x28: filter_zp -MatmulInt8Neon64Opt: - sub sp, sp, #208 - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - stp x19, x20, [sp], #16 - stp x21, x22, [sp], #16 - stp x23, x24, [sp], #16 - stp x25, x26, [sp], #16 - stp x27, x28, [sp], #16 - - ldr w8, [sp] - ldr w9, [sp, #8] - ldr w10, [sp, #16] - ldr x11, [sp, #24] - ldr x12, [sp, #32] - ldr x13, [sp, #40] - ldr w14, [sp, #48] - ldr w15, [sp, #56] - ldr w24, [sp, #64] - ldr w27, [sp, #72] - ldr x28, [sp, #80] - - mov w17, #4 // sizeof(int8)*4 - mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 - mov w17, #1 - mov x25, x2 -L1: - cmp w4, #0 // if at the end of col4 - beq End1 - - mov w16, w3 // reset a row4 counter - mov w23, w14 // reset a row counter - mov x17, x0 // reload a ptr - mov x22, x6 // reload a_sums ptr -L2: - cmp w16, #0 - beq End2 - - mov x18, x1 // reload b ptr - mov x19, x7 // reload bias ptr - mov w20, w5 // reload depth - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - dup v24.4s, wzr - dup v25.4s, wzr - dup v26.4s, wzr - dup v27.4s, wzr - dup v28.4s, wzr - dup v29.4s, wzr - dup v30.4s, wzr - dup v31.4s, wzr -L3: - cmp w20, #0 - beq End3 - - ld1 {v0.16b}, [x17], #16 - ld1 {v1.16b}, [x17], #16 - ld1 {v2.16b}, [x17], #16 - ld1 {v3.16b}, [x17], #16 - ld1 {v4.16b}, [x18], #16 - ld1 {v5.16b}, [x18], #16 - ld1 {v6.16b}, [x18], #16 - ld1 {v7.16b}, [x18], #16 - - smull v8.8h, v4.8b, v0.8b - smull v9.8h, v5.8b, v0.8b - smull v10.8h, v6.8b, v0.8b - smull v11.8h, v7.8b, v0.8b - smull v12.8h, v4.8b, v1.8b - smull v13.8h, v5.8b, v1.8b - smull v14.8h, v6.8b, v1.8b - smull v15.8h, v7.8b, v1.8b - - smlal2 v8.8h, v4.16b, v0.16b - smlal2 v9.8h, v5.16b, v0.16b - smlal2 v10.8h, v6.16b, v0.16b - smlal2 v11.8h, v7.16b, v0.16b - smlal2 v12.8h, v4.16b, v1.16b - smlal2 v13.8h, v5.16b, v1.16b - smlal2 v14.8h, v6.16b, v1.16b - smlal2 v15.8h, v7.16b, v1.16b - - sadalp v16.4s, v8.8h - sadalp v17.4s, v9.8h - sadalp v18.4s, v10.8h - sadalp v19.4s, v11.8h - sadalp v20.4s, v12.8h - sadalp v21.4s, v13.8h - sadalp v22.4s, v14.8h - sadalp v23.4s, v15.8h - - smull v8.8h, v4.8b, v2.8b - smull v9.8h, v5.8b, v2.8b - smull v10.8h, v6.8b, v2.8b - smull v11.8h, v7.8b, v2.8b - smull v12.8h, v4.8b, v3.8b - smull v13.8h, v5.8b, v3.8b - smull v14.8h, v6.8b, v3.8b - smull v15.8h, v7.8b, v3.8b - - smlal2 v8.8h, v4.16b, v2.16b - smlal2 v9.8h, v5.16b, v2.16b - smlal2 v10.8h, v6.16b, v2.16b - smlal2 v11.8h, v7.16b, v2.16b - smlal2 v12.8h, v4.16b, v3.16b - smlal2 v13.8h, v5.16b, v3.16b - smlal2 v14.8h, v6.16b, v3.16b - smlal2 v15.8h, v7.16b, v3.16b - - sadalp v24.4s, v8.8h - sadalp v25.4s, v9.8h - sadalp v26.4s, v10.8h - sadalp v27.4s, v11.8h - sadalp v28.4s, v12.8h - sadalp v29.4s, v13.8h - sadalp v30.4s, v14.8h - sadalp v31.4s, v15.8h - subs w20, w20, #16 // depth + 16 - b L3 - -End3: - addp v16.4s, v16.4s, v17.4s - addp v18.4s, v18.4s, v19.4s - addp v20.4s, v20.4s, v21.4s - addp v22.4s, v22.4s, v23.4s - addp v24.4s, v24.4s, v25.4s - addp v26.4s, v26.4s, v27.4s - addp v28.4s, v28.4s, v29.4s - addp v30.4s, v30.4s, v31.4s - - addp v16.4s, v16.4s, v18.4s - addp v17.4s, v20.4s, v22.4s - addp v18.4s, v24.4s, v26.4s - addp v19.4s, v28.4s, v30.4s - - // Add (Bias+Depth*Za*Zb-Za*Bsums) - ld1 {v15.4s}, [x19], #16 - add v16.4s, v16.4s, v15.4s - add v17.4s, v17.4s, v15.4s - add v18.4s, v18.4s, v15.4s - add v19.4s, v19.4s, v15.4s - - ld1r {v20.4s}, [x22], #4 - ld1r {v21.4s}, [x22], #4 - ld1r {v22.4s}, [x22], #4 - ld1r {v23.4s}, [x22], #4 - cmp w27, #0 - beq Apply - ld1 {v14.4s}, [x28] - mul v20.4s, v20.4s, v14.4s - mul v21.4s, v21.4s, v14.4s - mul v22.4s, v22.4s, v14.4s - mul v23.4s, v23.4s, v14.4s - -Apply: - // Subtract (Asums*Zb) - sub v16.4s, v16.4s, v20.4s - sub v17.4s, v17.4s, v21.4s - sub v18.4s, v18.4s, v22.4s - sub v19.4s, v19.4s, v23.4s - - cmp w27, #1 - beq PerCLoad - - ld1r {v13.4s}, [x12] - ld1r {v12.4s}, [x11] - ld1r {v11.4s}, [x13] - b Quantize - -PerCLoad: - ld1 {v13.4s}, [x12] - ld1 {v12.4s}, [x11] - ld1 {v11.4s}, [x13] -Quantize: - - // Apply left shift - sqshl v16.4s, v16.4s, v13.4s - sqshl v17.4s, v17.4s, v13.4s - sqshl v18.4s, v18.4s, v13.4s - sqshl v19.4s, v19.4s, v13.4s - - // Apply the fixed-point part of the multiplier. - sqrdmulh v16.4s, v16.4s, v12.4s - sqrdmulh v17.4s, v17.4s, v12.4s - sqrdmulh v18.4s, v18.4s, v12.4s - sqrdmulh v19.4s, v19.4s, v12.4s - - // Apply right shift - and v20.16b, v11.16b, v16.16b - sshr v20.4s, v20.4s, #31 - sqadd v16.4s, v16.4s, v20.4s - srshl v16.4s, v16.4s, v11.4s - and v21.16b, v11.16b, v17.16b - sshr v21.4s, v21.4s, #31 - sqadd v17.4s, v17.4s, v21.4s - srshl v17.4s, v17.4s, v11.4s - and v22.16b, v11.16b, v18.16b - sshr v22.4s, v22.4s, #31 - sqadd v18.4s, v18.4s, v22.4s - srshl v18.4s, v18.4s, v11.4s - and v23.16b, v11.16b, v19.16b - sshr v23.4s, v23.4s, #31 - sqadd v19.4s, v19.4s, v23.4s - srshl v19.4s, v19.4s, v11.4s - - // Add the destination zero point - dup v10.4s, w10 - add v16.4s, v16.4s, v10.4s - add v17.4s, v17.4s, v10.4s - add v18.4s, v18.4s, v10.4s - add v19.4s, v19.4s, v10.4s - - // Apply the act_min bound - dup v9.4s, w8 - smax v16.4s, v16.4s, v9.4s - smax v17.4s, v17.4s, v9.4s - smax v18.4s, v18.4s, v9.4s - smax v19.4s, v19.4s, v9.4s - - // Apply the act_min bound - dup v8.4s, w9 - smin v16.4s, v16.4s, v8.4s - smin v17.4s, v17.4s, v8.4s - smin v18.4s, v18.4s, v8.4s - smin v19.4s, v19.4s, v8.4s - - // int32 -> int16 - sqxtn v13.4h, v16.4s - sqxtn2 v13.8h, v17.4s - sqxtn v14.4h, v18.4s - sqxtn2 v14.8h, v19.4s - - // int16 -> int8 - sqxtn v15.8b, v13.8h - sqxtn2 v15.16b, v14.8h - - cmp w23, #4 - blt Write // if rows < 4 - cmp w15, #4 - blt Write // if cols < 4 - - st1 {v15.s}[0], [x2], x24 - st1 {v15.s}[1], [x2], x24 - st1 {v15.s}[2], [x2], x24 - st1 {v15.s}[3], [x2], x24 - b Endwrite - -Write: - cmp w15, #4 - beq WriteCol4 - cmp w15, #3 - beq WriteCol3 - cmp w15, #2 - beq WriteCol2 - cmp w15, #1 - beq WriteCol1 - -WriteCol4: - st1 {v15.s}[0], [x2], x24 - cmp w23, #1 - beq Endwrite - st1 {v15.s}[1], [x2], x24 - cmp w23, #2 - beq Endwrite - st1 {v15.s}[2], [x2], x24 - cmp w23, #3 - beq Endwrite - st1 {v15.s}[3], [x2], x24 - b Endwrite - -WriteCol3: - mov x26, x2 - st1 {v15.b}[0], [x26], #1 - st1 {v15.b}[1], [x26], #1 - st1 {v15.b}[2], [x26], #1 - add x2, x2, x24 - cmp w23, #1 - beq Endwrite - mov x26, x2 - st1 {v15.b}[4], [x26], #1 - st1 {v15.b}[5], [x26], #1 - st1 {v15.b}[6], [x26], #1 - add x2, x2, x24 - cmp w23, #2 - beq Endwrite - mov x26, x2 - st1 {v15.b}[8], [x26], #1 - st1 {v15.b}[9], [x26], #1 - st1 {v15.b}[10], [x26], #1 - add x2, x2, x24 - cmp w23, #3 - beq Endwrite - mov x26, x2 - st1 {v15.b}[12], [x26], #1 - st1 {v15.b}[13], [x26], #1 - st1 {v15.b}[14], [x26], #1 - add x2, x2, x24 - b Endwrite - -WriteCol2: - mov x26, x2 - st1 {v15.b}[0], [x26], #1 - st1 {v15.b}[1], [x26], #1 - add x2, x2, x24 - cmp w23, #1 - beq Endwrite - mov x26, x2 - st1 {v15.b}[4], [x26], #1 - st1 {v15.b}[5], [x26], #1 - add x2, x2, x24 - cmp w23, #2 - beq Endwrite - mov x26, x2 - st1 {v15.b}[8], [x26], #1 - st1 {v15.b}[9], [x26], #1 - add x2, x2, x24 - cmp w23, #3 - beq Endwrite - mov x26, x2 - st1 {v15.b}[12], [x26], #1 - st1 {v15.b}[13], [x26], #1 - add x2, x2, x24 - b Endwrite - -WriteCol1: - st1 {v15.b}[0], [x2], x24 - cmp w23, #1 - beq Endwrite - st1 {v15.b}[4], [x2], x24 - cmp w23, #2 - beq Endwrite - st1 {v15.b}[8], [x2], x24 - cmp w23, #3 - beq Endwrite - st1 {v15.b}[12], [x2], x24 - b Endwrite - -Endwrite: - sub w16, w16, #4 // a row4 counter - 4 - sub w23, w23, #4 // a row counter - 4 - b L2 - -End2: - sub w4, w4, #4 // b col4 counter - 4 - sub w15, w15, #4 // b col counter - 4 - add x1, x1, x21 // b ptr + stride - add x7, x7, #16 // bias ptr + stride - add x25, x25, #4 // output + stride(4 * sizeof(int8)) - mov x2, x25 - - cmp w27, #0 - beq PerTEnd2 - add x12, x12, #16 - add x11, x11, #16 - add x13, x13, #16 - add x28, x28, #16 -PerTEnd2: - b L1 - -End1: - sub sp, sp, #208 - ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 - ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - ldp x19, x20, [sp], #16 - ldp x21, x22, [sp], #16 - ldp x23, x24, [sp], #16 - ldp x25, x26, [sp], #16 - ldp x27, x28, [sp], #16 - ret +MatmulInt8Opt: + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + stp x25, x26, [sp], #16 + stp x27, x28, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + ldr x14, [sp, #48] + ldr x15, [sp, #56] + + mov x23, #4 + mul x23, x23, x5 // lhs step + mov x24, #4 + mul x24, x24, x14 // dst step +LoopRow: + mov x16, x1 // reload rhs ptr + mov x17, x4 // reload rhs col + mov x18, x7 // reload bias ptr + mov x27, x2 // reload dst ptr + ldr x28, [sp, #64] // reload filter_zp + + LoopCol: + mov x25, x6 // reload a_sums ptr + mov x19, x27 // reload dst ptr + mov x20, x0 // reload lhs ptr + mov x21, x5 // reload depth + + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr + + LoopDepth: + ld1 {v0.16b, v1.16b}, [x20], #32 + ld1 {v4.16b, v5.16b}, [x16], #32 + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + ld1 {v6.16b, v7.16b}, [x16], #32 + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + ld1 {v2.16b, v3.16b}, [x20], #32 + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs x21, x21, #16 // depth - 16 + bgt LoopDepth + + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + Bias: + cbz x7, NoBias + ld1 {v15.4s}, [x18], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + NoBias: + ld1r {v20.4s}, [x25], #4 + ld1r {v21.4s}, [x25], #4 + ld1r {v22.4s}, [x25], #4 + ld1r {v23.4s}, [x25], #4 + cbz x15, ApplySum + + ld1 {v14.4s}, [x28], #16 + mul v20.4s, v20.4s, v14.4s + mul v21.4s, v21.4s, v14.4s + mul v22.4s, v22.4s, v14.4s + mul v23.4s, v23.4s, v14.4s + + ApplySum: + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + cbnz x15, PerCLoad + + ld1r {v13.4s}, [x12] + ld1r {v12.4s}, [x11] + ld1r {v11.4s}, [x13] + b Quantize + + PerCLoad: + ld1 {v13.4s}, [x12], #16 + ld1 {v12.4s}, [x11], #16 + ld1 {v11.4s}, [x13], #16 + + Quantize: + sqshl v16.4s, v16.4s, v13.4s + sqshl v17.4s, v17.4s, v13.4s + sqshl v18.4s, v18.4s, v13.4s + sqshl v19.4s, v19.4s, v13.4s + + sqrdmulh v16.4s, v16.4s, v12.4s + sqrdmulh v17.4s, v17.4s, v12.4s + sqrdmulh v18.4s, v18.4s, v12.4s + sqrdmulh v19.4s, v19.4s, v12.4s + + and v20.16b, v11.16b, v16.16b + sshr v20.4s, v20.4s, #31 + sqadd v16.4s, v16.4s, v20.4s + srshl v16.4s, v16.4s, v11.4s + and v21.16b, v11.16b, v17.16b + sshr v21.4s, v21.4s, #31 + sqadd v17.4s, v17.4s, v21.4s + srshl v17.4s, v17.4s, v11.4s + and v22.16b, v11.16b, v18.16b + sshr v22.4s, v22.4s, #31 + sqadd v18.4s, v18.4s, v22.4s + srshl v18.4s, v18.4s, v11.4s + and v23.16b, v11.16b, v19.16b + sshr v23.4s, v23.4s, #31 + sqadd v19.4s, v19.4s, v23.4s + srshl v19.4s, v19.4s, v11.4s + + dup v10.4s, w10 + add v16.4s, v16.4s, v10.4s + add v17.4s, v17.4s, v10.4s + add v18.4s, v18.4s, v10.4s + add v19.4s, v19.4s, v10.4s + + dup v9.4s, w8 + smax v16.4s, v16.4s, v9.4s + smax v17.4s, v17.4s, v9.4s + smax v18.4s, v18.4s, v9.4s + smax v19.4s, v19.4s, v9.4s + + dup v8.4s, w9 + smin v16.4s, v16.4s, v8.4s + smin v17.4s, v17.4s, v8.4s + smin v18.4s, v18.4s, v8.4s + smin v19.4s, v19.4s, v8.4s + + sqxtn v13.4h, v16.4s + sqxtn2 v13.8h, v17.4s + sqxtn v14.4h, v18.4s + sqxtn2 v14.8h, v19.4s + + sqxtn v15.8b, v13.8h + sqxtn2 v15.16b, v14.8h + + cmp x17, #1 + beq Write1 + cmp x17, #2 + beq Write2 + cmp x17, #3 + beq Write3 + b Write4 + + Write1: + add x27, x27, #1 + st1 {v15.b}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.b}[4], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.b}[8], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.b}[12], [x19], x14 + b WriteEnd + Write2: + add x27, x27, #2 + st1 {v15.h}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.h}[2], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.h}[4], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.h}[6], [x19], x14 + b WriteEnd + Write3: + add x27, x27, #3 + add x22, x19, #2 + st1 {v15.h}[0], [x19], x14 + st1 {v15.b}[2], [x22], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.h}[2], [x19], x14 + st1 {v15.b}[6], [x22], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.h}[4], [x19], x14 + st1 {v15.b}[10], [x22], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.h}[6], [x19], x14 + st1 {v15.b}[14], [x22], x14 + b WriteEnd + Write4: + add x27, x27, #4 + st1 {v15.s}[0], [x19], x14 + cmp x3, #1 + beq WriteEnd + st1 {v15.s}[1], [x19], x14 + cmp x3, #2 + beq WriteEnd + st1 {v15.s}[2], [x19], x14 + cmp x3, #3 + beq WriteEnd + st1 {v15.s}[3], [x19], x14 + + WriteEnd: + subs x17, x17, #4 + bgt LoopCol + +LoopColEnd: + subs x3, x3, #4 + ble LoopRowEnd + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + add x6, x6, #16 + add x0, x0, x23 + add x2, x2, x24 + b LoopRow + +LoopRowEnd: + sub sp, sp, #208 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret #endif diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index 07dbeeb25f..b54f7e4507 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -815,22 +815,10 @@ void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) { int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; -#ifdef ENABLE_ARM32 - MatmulInt8Neon32Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, - conv_param->output_channel_, is_per_oc, filter_zp); -#elif ENABLE_ARM64 - MatmulInt8Neon64Opt(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum, - bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row, - col, conv_param->output_channel_, is_per_oc, filter_zp); -#else MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, conv_param->output_channel_, is_per_oc, filter_zp); -#endif return; } diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index abf898a8db..0382fb53f6 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -253,7 +253,7 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, #ifndef ENABLE_ARM void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, const int *bias, int mini, int maxi, int out_zp, int32_t *multiplier, int32_t *left_shift, - int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp) { + int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp) { int col_tile = C4NUM; /* support per-layer && weight per-channel */ /* row4x16-major * row16x2-major => (int8)row-major*/ diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index 2ea3b303e4..3a0b008fbd 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -62,7 +62,7 @@ void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row size_t per_channel, int32_t *filter_zp); void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, - int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp); + int32_t *right_shift, size_t stride, size_t filter_peroc, int32_t *filter_zp); #ifdef ENABLE_ARM64 void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, @@ -71,18 +71,11 @@ void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, i void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, const int *input_sum, const int *bias); -void MatmulInt8Neon64Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, - const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, - int32_t *left_shift, int32_t *right_shift, int row, int col, int stride, int filter_peroc, - int32_t *filter_zp); #endif #ifdef ENABLE_ARM32 void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel); -void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, - const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, - int32_t *left_shift, int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp); #endif #ifdef __cplusplus }