diff --git a/mindspore/lite/nnacl/assembly/arm32/MatmulInt8Opt.S b/mindspore/lite/nnacl/assembly/arm32/MatmulInt8Opt.S new file mode 100644 index 0000000000..a46d615b5b --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/MatmulInt8Opt.S @@ -0,0 +1,299 @@ +#ifdef __arm__ +#ifndef __aarch64__ + +.text +.align 5 +.global MatmulInt8Neon32Opt +#ifndef __APPLE__ +.type MatmulInt8Neon32Opt, %function +#endif + +//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, +// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, +// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel, +// int *filter_zp); +// #-52: a, #-48: b, #-44: dst, #-40: row +// #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp +// #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel, #48: filter_zp + +MatmulInt8Neon32Opt: + push {r0-r11, lr} + vpush {q4-q7} + add sp, sp, #116 + + ldr r4, [sp] // col + ldr r7, [sp, #40] // output stride + mov r8, #0 // output channels offset + ldr r10, [sp, #44] + cmp r10, #0 + beq L1 + ldr r6, [sp, #8] // load intpu_sums ptr if per_channel +L1: + cmp r4, #0 // if at the end of col + ble End1 + + ldr r0, [sp, #-52] // reload a ptr + ldr r3, [sp, #-40] // reset row counter + ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor +L2: + cmp r3, #0 // if at the end of row + ble End2 + + ldr r1, [sp, #-48] // reload b ptr + ldr r5, [sp, #4] // reset deep16 + vmov.i32 q6, #0 + vmov.i32 q7, #0 + vmov.i32 q8, #0 + vmov.i32 q9, #0 + vmov.i32 q10, #0 + vmov.i32 q11, #0 + vmov.i32 q12, #0 + vmov.i32 q13, #0 +L3: + cmp r5, #0 + beq End3 + + vld1.8 {d0, d1, d2, d3}, [r0]! + vld1.8 {d8, d9, d10, d11}, [r1]! + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 + + vpadal.s16 q6, q14 + vpadal.s16 q7, q2 + vpadal.s16 q8, q15 + vpadal.s16 q9, q3 + + vld1.8 {d0, d1, d2, d3}, [r0]! + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 + + vpadal.s16 q10, q14 + vpadal.s16 q11, q2 + vpadal.s16 q12, q15 + vpadal.s16 q13, q3 + sub r5, r5, #16 // deep16 -= 16 + b L3 + +End3: + vpadd.i32 d0, d12, d13 + vpadd.i32 d1, d14, d15 + vpadd.i32 d2, d16, d17 + vpadd.i32 d3, d18, d19 + vpadd.i32 d4, d20, d21 + vpadd.i32 d5, d22, d23 + vpadd.i32 d6, d24, d25 + vpadd.i32 d7, d26, d27 + + vpadd.i32 d28, d0, d1 + vpadd.i32 d29, d2, d3 + vpadd.i32 d30, d4, d5 + vpadd.i32 d31, d6, d7 + + // Add weight_bias + ldr r9, [sp, #12] // reload weight_bias ptr + add r9, r9, r8 + vld1.32 {d26}, [r9]! + vadd.i32 d28, d28, d26 + vadd.i32 d29, d29, d26 + vadd.i32 d30, d30, d26 + vadd.i32 d31, d31, d26 + + ldr r10, [sp, #44] + cmp r10, #0 + bgt PerChannel + +PerTensor: + // Substract input_sums + vld1.32 {d24, d25}, [r6]! + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + vsub.s32 d28, d28, d20 + vsub.s32 d29, d29, d21 + vsub.s32 d30, d30, d22 + vsub.s32 d31, d31, d23 + + // Apply left shift + ldr r10, [sp, #32] + ldr r11, [r10]! + vdup.32 q9, r11 + vshl.s32 q14, q14, q9 + vshl.s32 q15, q15, q9 + + // Apply the fixed-point part of the multiplier + ldr r10, [sp, #28] + ldr r11, [r10] + vdup.32 q8, r11 + vqrdmulh.s32 q14, q14, q8 + vqrdmulh.s32 q15, q15, q8 + + // Apply right shift + ldr r10, [sp, #36] + ldr r11, [r10] + vdup.32 q7, r11 + vand q6, q7, q14 + vshr.s32 q6, q6, #31 + vqadd.s32 q14, q14, q6 + vrshl.s32 q14, q14, q7 + vand q5, q7, q15 + vshr.s32 q5, q5, #31 + vqadd.s32 q15, q15, q5 + vrshl.s32 q15, q15, q7 + b AddDstZP + +PerChannel: + // Substract input_sums + vld1.32 {d24, d25}, [r6]! + vdup.32 d20, d24[0] + vdup.32 d21, d24[1] + vdup.32 d22, d25[0] + vdup.32 d23, d25[1] + ldr r10, [sp, #48] + vld1.32 {d19}, [r10] + vmul.s32 d24, d20, d19 + vmul.s32 d25, d21, d19 + vmul.s32 d26, d22, d19 + vmul.s32 d27, d23, d19 + vsub.s32 d28, d28, d24 + vsub.s32 d29, d29, d25 + vsub.s32 d30, d30, d26 + vsub.s32 d31, d31, d27 + + // Apply left shift + ldr r10, [sp, #32] + add r10, r10, r8 + vld1.32 {d23}, [r10] + vshl.s32 d28, d28, d23 + vshl.s32 d29, d29, d23 + vshl.s32 d30, d30, d23 + vshl.s32 d31, d31, d23 + + // Apply the fixed-point part of the multiplier + ldr r10, [sp, #28] + add r10, r10, r8 + vld1.32 {d22}, [r10] + vqrdmulh.s32 d28, d28, d22 + vqrdmulh.s32 d29, d29, d22 + vqrdmulh.s32 d30, d30, d22 + vqrdmulh.s32 d31, d31, d22 + + // Apply right shift + ldr r10, [sp, #36] + add r10, r10, r8 + vld1.32 {d21}, [r10] + vand d20, d21, d28 + vshr.s32 d20, d20, #31 + vqadd.s32 d28, d28, d20 + vrshl.s32 d28, d28, d21 + vand d19, d21, d29 + vshr.s32 d19, d19, #31 + vqadd.s32 d29, d29, d19 + vrshl.s32 d29, d29, d21 + vand d18, d21, d30 + vshr.s32 d18, d18, #31 + vqadd.s32 d30, d30, d18 + vrshl.s32 d30, d30, d21 + vand d17, d21, d31 + vshr.s32 d17, d17, #31 + vqadd.s32 d31, d31, d17 + vrshl.s32 d31, d31, d21 + +AddDstZP: + // Add the destination zero point + ldr r10, [sp, #24] + vdup.32 q4, r10 + vadd.i32 q14, q14, q4 + vadd.i32 q15, q15, q4 + + // Apply the act_min bound + ldr r10, [sp, #16] + vdup.32 q3, r10 + vmax.s32 q14, q14, q3 + vmax.s32 q15, q15, q3 + + // Apply the act_max bound + ldr r10, [sp, #20] + vdup.32 q2, r10 + vmin.s32 q14, q14, q2 + vmin.s32 q15, q15, q2 + + // Cast-and-saturate from int32 to int16 + vqmovn.s32 d28, q14 + vqmovn.s32 d29, q15 + + // Cast-and-saturate from int16 to int8 + vqmovn.s16 d30, q14 + + // start to write + cmp r4, #2 + bge WriteCol2 + cmp r4, #1 + beq WriteCol1 + b EndWrite + +WriteCol2: + vst1.16 {d30[0]}, [r2], r7 + cmp r3, #1 + beq EndWrite + vst1.16 {d30[1]}, [r2], r7 + cmp r3, #2 + beq EndWrite + vst1.16 {d30[2]}, [r2], r7 + cmp r3, #3 + beq EndWrite + vst1.16 {d30[3]}, [r2], r7 + b EndWrite + +WriteCol1: + vst1.8 {d30[0]}, [r2], r7 + cmp r3, #1 + beq EndWrite + vst1.8 {d30[2]}, [r2], r7 + cmp r3, #2 + beq EndWrite + vst1.8 {d30[4]}, [r2], r7 + cmp r3, #3 + beq EndWrite + vst1.8 {d30[6]}, [r2], r7 + b EndWrite + +EndWrite: + sub r3, r3, #4 // a row counter -= 4 + b L2 + +End2: + sub r4, r4, #2 // b col counter -= 2 + ldr r1, [sp, #-48] // load b ptr + ldr r9, [sp, #4] + mov r10, #2 + mul r9, r9, r10 // the stride of b + add r1, r1, r9 // b ptr + stride + str r1, [sp, #-48] + ldr r2, [sp, #-44] // load dst ptr + add r2, r2, #2 // dst ptr + offset + str r2, [sp, #-44] + ldr r10, [sp, #48] + add r10, r10, #8 + str r10, [sp, #48] + add r8, r8, #8 // output channels offset + 2*sizeof(int) + b L1 + +End1: + sub sp, sp, #116 + vpop {q4-q7} + pop {r0-r11, pc} +#endif +#endif diff --git a/mindspore/lite/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S b/mindspore/lite/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S deleted file mode 100644 index bdbfa738b2..0000000000 --- a/mindspore/lite/nnacl/assembly/arm64/IndirectGemmInt8_4x4.S +++ /dev/null @@ -1,371 +0,0 @@ -#ifdef __aarch64__ - -.text -.align 5 -.global IndirectGemmInt8_4x4 -#ifndef __APPLE__ -.type IndirectGemmInt8_4x4, %function -#endif - -// void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, -// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, -// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); -// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset -IndirectGemmInt8_4x4: - - .macro INIT_BIAS - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - dup v24.4s, wzr - dup v25.4s, wzr - dup v26.4s, wzr - dup v27.4s, wzr - dup v28.4s, wzr - dup v29.4s, wzr - dup v30.4s, wzr - dup v31.4s, wzr - .endm - - // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to - // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers - // r19 ~ r29 should be also preserved - // whereas our coding style do not permit such amount of parameters - sub sp, sp, #176 - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - stp x19, x20, [sp], #16 - stp x21, x22, [sp], #16 - stp x23, x24, [sp], #16 - - ldr x15, [sp] - ldr w8, [sp, #8] - ldr w9, [sp, #16] - ldr w16, [sp, #24] - ldr x17, [sp, #32] - ldr x18, [sp, #40] - ldr x19, [sp, #48] - ldr x20, [sp, #56] - ldr x21, [sp, #64] - ldr x23, [sp, #72] - - mul x5, x4, x5 - mov x4, #1 - - LoopOc: - - mov x10, x4 - mov x12, x1 - - LoopKsize: - INIT_BIAS - mov x11, x0 - - // as some processors do not support sdot intrinsic, we use instruction word - // dp support is stilled judged dymaticly, instruction word is just used to ensure compilation - // according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf - // the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is - // 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd - // mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index - - // load input for output 1-8 - ld1 {v0.16b, v1.16b}, [x12], #32 - // load weight - ld1 {v4.16b, v5.16b}, [x2], #32 - // step for output 1-4 - smull v8.8h, v0.8b, v4.8b - smull v9.8h, v0.8b, v5.8b - smlal2 v8.8h, v0.16b, v4.16b - smlal2 v9.8h, v0.16b, v5.16b - // load input for output 9-16 - ld1 {v6.16b, v7.16b}, [x2], #32 - // another step for output 5-8 - smull v12.8h, v1.8b, v4.8b - smull v13.8h, v1.8b, v5.8b - smlal2 v12.8h, v1.16b, v4.16b - smlal2 v13.8h, v1.16b, v5.16b - ld1 {v2.16b, v3.16b}, [x12], #32 - smull v10.8h, v0.8b, v6.8b - smull v11.8h, v0.8b, v7.8b - saddlp v16.4s, v8.8h - smlal2 v10.8h, v0.16b, v6.16b - smlal2 v11.8h, v0.16b, v7.16b - saddlp v17.4s, v9.8h - smull v14.8h, v1.8b, v6.8b - smull v15.8h, v1.8b, v7.8b - saddlp v18.4s, v10.8h - smlal2 v14.8h, v1.16b, v6.16b - smlal2 v15.8h, v1.16b, v7.16b - - subs x13, x5, #1 - beq LoopIcEnd - - LoopIc: - // load input for output 1-8 - ld1 {v0.16b, v1.16b}, [x12], #32 - sadalp v19.4s, v11.8h - smull v8.8h, v2.8b, v4.8b - smull v9.8h, v2.8b, v5.8b - sadalp v20.4s, v12.8h - smlal2 v8.8h, v2.16b, v4.16b - smlal2 v9.8h, v2.16b, v5.16b - sadalp v21.4s, v13.8h - smull v10.8h, v2.8b, v6.8b - smull v11.8h, v2.8b, v7.8b - sadalp v22.4s, v14.8h - smlal2 v10.8h, v2.16b, v6.16b - smlal2 v11.8h, v2.16b, v7.16b - sadalp v23.4s, v15.8h - smull v12.8h, v3.8b, v4.8b - smull v13.8h, v3.8b, v5.8b - sadalp v24.4s, v8.8h - smlal2 v12.8h, v3.16b, v4.16b - smlal2 v13.8h, v3.16b, v5.16b - ld1 {v4.16b, v5.16b}, [x2], #32 - sadalp v25.4s, v9.8h - smull v14.8h, v3.8b, v6.8b - smull v15.8h, v3.8b, v7.8b - sadalp v26.4s, v10.8h - smlal2 v14.8h, v3.16b, v6.16b - smlal2 v15.8h, v3.16b, v7.16b - ld1 {v6.16b, v7.16b}, [x2], #32 - sadalp v27.4s, v11.8h - smull v8.8h, v0.8b, v4.8b - smull v9.8h, v0.8b, v5.8b - sadalp v28.4s, v12.8h - smlal2 v8.8h, v0.16b, v4.16b - smlal2 v9.8h, v0.16b, v5.16b - ld1 {v2.16b, v3.16b}, [x12], #32 - sadalp v29.4s, v13.8h - smull v12.8h, v1.8b, v4.8b - smull v13.8h, v1.8b, v5.8b - sadalp v30.4s, v14.8h - smlal2 v12.8h, v1.16b, v4.16b - smlal2 v13.8h, v1.16b, v5.16b - sadalp v31.4s, v15.8h - smull v10.8h, v0.8b, v6.8b - smull v11.8h, v0.8b, v7.8b - sadalp v16.4s, v8.8h - smlal2 v10.8h, v0.16b, v6.16b - smlal2 v11.8h, v0.16b, v7.16b - sadalp v17.4s, v9.8h - smull v14.8h, v1.8b, v6.8b - smull v15.8h, v1.8b, v7.8b - sadalp v18.4s, v10.8h - smlal2 v14.8h, v1.16b, v6.16b - smlal2 v15.8h, v1.16b, v7.16b - - subs x13, x13, #1 - bne LoopIc - - LoopIcEnd: - sadalp v19.4s, v11.8h - smull v8.8h, v2.8b, v4.8b - smull v9.8h, v2.8b, v5.8b - sadalp v20.4s, v12.8h - smlal2 v8.8h, v2.16b, v4.16b - smlal2 v9.8h, v2.16b, v5.16b - sadalp v21.4s, v13.8h - smull v10.8h, v2.8b, v6.8b - smull v11.8h, v2.8b, v7.8b - sadalp v22.4s, v14.8h - smlal2 v10.8h, v2.16b, v6.16b - smlal2 v11.8h, v2.16b, v7.16b - sadalp v23.4s, v15.8h - smull v12.8h, v3.8b, v4.8b - smull v13.8h, v3.8b, v5.8b - sadalp v24.4s, v8.8h - smlal2 v12.8h, v3.16b, v4.16b - smlal2 v13.8h, v3.16b, v5.16b - sadalp v25.4s, v9.8h - smull v14.8h, v3.8b, v6.8b - smull v15.8h, v3.8b, v7.8b - sadalp v26.4s, v10.8h - smlal2 v14.8h, v3.16b, v6.16b - smlal2 v15.8h, v3.16b, v7.16b - sadalp v27.4s, v11.8h - sadalp v28.4s, v12.8h - sadalp v29.4s, v13.8h - sadalp v30.4s, v14.8h - sadalp v31.4s, v15.8h - - // pairwise add - addp v16.4s, v16.4s, v17.4s - addp v18.4s, v18.4s, v19.4s - addp v20.4s, v20.4s, v21.4s - addp v22.4s, v22.4s, v23.4s - addp v24.4s, v24.4s, v25.4s - addp v26.4s, v26.4s, v27.4s - addp v28.4s, v28.4s, v29.4s - addp v30.4s, v30.4s, v31.4s - dup v12.4s, wzr - cbz x3, NoReadBias - ld1 {v12.4s}, [x3] - NoReadBias: - addp v16.4s, v16.4s, v18.4s - addp v20.4s, v20.4s, v22.4s - addp v24.4s, v24.4s, v26.4s - addp v28.4s, v28.4s, v30.4s - cbz x20, NoSum - // load sum - mov x22, x15 - cbz x21, SymSum - ld1 {v8.4s}, [x22], x23 - ld1 {v9.4s}, [x22], x23 - ld1 {v10.4s}, [x22], x23 - ld1 {v11.4s}, [x22] - b AddSum - SymSum: - ld1r {v8.4s}, [x22], #4 - ld1r {v9.4s}, [x22], #4 - ld1r {v10.4s}, [x22], #4 - ld1r {v11.4s}, [x22] - AddSum: - sub v16.4s, v16.4s, v8.4s - sub v20.4s, v20.4s, v9.4s - sub v24.4s, v24.4s, v10.4s - sub v28.4s, v28.4s, v11.4s - NoSum: - add v16.4s, v16.4s, v12.4s - add v20.4s, v20.4s, v12.4s - add v24.4s, v24.4s, v12.4s - add v28.4s, v28.4s, v12.4s - - cbnz x21, PerChannel - ld1r {v2.4s}, [x18] - ld1r {v3.4s}, [x17] - ld1r {v4.4s}, [x19] - b QuantizeStart - PerChannel: - ld1 {v2.4s}, [x18] - ld1 {v3.4s}, [x17] - ld1 {v4.4s}, [x19] - QuantizeStart: - sqshl v16.4s, v16.4s, v2.4s - sqshl v20.4s, v20.4s, v2.4s - sqshl v24.4s, v24.4s, v2.4s - sqshl v28.4s, v28.4s, v2.4s - - sqrdmulh v16.4s, v16.4s, v3.4s - sqrdmulh v20.4s, v20.4s, v3.4s - sqrdmulh v24.4s, v24.4s, v3.4s - sqrdmulh v28.4s, v28.4s, v3.4s - - and v0.16b, v4.16b, v16.16b - sshr v0.4s, v0.4s, #31 - sqadd v16.4s, v16.4s, v0.4s - srshl v16.4s, v16.4s, v4.4s - and v1.16b, v4.16b, v20.16b - sshr v1.4s, v1.4s, #31 - sqadd v20.4s, v20.4s, v1.4s - srshl v20.4s, v20.4s, v4.4s - and v2.16b, v4.16b, v24.16b - sshr v2.4s, v2.4s, #31 - sqadd v24.4s, v24.4s, v2.4s - srshl v24.4s, v24.4s, v4.4s - and v3.16b, v4.16b, v28.16b - sshr v3.4s, v3.4s, #31 - sqadd v28.4s, v28.4s, v3.4s - srshl v28.4s, v28.4s, v4.4s - - dup v5.4s, w16 - add v16.4s, v16.4s, v5.4s - add v20.4s, v20.4s, v5.4s - add v24.4s, v24.4s, v5.4s - add v28.4s, v28.4s, v5.4s - - dup v0.4s, w8 - smax v16.4s, v16.4s, v0.4s - smax v20.4s, v20.4s, v0.4s - smax v24.4s, v24.4s, v0.4s - smax v28.4s, v28.4s, v0.4s - - dup v1.4s, w9 - smin v16.4s, v16.4s, v1.4s - smin v20.4s, v20.4s, v1.4s - smin v24.4s, v24.4s, v1.4s - smin v28.4s, v28.4s, v1.4s - - sqxtn v13.4h, v16.4s - sqxtn2 v13.8h, v20.4s - sqxtn v15.8b, v13.8h - sqxtn v14.4h, v24.4s - sqxtn2 v14.8h, v28.4s - sqxtn2 v15.16b, v14.8h - - // prefetching is not prefered while writing results in spite of cache missings - // you could try prfm pstl2strm - WriteStart: - cmp x6, #1 - beq Write1 - cmp x6, #2 - beq Write2 - cmp x6, #3 - beq Write3 - b Write4 - Write1: - st1 {v15.b}[0], [x11], x7 - st1 {v15.b}[4], [x11], x7 - st1 {v15.b}[8], [x11], x7 - st1 {v15.b}[12], [x11] - add x0, x0, #1 - b WriteEnd - Write2: - st1 {v15.h}[0], [x11], x7 - st1 {v15.h}[2], [x11], x7 - st1 {v15.h}[4], [x11], x7 - st1 {v15.h}[6], [x11] - add x0, x0, #2 - b WriteEnd - Write3: - add x14, x11, #2 - st1 {v15.h}[0], [x11], x7 - st1 {v15.b}[2], [x14], x7 - st1 {v15.h}[2], [x11], x7 - st1 {v15.b}[6], [x14], x7 - st1 {v15.h}[4], [x11], x7 - st1 {v15.b}[10], [x14], x7 - st1 {v15.h}[6], [x11] - st1 {v15.b}[14], [x14] - add x0, x0, #3 - b WriteEnd - Write4: - st1 {v15.s}[0], [x11], x7 - st1 {v15.s}[1], [x11], x7 - st1 {v15.s}[2], [x11], x7 - st1 {v15.s}[3], [x11] - add x0, x0, #4 - - WriteEnd: - - subs x10, x10, #1 - bne LoopKsize - - subs x6, x6, #4 - cbz x21, NoChannelForward - cbz x20, NoSumForward - add x15, x15, #16 - NoSumForward: - add x17, x17, #16 - add x18, x18, #16 - add x19, x19, #16 - NoChannelForward: - cbz x3, NoStepFowrard - add x3, x3, #16 - NoStepFowrard: - bgt LoopOc - - sub sp, sp, #176 - ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 - ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - ldp x19, x20, [sp], #16 - ldp x21, x22, [sp], #16 - ldp x23, x24, [sp], #16 - ret -#endif - diff --git a/mindspore/lite/nnacl/assembly/arm64/MatmulInt8Opt.S b/mindspore/lite/nnacl/assembly/arm64/MatmulInt8Opt.S new file mode 100644 index 0000000000..7ad867385d --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/MatmulInt8Opt.S @@ -0,0 +1,408 @@ +#ifdef __aarch64__ + .text + .align 5 + .global MatmulInt8Neon64Opt +#ifndef __APPLE__ + .type MatmulInt8Neon64Opt, %function +#endif + +//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, +// const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, +// int32_t *right_shift, int row, int col, int stride, int filter_peroc, int32_t *filter_zp) + +// x0: a(left matrix ptr) +// x1: b(right matrix ptr) +// x2: out ptr +// w3: row4 +// w4: col4 +// w5: deep16 +// x6: a_sums +// x7: bias +// w8: act_min +// w9: act_max +// w10: out_zp +// x11: multiplier +// x12: left_shift +// x13: right_shift +// w14: row +// w15: col +// w24: stride +// w27: filter_peroc +// x28: filter_zp + +MatmulInt8Neon64Opt: + sub sp, sp, #208 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + stp x23, x24, [sp], #16 + stp x25, x26, [sp], #16 + stp x27, x28, [sp], #16 + + ldr w8, [sp] + ldr w9, [sp, #8] + ldr w10, [sp, #16] + ldr x11, [sp, #24] + ldr x12, [sp, #32] + ldr x13, [sp, #40] + ldr w14, [sp, #48] + ldr w15, [sp, #56] + ldr w24, [sp, #64] + ldr w27, [sp, #72] + ldr x28, [sp, #80] + + mov w17, #4 // sizeof(int8)*4 + mul w21, w5, w17 // the stride of a/b: sizeof(int8)*4*deep16 + mov w17, #1 + mov x25, x2 +L1: + cmp w4, #0 // if at the end of col4 + beq End1 + + mov w16, w3 // reset a row4 counter + mov w23, w14 // reset a row counter + mov x17, x0 // reload a ptr + mov x22, x6 // reload a_sums ptr +L2: + cmp w16, #0 + beq End2 + + mov x18, x1 // reload b ptr + mov x19, x7 // reload bias ptr + mov w20, w5 // reload depth + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + dup v24.4s, wzr + dup v25.4s, wzr + dup v26.4s, wzr + dup v27.4s, wzr + dup v28.4s, wzr + dup v29.4s, wzr + dup v30.4s, wzr + dup v31.4s, wzr +L3: + cmp w20, #0 + beq End3 + + ld1 {v0.16b}, [x17], #16 + ld1 {v1.16b}, [x17], #16 + ld1 {v2.16b}, [x17], #16 + ld1 {v3.16b}, [x17], #16 + ld1 {v4.16b}, [x18], #16 + ld1 {v5.16b}, [x18], #16 + ld1 {v6.16b}, [x18], #16 + ld1 {v7.16b}, [x18], #16 + + smull v8.8h, v4.8b, v0.8b + smull v9.8h, v5.8b, v0.8b + smull v10.8h, v6.8b, v0.8b + smull v11.8h, v7.8b, v0.8b + smull v12.8h, v4.8b, v1.8b + smull v13.8h, v5.8b, v1.8b + smull v14.8h, v6.8b, v1.8b + smull v15.8h, v7.8b, v1.8b + + smlal2 v8.8h, v4.16b, v0.16b + smlal2 v9.8h, v5.16b, v0.16b + smlal2 v10.8h, v6.16b, v0.16b + smlal2 v11.8h, v7.16b, v0.16b + smlal2 v12.8h, v4.16b, v1.16b + smlal2 v13.8h, v5.16b, v1.16b + smlal2 v14.8h, v6.16b, v1.16b + smlal2 v15.8h, v7.16b, v1.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v4.8b, v2.8b + smull v9.8h, v5.8b, v2.8b + smull v10.8h, v6.8b, v2.8b + smull v11.8h, v7.8b, v2.8b + smull v12.8h, v4.8b, v3.8b + smull v13.8h, v5.8b, v3.8b + smull v14.8h, v6.8b, v3.8b + smull v15.8h, v7.8b, v3.8b + + smlal2 v8.8h, v4.16b, v2.16b + smlal2 v9.8h, v5.16b, v2.16b + smlal2 v10.8h, v6.16b, v2.16b + smlal2 v11.8h, v7.16b, v2.16b + smlal2 v12.8h, v4.16b, v3.16b + smlal2 v13.8h, v5.16b, v3.16b + smlal2 v14.8h, v6.16b, v3.16b + smlal2 v15.8h, v7.16b, v3.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + subs w20, w20, #16 // depth + 16 + b L3 + +End3: + addp v16.4s, v16.4s, v17.4s + addp v18.4s, v18.4s, v19.4s + addp v20.4s, v20.4s, v21.4s + addp v22.4s, v22.4s, v23.4s + addp v24.4s, v24.4s, v25.4s + addp v26.4s, v26.4s, v27.4s + addp v28.4s, v28.4s, v29.4s + addp v30.4s, v30.4s, v31.4s + + addp v16.4s, v16.4s, v18.4s + addp v17.4s, v20.4s, v22.4s + addp v18.4s, v24.4s, v26.4s + addp v19.4s, v28.4s, v30.4s + + // Add (Bias+Depth*Za*Zb-Za*Bsums) + ld1 {v15.4s}, [x19], #16 + add v16.4s, v16.4s, v15.4s + add v17.4s, v17.4s, v15.4s + add v18.4s, v18.4s, v15.4s + add v19.4s, v19.4s, v15.4s + + ld1r {v20.4s}, [x22], #4 + ld1r {v21.4s}, [x22], #4 + ld1r {v22.4s}, [x22], #4 + ld1r {v23.4s}, [x22], #4 + cmp w27, #0 + beq Apply + ld1 {v14.4s}, [x28] + mul v20.4s, v20.4s, v14.4s + mul v21.4s, v21.4s, v14.4s + mul v22.4s, v22.4s, v14.4s + mul v23.4s, v23.4s, v14.4s + +Apply: + // Subtract (Asums*Zb) + sub v16.4s, v16.4s, v20.4s + sub v17.4s, v17.4s, v21.4s + sub v18.4s, v18.4s, v22.4s + sub v19.4s, v19.4s, v23.4s + + cmp w27, #1 + beq PerCLoad + + ld1r {v13.4s}, [x12] + ld1r {v12.4s}, [x11] + ld1r {v11.4s}, [x13] + b Quantize + +PerCLoad: + ld1 {v13.4s}, [x12] + ld1 {v12.4s}, [x11] + ld1 {v11.4s}, [x13] +Quantize: + + // Apply left shift + sqshl v16.4s, v16.4s, v13.4s + sqshl v17.4s, v17.4s, v13.4s + sqshl v18.4s, v18.4s, v13.4s + sqshl v19.4s, v19.4s, v13.4s + + // Apply the fixed-point part of the multiplier. + sqrdmulh v16.4s, v16.4s, v12.4s + sqrdmulh v17.4s, v17.4s, v12.4s + sqrdmulh v18.4s, v18.4s, v12.4s + sqrdmulh v19.4s, v19.4s, v12.4s + + // Apply right shift + and v20.16b, v11.16b, v16.16b + sshr v20.4s, v20.4s, #31 + sqadd v16.4s, v16.4s, v20.4s + srshl v16.4s, v16.4s, v11.4s + and v21.16b, v11.16b, v17.16b + sshr v21.4s, v21.4s, #31 + sqadd v17.4s, v17.4s, v21.4s + srshl v17.4s, v17.4s, v11.4s + and v22.16b, v11.16b, v18.16b + sshr v22.4s, v22.4s, #31 + sqadd v18.4s, v18.4s, v22.4s + srshl v18.4s, v18.4s, v11.4s + and v23.16b, v11.16b, v19.16b + sshr v23.4s, v23.4s, #31 + sqadd v19.4s, v19.4s, v23.4s + srshl v19.4s, v19.4s, v11.4s + + // Add the destination zero point + dup v10.4s, w10 + add v16.4s, v16.4s, v10.4s + add v17.4s, v17.4s, v10.4s + add v18.4s, v18.4s, v10.4s + add v19.4s, v19.4s, v10.4s + + // Apply the act_min bound + dup v9.4s, w8 + smax v16.4s, v16.4s, v9.4s + smax v17.4s, v17.4s, v9.4s + smax v18.4s, v18.4s, v9.4s + smax v19.4s, v19.4s, v9.4s + + // Apply the act_min bound + dup v8.4s, w9 + smin v16.4s, v16.4s, v8.4s + smin v17.4s, v17.4s, v8.4s + smin v18.4s, v18.4s, v8.4s + smin v19.4s, v19.4s, v8.4s + + // int32 -> int16 + sqxtn v13.4h, v16.4s + sqxtn2 v13.8h, v17.4s + sqxtn v14.4h, v18.4s + sqxtn2 v14.8h, v19.4s + + // int16 -> int8 + sqxtn v15.8b, v13.8h + sqxtn2 v15.16b, v14.8h + + cmp w23, #4 + blt Write // if rows < 4 + cmp w15, #4 + blt Write // if cols < 4 + + st1 {v15.s}[0], [x2], x24 + st1 {v15.s}[1], [x2], x24 + st1 {v15.s}[2], [x2], x24 + st1 {v15.s}[3], [x2], x24 + b Endwrite + +Write: + cmp w15, #4 + beq WriteCol4 + cmp w15, #3 + beq WriteCol3 + cmp w15, #2 + beq WriteCol2 + cmp w15, #1 + beq WriteCol1 + +WriteCol4: + st1 {v15.s}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v15.s}[1], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v15.s}[2], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v15.s}[3], [x2], x24 + b Endwrite + +WriteCol3: + mov x26, x2 + st1 {v15.b}[0], [x26], #1 + st1 {v15.b}[1], [x26], #1 + st1 {v15.b}[2], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v15.b}[4], [x26], #1 + st1 {v15.b}[5], [x26], #1 + st1 {v15.b}[6], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v15.b}[8], [x26], #1 + st1 {v15.b}[9], [x26], #1 + st1 {v15.b}[10], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v15.b}[12], [x26], #1 + st1 {v15.b}[13], [x26], #1 + st1 {v15.b}[14], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol2: + mov x26, x2 + st1 {v15.b}[0], [x26], #1 + st1 {v15.b}[1], [x26], #1 + add x2, x2, x24 + cmp w23, #1 + beq Endwrite + mov x26, x2 + st1 {v15.b}[4], [x26], #1 + st1 {v15.b}[5], [x26], #1 + add x2, x2, x24 + cmp w23, #2 + beq Endwrite + mov x26, x2 + st1 {v15.b}[8], [x26], #1 + st1 {v15.b}[9], [x26], #1 + add x2, x2, x24 + cmp w23, #3 + beq Endwrite + mov x26, x2 + st1 {v15.b}[12], [x26], #1 + st1 {v15.b}[13], [x26], #1 + add x2, x2, x24 + b Endwrite + +WriteCol1: + st1 {v15.b}[0], [x2], x24 + cmp w23, #1 + beq Endwrite + st1 {v15.b}[4], [x2], x24 + cmp w23, #2 + beq Endwrite + st1 {v15.b}[8], [x2], x24 + cmp w23, #3 + beq Endwrite + st1 {v15.b}[12], [x2], x24 + b Endwrite + +Endwrite: + sub w16, w16, #4 // a row4 counter - 4 + sub w23, w23, #4 // a row counter - 4 + b L2 + +End2: + sub w4, w4, #4 // b col4 counter - 4 + sub w15, w15, #4 // b col counter - 4 + add x1, x1, x21 // b ptr + stride + add x7, x7, #16 // bias ptr + stride + add x25, x25, #4 // output + stride(4 * sizeof(int8)) + mov x2, x25 + + cmp w27, #0 + beq PerTEnd2 + add x12, x12, #16 + add x11, x11, #16 + add x13, x13, #16 + add x28, x28, #16 +PerTEnd2: + b L1 + +End1: + sub sp, sp, #208 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ldp x23, x24, [sp], #16 + ldp x25, x26, [sp], #16 + ldp x27, x28, [sp], #16 + ret +#endif diff --git a/mindspore/lite/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S b/mindspore/lite/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S deleted file mode 100644 index 2c43efb982..0000000000 --- a/mindspore/lite/nnacl/assembly/opt/IndirectGemmInt8_24x4_dp.S +++ /dev/null @@ -1,785 +0,0 @@ -#ifdef __aarch64__ - -.text -.align 5 -.global IndirectGemmInt8_24x4_dp -#ifndef __APPLE__ -.type IndirectGemmInt8_24x4_dp, %function -#endif - -// void IndirectGemmInt8_24x4_dp(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4, -// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, int32_t *out_multiplier, -// int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); -// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset -// we use sdot intrinsic on cores that supports dotprod(Armv8.2-A w/dp or later) -// mrs intrinsic could read system register ID_AA64ISAR0_EL1(or s3_0_c0_c6_0 on Armv8.2-A) -// the 44-48 bits indicates whether dotprod is supported -IndirectGemmInt8_24x4_dp: - - .macro INIT_BIAS - dup v7.4s, wzr - cbz x3, InitBias - ld1 {v7.4s}, [x3] - InitBias: - cbz x20, NoSum - mov x22, x15 - cbz x21, SymSum - ld1 {v8.4s}, [x22], x23 - ld1 {v9.4s}, [x22], x23 - ld1 {v10.4s}, [x22], x23 - ld1 {v11.4s}, [x22], x23 - ld1 {v12.4s}, [x22], x23 - ld1 {v13.4s}, [x22], x23 - ld1 {v14.4s}, [x22], x23 - ld1 {v15.4s}, [x22], x23 - ld1 {v16.4s}, [x22], x23 - ld1 {v17.4s}, [x22], x23 - ld1 {v18.4s}, [x22], x23 - ld1 {v19.4s}, [x22], x23 - ld1 {v20.4s}, [x22], x23 - ld1 {v21.4s}, [x22], x23 - ld1 {v22.4s}, [x22], x23 - ld1 {v23.4s}, [x22], x23 - ld1 {v24.4s}, [x22], x23 - ld1 {v25.4s}, [x22], x23 - ld1 {v26.4s}, [x22], x23 - ld1 {v27.4s}, [x22], x23 - ld1 {v28.4s}, [x22], x23 - ld1 {v29.4s}, [x22], x23 - ld1 {v30.4s}, [x22], x23 - ld1 {v31.4s}, [x22], x23 - b AddSum - SymSum: - ld1r {v8.4s}, [x22], #4 - ld1r {v9.4s}, [x22], #4 - ld1r {v10.4s}, [x22], #4 - ld1r {v11.4s}, [x22], #4 - ld1r {v12.4s}, [x22], #4 - ld1r {v13.4s}, [x22], #4 - ld1r {v14.4s}, [x22], #4 - ld1r {v15.4s}, [x22], #4 - ld1r {v16.4s}, [x22], #4 - ld1r {v17.4s}, [x22], #4 - ld1r {v18.4s}, [x22], #4 - ld1r {v19.4s}, [x22], #4 - ld1r {v20.4s}, [x22], #4 - ld1r {v21.4s}, [x22], #4 - ld1r {v22.4s}, [x22], #4 - ld1r {v23.4s}, [x22], #4 - ld1r {v24.4s}, [x22], #4 - ld1r {v25.4s}, [x22], #4 - ld1r {v26.4s}, [x22], #4 - ld1r {v27.4s}, [x22], #4 - ld1r {v28.4s}, [x22], #4 - ld1r {v29.4s}, [x22], #4 - ld1r {v30.4s}, [x22], #4 - ld1r {v31.4s}, [x22], #4 - AddSum: - sub v8.4s, v7.4s, v8.4s - sub v9.4s, v7.4s, v9.4s - sub v10.4s, v7.4s, v10.4s - sub v11.4s, v7.4s, v11.4s - sub v12.4s, v7.4s, v12.4s - sub v13.4s, v7.4s, v13.4s - sub v14.4s, v7.4s, v14.4s - sub v15.4s, v7.4s, v15.4s - sub v16.4s, v7.4s, v16.4s - sub v17.4s, v7.4s, v17.4s - sub v18.4s, v7.4s, v18.4s - sub v19.4s, v7.4s, v19.4s - sub v20.4s, v7.4s, v20.4s - sub v21.4s, v7.4s, v21.4s - sub v22.4s, v7.4s, v22.4s - sub v23.4s, v7.4s, v23.4s - sub v24.4s, v7.4s, v24.4s - sub v25.4s, v7.4s, v25.4s - sub v26.4s, v7.4s, v26.4s - sub v27.4s, v7.4s, v27.4s - sub v28.4s, v7.4s, v28.4s - sub v29.4s, v7.4s, v29.4s - sub v30.4s, v7.4s, v30.4s - sub v31.4s, v7.4s, v31.4s - b InitBiasEnd - NoSum: - mov v8.16b, v7.16b - mov v9.16b, v7.16b - mov v10.16b, v7.16b - mov v11.16b, v7.16b - mov v12.16b, v7.16b - mov v13.16b, v7.16b - mov v14.16b, v7.16b - mov v15.16b, v7.16b - mov v16.16b, v7.16b - mov v17.16b, v7.16b - mov v18.16b, v7.16b - mov v19.16b, v7.16b - mov v20.16b, v7.16b - mov v21.16b, v7.16b - mov v22.16b, v7.16b - mov v23.16b, v7.16b - mov v24.16b, v7.16b - mov v25.16b, v7.16b - mov v26.16b, v7.16b - mov v27.16b, v7.16b - mov v28.16b, v7.16b - mov v29.16b, v7.16b - mov v30.16b, v7.16b - mov v31.16b, v7.16b - InitBiasEnd: - .endm - - // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to - // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers - // r19 ~ r29 should be also preserved - // whereas our coding style do not permit such amount of parameters - sub sp, sp, #176 - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - stp x19, x20, [sp], #16 - stp x21, x22, [sp], #16 - stp x23, x24, [sp], #16 - - ldr x15, [sp] - ldr w8, [sp, #8] - ldr w9, [sp, #16] - ldr w16, [sp, #24] - ldr x17, [sp, #32] - ldr x18, [sp, #40] - ldr x19, [sp, #48] - ldr x20, [sp, #56] - ldr x21, [sp, #64] - ldr x23, [sp, #72] - - mul x5, x4, x5 - mov x4, #1 - - LoopOc: - - mov x10, x4 - mov x12, x1 - - LoopKsize: - INIT_BIAS - mov x11, x0 - - // as some processors do not support sdot intrinsic, we use instruction word - // dp support is stilled judged dymaticly, instruction word is just used to ensure compilation - // according to https://static.docs.arm.com/ddi0596/g/ISA_A64_xml_v86A-2020-03_OPT.pdf - // the instruction word of sdot vd.4s, vn.16b, vm.4b[index] is - // 0100 1111 10Lm mmmm 1110 H0nn nnnd dddd - // mmmmm/nnnnn/ddddd is the number of neon register, HL is the high/low bit of index - - // load input for output 1-8 - ld1 {v0.16b, v1.16b}, [x12], #32 - // load weight - ld1 {v6.16b}, [x2], #16 - // step for output 1-4 - .inst 0x4f80e0c8 // sdot v8.4s, v6.16b, v0.4b[0] - .inst 0x4fa0e0c9 // sdot v9.4s, v6.16b, v0.4b[1] - .inst 0x4f80e8ca // sdot v10.4s, v6.16b, v0.4b[2] - .inst 0x4fa0e8cb // sdot v11.4s, v6.16b, v0.4b[3] - // load input for output 9-16 - ld1 {v2.16b, v3.16b, v4.16b, v5.16b}, [x12], #64 - // another step for output 5-8 - .inst 0x4f81e0cc // sdot v12.4s, v6.16b, v1.4b[0] - .inst 0x4fa1e0cd // sdot v13.4s, v6.16b, v1.4b[1] - .inst 0x4f81e8ce // sdot v14.4s, v6.16b, v1.4b[2] - .inst 0x4fa1e8cf // sdot v15.4s, v6.16b, v1.4b[3] - - subs x13, x5, #1 - beq LoopIcEndOne - // load weight - ld1 {v7.16b}, [x2], #16 - cmp x13, #1 - beq LoopIcEnd - - LoopIc: - // load input for output 1-8 - ld1 {v0.16b, v1.16b}, [x12], #32 - .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] - .inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1] - .inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2] - .inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3] - .inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0] - .inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1] - .inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2] - .inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3] - ld1 {v2.16b, v3.16b}, [x12], #32 - .inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0] - .inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1] - .inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2] - .inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3] - .inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0] - .inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1] - .inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2] - .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] - // load input for output 9-16 - ld1 {v4.4s, v5.4s}, [x12], #32 - .inst 0x4f80e0e8 // sdot v8.4s, v7.16b, v0.4b[0] - .inst 0x4fa0e0e9 // sdot v9.4s, v7.16b, v0.4b[1] - .inst 0x4f80e8ea // sdot v10.4s, v7.16b, v0.4b[2] - .inst 0x4fa0e8eb // sdot v11.4s, v7.16b, v0.4b[3] - // another step for output 5-8 - .inst 0x4f81e0ec // sdot v12.4s, v7.16b, v1.4b[0] - .inst 0x4fa1e0ed // sdot v13.4s, v7.16b, v1.4b[1] - .inst 0x4f81e8ee // sdot v14.4s, v7.16b, v1.4b[2] - .inst 0x4fa1e8ef // sdot v15.4s, v7.16b, v1.4b[3] - // load input for output 1-8 - ld1 {v0.16b, v1.16b}, [x12], #32 - .inst 0x4f82e0f0 // sdot v16.4s, v7.16b, v2.4b[0] - .inst 0x4fa2e0f1 // sdot v17.4s, v7.16b, v2.4b[1] - .inst 0x4f82e8f2 // sdot v18.4s, v7.16b, v2.4b[2] - .inst 0x4fa2e8f3 // sdot v19.4s, v7.16b, v2.4b[3] - .inst 0x4f83e0f4 // sdot v20.4s, v7.16b, v3.4b[0] - .inst 0x4fa3e0f5 // sdot v21.4s, v7.16b, v3.4b[1] - .inst 0x4f83e8f6 // sdot v22.4s, v7.16b, v3.4b[2] - .inst 0x4fa3e8f7 // sdot v23.4s, v7.16b, v3.4b[3] - // load weight - ld1 {v6.16b}, [x2], #16 - .inst 0x4f84e0f8 // sdot v24.4s, v7.16b, v4.4b[0] - .inst 0x4fa4e0f9 // sdot v25.4s, v7.16b, v4.4b[1] - .inst 0x4f84e8fa // sdot v26.4s, v7.16b, v4.4b[2] - .inst 0x4fa4e8fb // sdot v27.4s, v7.16b, v4.4b[3] - .inst 0x4f85e0fc // sdot v28.4s, v7.16b, v5.4b[0] - .inst 0x4fa5e0fd // sdot v29.4s, v7.16b, v5.4b[1] - .inst 0x4f85e8fe // sdot v30.4s, v7.16b, v5.4b[2] - .inst 0x4fa5e8ff // sdot v31.4s, v7.16b, v5.4b[3] - // load input for output 9-16 - ld1 {v2.4s, v3.4s}, [x12], #32 - .inst 0x4f80e0c8 // sdot v8.4s, v6.16b, v0.4b[0] - .inst 0x4fa0e0c9 // sdot v9.4s, v6.16b, v0.4b[1] - .inst 0x4f80e8ca // sdot v10.4s, v6.16b, v0.4b[2] - .inst 0x4fa0e8cb // sdot v11.4s, v6.16b, v0.4b[3] - // another step for output 5-8 - .inst 0x4f81e0cc // sdot v12.4s, v6.16b, v1.4b[0] - .inst 0x4fa1e0cd // sdot v13.4s, v6.16b, v1.4b[1] - .inst 0x4f81e8ce // sdot v14.4s, v6.16b, v1.4b[2] - .inst 0x4fa1e8cf // sdot v15.4s, v6.16b, v1.4b[3] - // load input for output 9-16 - ld1 {v4.4s, v5.4s}, [x12], #32 - - subs x13, x13, #2 - beq LoopIcEndOne - // load weight - ld1 {v7.16b}, [x2], #16 - cmp x13, #1 - beq LoopIcEnd - b LoopIc - - LoopIcEnd: - mov x22, x15 - // load input for output 1-8 - ld1 {v0.16b, v1.16b}, [x12], #32 - .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] - .inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1] - .inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2] - .inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3] - .inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0] - .inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1] - .inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2] - .inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3] - ld1 {v2.16b, v3.16b}, [x12], #32 - .inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0] - .inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1] - .inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2] - .inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3] - .inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0] - .inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1] - .inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2] - .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] - // load input for output 9-16 - ld1 {v4.4s, v5.4s}, [x12], #32 - .inst 0x4f80e0e8 // sdot v8.4s, v7.16b, v0.4b[0] - .inst 0x4fa0e0e9 // sdot v9.4s, v7.16b, v0.4b[1] - .inst 0x4f80e8ea // sdot v10.4s, v7.16b, v0.4b[2] - .inst 0x4fa0e8eb // sdot v11.4s, v7.16b, v0.4b[3] - .inst 0x4f81e0ec // sdot v12.4s, v7.16b, v1.4b[0] - .inst 0x4fa1e0ed // sdot v13.4s, v7.16b, v1.4b[1] - .inst 0x4f81e8ee // sdot v14.4s, v7.16b, v1.4b[2] - .inst 0x4fa1e8ef // sdot v15.4s, v7.16b, v1.4b[3] - - .inst 0x4f82e0f0 // sdot v16.4s, v7.16b, v2.4b[0] - .inst 0x4fa2e0f1 // sdot v17.4s, v7.16b, v2.4b[1] - .inst 0x4f82e8f2 // sdot v18.4s, v7.16b, v2.4b[2] - .inst 0x4fa2e8f3 // sdot v19.4s, v7.16b, v2.4b[3] - .inst 0x4f83e0f4 // sdot v20.4s, v7.16b, v3.4b[0] - .inst 0x4fa3e0f5 // sdot v21.4s, v7.16b, v3.4b[1] - .inst 0x4f83e8f6 // sdot v22.4s, v7.16b, v3.4b[2] - .inst 0x4fa3e8f7 // sdot v23.4s, v7.16b, v3.4b[3] - - .inst 0x4f84e0f8 // sdot v24.4s, v7.16b, v4.4b[0] - .inst 0x4fa4e0f9 // sdot v25.4s, v7.16b, v4.4b[1] - .inst 0x4f84e8fa // sdot v26.4s, v7.16b, v4.4b[2] - .inst 0x4fa4e8fb // sdot v27.4s, v7.16b, v4.4b[3] - .inst 0x4f85e0fc // sdot v28.4s, v7.16b, v5.4b[0] - .inst 0x4fa5e0fd // sdot v29.4s, v7.16b, v5.4b[1] - .inst 0x4f85e8fe // sdot v30.4s, v7.16b, v5.4b[2] - .inst 0x4fa5e8ff // sdot v31.4s, v7.16b, v5.4b[3] - b Quantization - - LoopIcEndOne: - .inst 0x4f82e0d0 // sdot v16.4s, v6.16b, v2.4b[0] - .inst 0x4fa2e0d1 // sdot v17.4s, v6.16b, v2.4b[1] - .inst 0x4f82e8d2 // sdot v18.4s, v6.16b, v2.4b[2] - .inst 0x4fa2e8d3 // sdot v19.4s, v6.16b, v2.4b[3] - .inst 0x4f83e0d4 // sdot v20.4s, v6.16b, v3.4b[0] - .inst 0x4fa3e0d5 // sdot v21.4s, v6.16b, v3.4b[1] - .inst 0x4f83e8d6 // sdot v22.4s, v6.16b, v3.4b[2] - .inst 0x4fa3e8d7 // sdot v23.4s, v6.16b, v3.4b[3] - - .inst 0x4f84e0d8 // sdot v24.4s, v6.16b, v4.4b[0] - .inst 0x4fa4e0d9 // sdot v25.4s, v6.16b, v4.4b[1] - .inst 0x4f84e8da // sdot v26.4s, v6.16b, v4.4b[2] - .inst 0x4fa4e8db // sdot v27.4s, v6.16b, v4.4b[3] - .inst 0x4f85e0dc // sdot v28.4s, v6.16b, v5.4b[0] - .inst 0x4fa5e0dd // sdot v29.4s, v6.16b, v5.4b[1] - .inst 0x4f85e8de // sdot v30.4s, v6.16b, v5.4b[2] - .inst 0x4fa5e8df // sdot v31.4s, v6.16b, v5.4b[3] - - Quantization: - cbnz x21, PerChannel - ld1r {v2.4s}, [x18] - ld1r {v3.4s}, [x17] - ld1r {v4.4s}, [x19] - b QuantizeStart - PerChannel: - ld1 {v2.4s}, [x18] - ld1 {v3.4s}, [x17] - ld1 {v4.4s}, [x19] - QuantizeStart: - sqshl v8.4s, v8.4s, v2.4s - sqshl v9.4s, v9.4s, v2.4s - sqshl v10.4s, v10.4s, v2.4s - sqshl v11.4s, v11.4s, v2.4s - sqshl v12.4s, v12.4s, v2.4s - sqshl v13.4s, v13.4s, v2.4s - sqshl v14.4s, v14.4s, v2.4s - sqshl v15.4s, v15.4s, v2.4s - sqshl v16.4s, v16.4s, v2.4s - sqshl v17.4s, v17.4s, v2.4s - sqshl v18.4s, v18.4s, v2.4s - sqshl v19.4s, v19.4s, v2.4s - sqshl v20.4s, v20.4s, v2.4s - sqshl v21.4s, v21.4s, v2.4s - sqshl v22.4s, v22.4s, v2.4s - sqshl v23.4s, v23.4s, v2.4s - sqshl v24.4s, v24.4s, v2.4s - sqshl v25.4s, v25.4s, v2.4s - sqshl v26.4s, v26.4s, v2.4s - sqshl v27.4s, v27.4s, v2.4s - sqshl v28.4s, v28.4s, v2.4s - sqshl v29.4s, v29.4s, v2.4s - sqshl v30.4s, v30.4s, v2.4s - sqshl v31.4s, v31.4s, v2.4s - - sqrdmulh v8.4s, v8.4s, v3.4s - sqrdmulh v9.4s, v9.4s, v3.4s - sqrdmulh v10.4s, v10.4s, v3.4s - sqrdmulh v11.4s, v11.4s, v3.4s - sqrdmulh v12.4s, v12.4s, v3.4s - sqrdmulh v13.4s, v13.4s, v3.4s - sqrdmulh v14.4s, v14.4s, v3.4s - sqrdmulh v15.4s, v15.4s, v3.4s - sqrdmulh v16.4s, v16.4s, v3.4s - sqrdmulh v17.4s, v17.4s, v3.4s - sqrdmulh v18.4s, v18.4s, v3.4s - sqrdmulh v19.4s, v19.4s, v3.4s - sqrdmulh v20.4s, v20.4s, v3.4s - sqrdmulh v21.4s, v21.4s, v3.4s - sqrdmulh v22.4s, v22.4s, v3.4s - sqrdmulh v23.4s, v23.4s, v3.4s - sqrdmulh v24.4s, v24.4s, v3.4s - sqrdmulh v25.4s, v25.4s, v3.4s - sqrdmulh v26.4s, v26.4s, v3.4s - sqrdmulh v27.4s, v27.4s, v3.4s - sqrdmulh v28.4s, v28.4s, v3.4s - sqrdmulh v29.4s, v29.4s, v3.4s - sqrdmulh v30.4s, v30.4s, v3.4s - sqrdmulh v31.4s, v31.4s, v3.4s - - and v0.16b, v4.16b, v8.16b - sshr v0.4s, v0.4s, #31 - sqadd v8.4s, v8.4s, v0.4s - srshl v8.4s, v8.4s, v4.4s - and v1.16b, v4.16b, v9.16b - sshr v1.4s, v1.4s, #31 - sqadd v9.4s, v9.4s, v1.4s - srshl v9.4s, v9.4s, v4.4s - and v2.16b, v4.16b, v10.16b - sshr v2.4s, v2.4s, #31 - sqadd v10.4s, v10.4s, v2.4s - srshl v10.4s, v10.4s, v4.4s - and v3.16b, v4.16b, v11.16b - sshr v3.4s, v3.4s, #31 - sqadd v11.4s, v11.4s, v3.4s - srshl v11.4s, v11.4s, v4.4s - and v0.16b, v4.16b, v12.16b - sshr v0.4s, v0.4s, #31 - sqadd v12.4s, v12.4s, v0.4s - srshl v12.4s, v12.4s, v4.4s - and v1.16b, v4.16b, v13.16b - sshr v1.4s, v1.4s, #31 - sqadd v13.4s, v13.4s, v1.4s - srshl v13.4s, v13.4s, v4.4s - and v2.16b, v4.16b, v14.16b - sshr v2.4s, v2.4s, #31 - sqadd v14.4s, v14.4s, v2.4s - srshl v14.4s, v14.4s, v4.4s - and v3.16b, v4.16b, v15.16b - sshr v3.4s, v3.4s, #31 - sqadd v15.4s, v15.4s, v3.4s - srshl v15.4s, v15.4s, v4.4s - and v0.16b, v4.16b, v16.16b - sshr v0.4s, v0.4s, #31 - sqadd v16.4s, v16.4s, v0.4s - srshl v16.4s, v16.4s, v4.4s - and v1.16b, v4.16b, v17.16b - sshr v1.4s, v1.4s, #31 - sqadd v17.4s, v17.4s, v1.4s - srshl v17.4s, v17.4s, v4.4s - and v2.16b, v4.16b, v18.16b - sshr v2.4s, v2.4s, #31 - sqadd v18.4s, v18.4s, v2.4s - srshl v18.4s, v18.4s, v4.4s - and v3.16b, v4.16b, v19.16b - sshr v3.4s, v3.4s, #31 - sqadd v19.4s, v19.4s, v3.4s - srshl v19.4s, v19.4s, v4.4s - and v0.16b, v4.16b, v20.16b - sshr v0.4s, v0.4s, #31 - sqadd v20.4s, v20.4s, v0.4s - srshl v20.4s, v20.4s, v4.4s - and v1.16b, v4.16b, v21.16b - sshr v1.4s, v1.4s, #31 - sqadd v21.4s, v21.4s, v1.4s - srshl v21.4s, v21.4s, v4.4s - and v2.16b, v4.16b, v22.16b - sshr v2.4s, v2.4s, #31 - sqadd v22.4s, v22.4s, v2.4s - srshl v22.4s, v22.4s, v4.4s - and v3.16b, v4.16b, v23.16b - sshr v3.4s, v3.4s, #31 - sqadd v23.4s, v23.4s, v3.4s - srshl v23.4s, v23.4s, v4.4s - and v0.16b, v4.16b, v24.16b - sshr v0.4s, v0.4s, #31 - sqadd v24.4s, v24.4s, v0.4s - srshl v24.4s, v24.4s, v4.4s - and v1.16b, v4.16b, v25.16b - sshr v1.4s, v1.4s, #31 - sqadd v25.4s, v25.4s, v1.4s - srshl v25.4s, v25.4s, v4.4s - and v2.16b, v4.16b, v26.16b - sshr v2.4s, v2.4s, #31 - sqadd v26.4s, v26.4s, v2.4s - srshl v26.4s, v26.4s, v4.4s - and v3.16b, v4.16b, v27.16b - sshr v3.4s, v3.4s, #31 - sqadd v27.4s, v27.4s, v3.4s - srshl v27.4s, v27.4s, v4.4s - and v0.16b, v4.16b, v28.16b - sshr v0.4s, v0.4s, #31 - sqadd v28.4s, v28.4s, v0.4s - srshl v28.4s, v28.4s, v4.4s - and v1.16b, v4.16b, v29.16b - sshr v1.4s, v1.4s, #31 - sqadd v29.4s, v29.4s, v1.4s - srshl v29.4s, v29.4s, v4.4s - and v2.16b, v4.16b, v30.16b - sshr v2.4s, v2.4s, #31 - sqadd v30.4s, v30.4s, v2.4s - srshl v30.4s, v30.4s, v4.4s - and v3.16b, v4.16b, v31.16b - sshr v3.4s, v3.4s, #31 - sqadd v31.4s, v31.4s, v3.4s - srshl v31.4s, v31.4s, v4.4s - - dup v5.4s, w16 - add v8.4s, v8.4s, v5.4s - add v9.4s, v9.4s, v5.4s - add v10.4s, v10.4s, v5.4s - add v11.4s, v11.4s, v5.4s - add v12.4s, v12.4s, v5.4s - add v13.4s, v13.4s, v5.4s - add v14.4s, v14.4s, v5.4s - add v15.4s, v15.4s, v5.4s - add v16.4s, v16.4s, v5.4s - add v17.4s, v17.4s, v5.4s - add v18.4s, v18.4s, v5.4s - add v19.4s, v19.4s, v5.4s - add v20.4s, v20.4s, v5.4s - add v21.4s, v21.4s, v5.4s - add v22.4s, v22.4s, v5.4s - add v23.4s, v23.4s, v5.4s - add v24.4s, v24.4s, v5.4s - add v25.4s, v25.4s, v5.4s - add v26.4s, v26.4s, v5.4s - add v27.4s, v27.4s, v5.4s - add v28.4s, v28.4s, v5.4s - add v29.4s, v29.4s, v5.4s - add v30.4s, v30.4s, v5.4s - add v31.4s, v31.4s, v5.4s - - dup v0.4s, w8 - smax v8.4s, v8.4s, v0.4s - smax v9.4s, v9.4s, v0.4s - smax v10.4s, v10.4s, v0.4s - smax v11.4s, v11.4s, v0.4s - smax v12.4s, v12.4s, v0.4s - smax v13.4s, v13.4s, v0.4s - smax v14.4s, v14.4s, v0.4s - smax v15.4s, v15.4s, v0.4s - smax v16.4s, v16.4s, v0.4s - smax v17.4s, v17.4s, v0.4s - smax v18.4s, v18.4s, v0.4s - smax v19.4s, v19.4s, v0.4s - smax v20.4s, v20.4s, v0.4s - smax v21.4s, v21.4s, v0.4s - smax v22.4s, v22.4s, v0.4s - smax v23.4s, v23.4s, v0.4s - smax v24.4s, v24.4s, v0.4s - smax v25.4s, v25.4s, v0.4s - smax v26.4s, v26.4s, v0.4s - smax v27.4s, v27.4s, v0.4s - smax v28.4s, v28.4s, v0.4s - smax v29.4s, v29.4s, v0.4s - smax v30.4s, v30.4s, v0.4s - smax v31.4s, v31.4s, v0.4s - - dup v1.4s, w9 - smin v8.4s, v8.4s, v1.4s - smin v9.4s, v9.4s, v1.4s - smin v10.4s, v10.4s, v1.4s - smin v11.4s, v11.4s, v1.4s - smin v12.4s, v12.4s, v1.4s - smin v13.4s, v13.4s, v1.4s - smin v14.4s, v14.4s, v1.4s - smin v15.4s, v15.4s, v1.4s - smin v16.4s, v16.4s, v1.4s - smin v17.4s, v17.4s, v1.4s - smin v18.4s, v18.4s, v1.4s - smin v19.4s, v19.4s, v1.4s - smin v20.4s, v20.4s, v1.4s - smin v21.4s, v21.4s, v1.4s - smin v22.4s, v22.4s, v1.4s - smin v23.4s, v23.4s, v1.4s - smin v24.4s, v24.4s, v1.4s - smin v25.4s, v25.4s, v1.4s - smin v26.4s, v26.4s, v1.4s - smin v27.4s, v27.4s, v1.4s - smin v28.4s, v28.4s, v1.4s - smin v29.4s, v29.4s, v1.4s - smin v30.4s, v30.4s, v1.4s - smin v31.4s, v31.4s, v1.4s - - sqxtn v6.4h, v8.4s - sqxtn2 v6.8h, v9.4s - sqxtn v0.8b, v6.8h - sqxtn v7.4h, v10.4s - sqxtn2 v7.8h, v11.4s - sqxtn2 v0.16b, v7.8h - - sqxtn v6.4h, v12.4s - sqxtn2 v6.8h, v13.4s - sqxtn v1.8b, v6.8h - sqxtn v7.4h, v14.4s - sqxtn2 v7.8h, v15.4s - sqxtn2 v1.16b, v7.8h - - sqxtn v6.4h, v16.4s - sqxtn2 v6.8h, v17.4s - sqxtn v2.8b, v6.8h - sqxtn v7.4h, v18.4s - sqxtn2 v7.8h, v19.4s - sqxtn2 v2.16b, v7.8h - - sqxtn v6.4h, v20.4s - sqxtn2 v6.8h, v21.4s - sqxtn v3.8b, v6.8h - sqxtn v7.4h, v22.4s - sqxtn2 v7.8h, v23.4s - sqxtn2 v3.16b, v7.8h - - sqxtn v6.4h, v24.4s - sqxtn2 v6.8h, v25.4s - sqxtn v4.8b, v6.8h - sqxtn v7.4h, v26.4s - sqxtn2 v7.8h, v27.4s - sqxtn2 v4.16b, v7.8h - - sqxtn v6.4h, v28.4s - sqxtn2 v6.8h, v29.4s - sqxtn v5.8b, v6.8h - sqxtn v7.4h, v30.4s - sqxtn2 v7.8h, v31.4s - sqxtn2 v5.16b, v7.8h - // prefetching is not prefered while writing results in spite of cache missings - // you could try prfm pstl2strm - WriteStart: - cmp x6, #1 - beq Write1 - cmp x6, #2 - beq Write2 - cmp x6, #3 - beq Write3 - b Write4 - Write1: - st1 {v0.b}[0], [x11], x7 - st1 {v0.b}[4], [x11], x7 - st1 {v0.b}[8], [x11], x7 - st1 {v0.b}[12], [x11], x7 - st1 {v1.b}[0], [x11], x7 - st1 {v1.b}[4], [x11], x7 - st1 {v1.b}[8], [x11], x7 - st1 {v1.b}[12], [x11], x7 - st1 {v2.b}[0], [x11], x7 - st1 {v2.b}[4], [x11], x7 - st1 {v2.b}[8], [x11], x7 - st1 {v2.b}[12], [x11], x7 - st1 {v3.b}[0], [x11], x7 - st1 {v3.b}[4], [x11], x7 - st1 {v3.b}[8], [x11], x7 - st1 {v3.b}[12], [x11], x7 - st1 {v4.b}[0], [x11], x7 - st1 {v4.b}[4], [x11], x7 - st1 {v4.b}[8], [x11], x7 - st1 {v4.b}[12], [x11], x7 - st1 {v5.b}[0], [x11], x7 - st1 {v5.b}[4], [x11], x7 - st1 {v5.b}[8], [x11], x7 - st1 {v5.b}[12], [x11] - add x0, x0, #1 - b WriteEnd - Write2: - st1 {v0.h}[0], [x11], x7 - st1 {v0.h}[2], [x11], x7 - st1 {v0.h}[4], [x11], x7 - st1 {v0.h}[6], [x11], x7 - st1 {v1.h}[0], [x11], x7 - st1 {v1.h}[2], [x11], x7 - st1 {v1.h}[4], [x11], x7 - st1 {v1.h}[6], [x11], x7 - st1 {v2.h}[0], [x11], x7 - st1 {v2.h}[2], [x11], x7 - st1 {v2.h}[4], [x11], x7 - st1 {v2.h}[6], [x11], x7 - st1 {v3.h}[0], [x11], x7 - st1 {v3.h}[2], [x11], x7 - st1 {v3.h}[4], [x11], x7 - st1 {v3.h}[6], [x11], x7 - st1 {v4.h}[0], [x11], x7 - st1 {v4.h}[2], [x11], x7 - st1 {v4.h}[4], [x11], x7 - st1 {v4.h}[6], [x11], x7 - st1 {v5.h}[0], [x11], x7 - st1 {v5.h}[2], [x11], x7 - st1 {v5.h}[4], [x11], x7 - st1 {v5.h}[6], [x11] - add x0, x0, #2 - b WriteEnd - Write3: - add x14, x11, #2 - st1 {v0.h}[0], [x11], x7 - st1 {v0.b}[2], [x14], x7 - st1 {v0.h}[2], [x11], x7 - st1 {v0.b}[6], [x14], x7 - st1 {v0.h}[4], [x11], x7 - st1 {v0.b}[10], [x14], x7 - st1 {v0.h}[6], [x11], x7 - st1 {v0.b}[14], [x14], x7 - st1 {v1.h}[0], [x11], x7 - st1 {v1.b}[2], [x14], x7 - st1 {v1.h}[2], [x11], x7 - st1 {v1.b}[6], [x14], x7 - st1 {v1.h}[4], [x11], x7 - st1 {v1.b}[10], [x14], x7 - st1 {v1.h}[6], [x11], x7 - st1 {v1.b}[14], [x14], x7 - st1 {v2.h}[0], [x11], x7 - st1 {v2.b}[2], [x14], x7 - st1 {v2.h}[2], [x11], x7 - st1 {v2.b}[6], [x14], x7 - st1 {v2.h}[4], [x11], x7 - st1 {v2.b}[10], [x14], x7 - st1 {v2.h}[6], [x11], x7 - st1 {v2.b}[14], [x14], x7 - st1 {v3.h}[0], [x11], x7 - st1 {v3.b}[2], [x14], x7 - st1 {v3.h}[2], [x11], x7 - st1 {v3.b}[6], [x14], x7 - st1 {v3.h}[4], [x11], x7 - st1 {v3.b}[10], [x14], x7 - st1 {v3.h}[6], [x11], x7 - st1 {v3.b}[14], [x14], x7 - st1 {v4.h}[0], [x11], x7 - st1 {v4.b}[2], [x14], x7 - st1 {v4.h}[2], [x11], x7 - st1 {v4.b}[6], [x14], x7 - st1 {v4.h}[4], [x11], x7 - st1 {v4.b}[10], [x14], x7 - st1 {v4.h}[6], [x11], x7 - st1 {v4.b}[14], [x14], x7 - st1 {v5.h}[0], [x11], x7 - st1 {v5.b}[2], [x14], x7 - st1 {v5.h}[2], [x11], x7 - st1 {v5.b}[6], [x14], x7 - st1 {v5.h}[4], [x11], x7 - st1 {v5.b}[10], [x14], x7 - st1 {v5.h}[6], [x11], x7 - st1 {v5.b}[14], [x14], x7 - add x0, x0, #3 - b WriteEnd - Write4: - st1 {v0.s}[0], [x11], x7 - st1 {v0.s}[1], [x11], x7 - st1 {v0.s}[2], [x11], x7 - st1 {v0.s}[3], [x11], x7 - st1 {v1.s}[0], [x11], x7 - st1 {v1.s}[1], [x11], x7 - st1 {v1.s}[2], [x11], x7 - st1 {v1.s}[3], [x11], x7 - st1 {v2.s}[0], [x11], x7 - st1 {v2.s}[1], [x11], x7 - st1 {v2.s}[2], [x11], x7 - st1 {v2.s}[3], [x11], x7 - st1 {v3.s}[0], [x11], x7 - st1 {v3.s}[1], [x11], x7 - st1 {v3.s}[2], [x11], x7 - st1 {v3.s}[3], [x11], x7 - st1 {v4.s}[0], [x11], x7 - st1 {v4.s}[1], [x11], x7 - st1 {v4.s}[2], [x11], x7 - st1 {v4.s}[3], [x11], x7 - st1 {v5.s}[0], [x11], x7 - st1 {v5.s}[1], [x11], x7 - st1 {v5.s}[2], [x11], x7 - st1 {v5.s}[3], [x11] - add x0, x0, #4 - - WriteEnd: - - subs x10, x10, #1 - bne LoopKsize - - subs x6, x6, #4 - cbz x21, NoChannelForward - cbz x20, NoSumForward - add x15, x15, #16 - NoSumForward: - add x17, x17, #16 - add x18, x18, #16 - add x19, x19, #16 - NoChannelForward: - cbz x3, NoStepFowrard - add x3, x3, #16 - NoStepFowrard: - bgt LoopOc - - sub sp, sp, #176 - ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 - ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 - ldp x19, x20, [sp], #16 - ldp x21, x22, [sp], #16 - ldp x23, x24, [sp], #16 - ret -#endif diff --git a/mindspore/lite/nnacl/int8/common_func_int8.h b/mindspore/lite/nnacl/int8/common_func_int8.h index 55b646918e..cd3ed70b02 100644 --- a/mindspore/lite/nnacl/int8/common_func_int8.h +++ b/mindspore/lite/nnacl/int8/common_func_int8.h @@ -62,10 +62,6 @@ int32x4_t ClacScaledInput(int32x4_t input, int32x4_t left_shift_result_vec, int3 #endif #ifdef ENABLE_ARM32 -void IndirectGemmInt8_2x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, - size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, - size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before, - int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height, int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, int32_t acc_min, @@ -76,10 +72,6 @@ void ConvDw3x3Int8BorderPixel(int8_t *dst, const int8_t *src, const int16_t *wei void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res, size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi); -void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize, - size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min, - size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before, - int32_t *shift_after, size_t asymmetric, size_t per_channel, size_t per_channel_offset); void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias, int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp, int32_t *out_multiplier, int32_t *left_shift, int32_t *right_shift, diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index 5bfb94fbf6..07dbeeb25f 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -811,38 +811,25 @@ void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int return; } -void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, - int32_t *multiplier, ConvParameter *conv_param) { - int is_per_channel = conv_param->conv_quant_arg_.filter_arg_num_ != 1 ? true : false; -#ifdef ENABLE_ARM32 - MatmulInt8Neon32(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, - conv_param->output_channel_, is_per_channel); -#else - MatMulInt8_4x2_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, - left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - is_per_channel); -#endif - return; -} - void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, - int32_t *multiplier, ConvParameter *conv_param) { + int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp) { int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1; -#ifdef ENABLE_ARM64 - MatmulInt8Neon64(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum, - bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row, col, - conv_param->output_channel_, is_per_oc); +#ifdef ENABLE_ARM32 + MatmulInt8Neon32Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, + conv_param->output_channel_, is_per_oc, filter_zp); +#elif ENABLE_ARM64 + MatmulInt8Neon64Opt(packed_input, packed_weight, dst, UP_ROUND(row, C4NUM), UP_ROUND(col, C4NUM), deep16, input_sum, + bias, conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, row, + col, conv_param->output_channel_, is_per_oc, filter_zp); #else - MatMulInt8_16x4_r(packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, - left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_, - conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], - is_per_oc); + MatmulInt8Opt(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias, + conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], + conv_param->conv_quant_arg_.output_quant_args_[0].zp_, multiplier, left_shift, right_shift, + conv_param->output_channel_, is_per_oc, filter_zp); #endif return; } diff --git a/mindspore/lite/nnacl/int8/conv_int8.h b/mindspore/lite/nnacl/int8/conv_int8.h index 2792122002..1f84f86ae3 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.h +++ b/mindspore/lite/nnacl/int8/conv_int8.h @@ -43,13 +43,10 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i size_t plane_size, ConvParameter *conv_param); void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, - int32_t *multiplier, ConvParameter *conv_param); + int32_t *multiplier, ConvParameter *conv_param, int32_t *filter_zp); void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp); -void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum, - const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift, - int32_t *multiplier, ConvParameter *conv_param); // int8 convolution 3x3 void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data, diff --git a/mindspore/lite/nnacl/int8/matmul_int8.c b/mindspore/lite/nnacl/int8/matmul_int8.c index fdfcce39c6..abf898a8db 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.c +++ b/mindspore/lite/nnacl/int8/matmul_int8.c @@ -250,6 +250,41 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, return; } +#ifndef ENABLE_ARM +void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, + const int *bias, int mini, int maxi, int out_zp, int32_t *multiplier, int32_t *left_shift, + int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp) { + int col_tile = C4NUM; + /* support per-layer && weight per-channel */ + /* row4x16-major * row16x2-major => (int8)row-major*/ + for (int r = 0; r < row; r++) { + for (int c = 0; c < col; c++) { + int r4div = r / C4NUM, r4mod = r % C4NUM; + int c4div = c / col_tile, c4mod = c % col_tile; + size_t ci = r * stride + c; + int32_t value = 0; + for (int d = 0; d < deep16; d++) { + int d16div = d / C16NUM, d16mod = d % C16NUM; + size_t ai = r4div * deep16 * C4NUM + d16div * C4NUM * C16NUM + r4mod * C16NUM + d16mod; + size_t bi = c4div * deep16 * col_tile + d16div * col_tile * C16NUM + c4mod * C16NUM + d16mod; + value = value + a[ai] * b[bi]; + } + int32_t cur_input_sum = filter_peroc ? a_sums[r] * filter_zp[c] : a_sums[r]; + value -= cur_input_sum; + value += bias[c]; + int32_t cur_left_shift = filter_peroc ? left_shift[c] : left_shift[0]; + int32_t cur_right_shift = filter_peroc ? right_shift[c] : right_shift[0]; + int32_t cur_multiplier = filter_peroc ? multiplier[c] : multiplier[0]; + value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + out_zp; + value = MSMIN(maxi, value); + value = MSMAX(mini, value); + dst[ci] = (int8_t)value; + } + } + return; +} +#endif + void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4, size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, diff --git a/mindspore/lite/nnacl/int8/matmul_int8.h b/mindspore/lite/nnacl/int8/matmul_int8.h index 2be58e4370..2ea3b303e4 100644 --- a/mindspore/lite/nnacl/int8/matmul_int8.h +++ b/mindspore/lite/nnacl/int8/matmul_int8.h @@ -60,6 +60,9 @@ void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift, int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi, size_t per_channel, int32_t *filter_zp); +void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, + const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift, + int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp); #ifdef ENABLE_ARM64 void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums, @@ -68,11 +71,18 @@ void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, i void MatMulR4Int8Neon64(const int8_t *a, const int8_t *b, int32_t *dst, int row4, int col4, int deep16, const int *input_sum, const int *bias); +void MatmulInt8Neon64Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, + const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, + int32_t *left_shift, int32_t *right_shift, int row, int col, int stride, int filter_peroc, + int32_t *filter_zp); #endif #ifdef ENABLE_ARM32 void MatmulInt8Neon32(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel); +void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, const int *a_sums, + const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, + int32_t *left_shift, int32_t *right_shift, int stride, int filter_peroc, int32_t *filter_zp); #endif #ifdef __cplusplus } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc index 4e459545c8..dda5957e87 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.cc @@ -92,27 +92,19 @@ int Convolution1x1Int8OcOptPre(void *cdata, int task_id) { } int Convolution1x1Int8CPUKernel::OcRun(int task_id) { -#ifdef ENABLE_ARM32 - return RunArm32Oc(task_id); -#else if (support_optimize_) { return RunArm64OptOc(task_id); } else { - return RunArm64Oc(task_id); + return RunArmOc(task_id); } -#endif } int Convolution1x1Int8CPUKernel::HwRun(int task_id) { -#ifdef ENABLE_ARM32 - return RunArm32Hw(task_id); -#else if (support_optimize_) { return RunArm64OptHw(task_id); } else { - return RunArm64Hw(task_id); + return RunArmHw(task_id); } -#endif } int Convolution1x1Int8CPUKernel::InitRunBuf() { @@ -124,6 +116,7 @@ int Convolution1x1Int8CPUKernel::InitRunBuf() { size_t size = support_optimize_ ? UP_ROUND(matmul_param_->row_, C8NUM) * UP_ROUND(matmul_param_->deep_, C4NUM) : UP_ROUND(matmul_param_->row_, C4NUM) * UP_ROUND(matmul_param_->deep_, C16NUM); + packed_input_ = reinterpret_cast(ctx_->allocator->Malloc(size * sizeof(int8_t))); if (packed_input_ == nullptr) { MS_LOG(ERROR) << "conv1x1 int8 Malloc packed_input_ error!"; @@ -333,8 +326,8 @@ int Convolution1x1Int8CPUKernel::InitParam() { matmul_param_->deep_4_ = UP_ROUND(matmul_param_->deep_, C4NUM); matmul_param_->deep_16_ = UP_ROUND(matmul_param_->deep_, C16NUM); - int row_pack_count = 0; - int col_pack_count = 0; + int row_pack_count; + int col_pack_count; #ifdef ENABLE_ARM32 row_pack_count = C4NUM; @@ -350,15 +343,7 @@ int Convolution1x1Int8CPUKernel::InitParam() { #endif /* init input sum size */ - if (support_optimize_) { - input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count); - } else { - if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) { - input_sum_size_ = UP_ROUND(matmul_param_->col_, col_pack_count) * UP_ROUND(matmul_param_->row_, row_pack_count); - } else { - input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count); - } - } + input_sum_size_ = UP_ROUND(matmul_param_->row_, row_pack_count); if (pre_trans_input_) { input_ptr_ = reinterpret_cast(malloc(matmul_param_->row_ * matmul_param_->deep_ * sizeof(int8_t))); @@ -404,7 +389,7 @@ void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_out return; } -int Convolution1x1Int8CPUKernel::RunArm64Hw(int task_id) { +int Convolution1x1Int8CPUKernel::RunArmHw(int task_id) { int cur_stride = thread_stride_hw_ * C4NUM; int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C4NUM; int cur_hw = MSMIN(cur_stride, res_stride); @@ -415,51 +400,20 @@ int Convolution1x1Int8CPUKernel::RunArm64Hw(int task_id) { int8_t *hw_in = input_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->input_channel_; int8_t *hw_out = output_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->output_channel_; int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->deep_16_; - int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->col_4_ - : input_sum_ + task_id * thread_stride_hw_ * C4NUM; + int32_t *hw_input_sum = input_sum_ + task_id * thread_stride_hw_ * C4NUM; RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_); if (filter_peroc_) { - PackInputSum16x4PerChannel(hw_packed_in, hw_input_sum, filter_zp_ptr_, cur_hw, matmul_param_->deep_, - matmul_param_->col_); + PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, 1, UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_); } else { PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_); } Conv1x1Int8(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast(bias_data_), cur_hw, - matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_); - return RET_OK; -} - -int Convolution1x1Int8CPUKernel::RunArm32Hw(int task_id) { - int cur_stride = thread_stride_hw_ * C4NUM; - int res_stride = matmul_param_->row_ - task_id * thread_stride_hw_ * C4NUM; - int cur_hw = MSMIN(cur_stride, res_stride); - if (cur_hw <= 0) { - return RET_OK; - } - - int8_t *hw_in = input_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->input_channel_; - int8_t *hw_out = output_ptr_ + task_id * thread_stride_hw_ * C4NUM * conv_param_->output_channel_; - int8_t *hw_packed_in = packed_input_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->deep_16_; - int32_t *hw_input_sum = filter_peroc_ ? input_sum_ + task_id * thread_stride_hw_ * C4NUM * matmul_param_->col_2_ - : input_sum_ + task_id * thread_stride_hw_ * C4NUM; - - RowMajor2Row16x4MajorInt8(hw_in, hw_packed_in, cur_hw, matmul_param_->deep_); - - if (filter_peroc_) { - PackInputSum16x4PerChannelArm32(hw_packed_in, hw_input_sum, filter_zp_ptr_, cur_hw, conv_param_->input_channel_, - conv_param_->output_channel_); - } else { - PackInputSum16x4PerLayer(hw_packed_in, hw_input_sum, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, - UP_ROUND(cur_hw, C4NUM), matmul_param_->deep_16_); - } - - Conv1x1Int8Arm32(hw_packed_in, packed_weight_, hw_out, hw_input_sum, reinterpret_cast(bias_data_), cur_hw, - matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_); - + matmul_param_->col_, matmul_param_->deep_16_, left_shift_, right_shift_, multiplier_, conv_param_, + filter_zp_ptr_); return RET_OK; } @@ -489,26 +443,6 @@ int Convolution1x1Int8CPUKernel::RunArm64OptHw(int task_id) { return RET_OK; } -int Convolution1x1Int8CPUKernel::RunArm32Oc(int task_id) { - int stride = thread_stride_oc_ * C2NUM; - int cur_stride = task_id * stride; - int res_stride = matmul_param_->col_ - cur_stride; - int cur_oc = MSMIN(stride, res_stride); - if (cur_oc <= 0) { - return RET_OK; - } - - int32_t *cur_input_sum = filter_peroc_ ? input_sum_ + cur_stride * matmul_param_->row_4_ : input_sum_; - int32_t *cur_left_shift = filter_peroc_ ? left_shift_ + cur_stride : conv_param_->conv_quant_arg_.left_shift_; - int32_t *cur_right_shift = filter_peroc_ ? right_shift_ + cur_stride : conv_param_->conv_quant_arg_.right_shift_; - int32_t *cur_multiplier = filter_peroc_ ? multiplier_ + cur_stride : conv_param_->conv_quant_arg_.quant_multiplier_; - - Conv1x1Int8Arm32(packed_input_, packed_weight_ + cur_stride * matmul_param_->deep_16_, output_ptr_ + cur_stride, - cur_input_sum, reinterpret_cast(bias_data_) + cur_stride, matmul_param_->row_, cur_oc, - matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_); - return RET_OK; -} - int Convolution1x1Int8CPUKernel::RunArm64OptOc(int task_id) { int stride = thread_stride_oc_ * C16NUM; int cur_stride = task_id * stride; @@ -531,8 +465,13 @@ int Convolution1x1Int8CPUKernel::RunArm64OptOc(int task_id) { return RET_OK; } -int Convolution1x1Int8CPUKernel::RunArm64Oc(int task_id) { - int stride = thread_stride_oc_ * C4NUM; +int Convolution1x1Int8CPUKernel::RunArmOc(int task_id) { +#ifdef ENABLE_ARM32 + int col_tile = C2NUM; +#else + int col_tile = C4NUM; +#endif + int stride = thread_stride_oc_ * col_tile; int cur_stride = task_id * stride; int res_stride = matmul_param_->col_ - cur_stride; int cur_oc = MSMIN(stride, res_stride); @@ -540,14 +479,14 @@ int Convolution1x1Int8CPUKernel::RunArm64Oc(int task_id) { return RET_OK; } - int32_t *cur_input_sum = filter_peroc_ ? input_sum_ + cur_stride * matmul_param_->row_4_ : input_sum_; int32_t *cur_left_shift = filter_peroc_ ? left_shift_ + cur_stride : conv_param_->conv_quant_arg_.left_shift_; int32_t *cur_right_shift = filter_peroc_ ? right_shift_ + cur_stride : conv_param_->conv_quant_arg_.right_shift_; int32_t *cur_multiplier = filter_peroc_ ? multiplier_ + cur_stride : conv_param_->conv_quant_arg_.quant_multiplier_; + int32_t *cur_zp = filter_peroc_ ? filter_zp_ptr_ + cur_stride : filter_zp_ptr_; Conv1x1Int8(packed_input_, packed_weight_ + cur_stride * matmul_param_->deep_16_, output_ptr_ + cur_stride, - cur_input_sum, reinterpret_cast(bias_data_) + cur_stride, matmul_param_->row_, cur_oc, - matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_); + input_sum_, reinterpret_cast(bias_data_) + cur_stride, matmul_param_->row_, cur_oc, + matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_, cur_zp); return RET_OK; } @@ -592,7 +531,12 @@ int Convolution1x1Int8CPUKernel::Run() { ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8OcOptPre, this, thread_count_hw_); } else { RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_); - PackInputSum16x4Int8(packed_input_, input_sum_, filter_zp_ptr_, conv_param_); + if (filter_peroc_) { + PackInputSum16x4PerLayer(packed_input_, input_sum_, 1, matmul_param_->row_4_, matmul_param_->deep_16_); + } else { + PackInputSum16x4PerLayer(packed_input_, input_sum_, conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_, + matmul_param_->row_4_, matmul_param_->deep_16_); + } } /* matmul parallel by oc */ error_code = ParallelLaunch(this->context_->thread_pool_, Convolution1x1Int8OcRun, this, thread_count_oc_); diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h index 8ae9c41464..3eda1e9e7a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/convolution_1x1_int8.h @@ -50,11 +50,9 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel { int OcOptPre(int task_id); private: - int RunArm32Oc(int task_id); - int RunArm64Oc(int task_id); + int RunArmOc(int task_id); int RunArm64OptOc(int task_id); - int RunArm32Hw(int task_id); - int RunArm64Hw(int task_id); + int RunArmHw(int task_id); int RunArm64OptHw(int task_id); private: diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc b/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc index ce8ff62226..eee84c75b7 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/opt_op_handler.cc @@ -21,11 +21,6 @@ #ifdef __cplusplus extern "C" { #endif -extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, - size_t ksize, size_t ic4, size_t output_channel, size_t offset, - const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, - int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, - size_t asymmetric, size_t per_channel, size_t per_channel_offset); extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, const int *input_sum, const int *bias); @@ -38,15 +33,6 @@ extern void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, size_ int *left_shift, int *right_shift, size_t stride, size_t peroc, int *filter_zp); #ifdef ENABLE_ARM64 -void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, - size_t ksize, size_t ic4, size_t output_channel, size_t offset, - const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, - int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, - size_t asymmetric, size_t per_channel, size_t per_channel_offset) { - return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min, - act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel, - per_channel_offset); -} void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16, const int *input_sum, const int *bias) {