diff --git a/mindspore/lite/nnacl/assembly/arm32/MatmulInt8Opt.S b/mindspore/lite/nnacl/assembly/arm32/MatmulInt8Opt.S index a46d615b5b..58edd8e9fe 100644 --- a/mindspore/lite/nnacl/assembly/arm32/MatmulInt8Opt.S +++ b/mindspore/lite/nnacl/assembly/arm32/MatmulInt8Opt.S @@ -8,10 +8,10 @@ .type MatmulInt8Neon32Opt, %function #endif -//void MatmulInt8Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, -// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, -// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel, -// int *filter_zp); +//void MatmulInt8Neon32Opt(const int8_t *a, const int8_t *b, int8_t *dst, int row, int col, int deep16, +// const int *input_sums, const int *weight_bias, int act_min, int act_max, int out_zp, +// int *multiplier, int *left_shift, int *right_shift, int stride, int per_channel, +// int *filter_zp); // #-52: a, #-48: b, #-44: dst, #-40: row // #0: col, #4: deep16, #8: input_sums, #12: weight_bias, #16: act_min, #20: act_max, #24: out_zp // #28: multiplier, #32: left_shift, #36: right_shift, #40: stride, #44: per_channel, #48: filter_zp @@ -21,6 +21,12 @@ MatmulInt8Neon32Opt: vpush {q4-q7} add sp, sp, #116 + ldr r0, [sp, #-52] // load a ptr + vld1.8 {d0, d1, d2, d3}, [r0]! + + ldr r1, [sp, #-48] // load b ptr + vld1.8 {d8, d9, d10, d11}, [r1]! + ldr r4, [sp] // col ldr r7, [sp, #40] // output stride mov r8, #0 // output channels offset @@ -32,15 +38,14 @@ L1: cmp r4, #0 // if at the end of col ble End1 - ldr r0, [sp, #-52] // reload a ptr ldr r3, [sp, #-40] // reset row counter ldr r6, [sp, #8] // reload intpu_sums ptr if per_tensor L2: cmp r3, #0 // if at the end of row ble End2 - ldr r1, [sp, #-48] // reload b ptr ldr r5, [sp, #4] // reset deep16 + sub r5, r5, #16 vmov.i32 q6, #0 vmov.i32 q7, #0 vmov.i32 q8, #0 @@ -49,12 +54,44 @@ L2: vmov.i32 q11, #0 vmov.i32 q12, #0 vmov.i32 q13, #0 -L3: + cmp r5, #0 - beq End3 + beq L3Tail +L3: + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 + vld1.8 {d0, d1, d2, d3}, [r0]! + + vpadal.s16 q6, q14 + vpadal.s16 q7, q2 + vpadal.s16 q8, q15 + vpadal.s16 q9, q3 + vmull.s8 q14, d0, d8 + vmull.s8 q2, d0, d10 + vmull.s8 q15, d2, d8 + vmull.s8 q3, d2, d10 + vmlal.s8 q14, d1, d9 + vmlal.s8 q2, d1, d11 + vmlal.s8 q15, d3, d9 + vmlal.s8 q3, d3, d11 vld1.8 {d0, d1, d2, d3}, [r0]! + + vpadal.s16 q10, q14 vld1.8 {d8, d9, d10, d11}, [r1]! + vpadal.s16 q11, q2 + vpadal.s16 q12, q15 + vpadal.s16 q13, q3 + sub r5, r5, #16 // deep16 -= 16 + cmp r5, #0 + bgt L3 +L3Tail: vmull.s8 q14, d0, d8 vmull.s8 q2, d0, d10 vmull.s8 q15, d2, d8 @@ -63,13 +100,13 @@ L3: vmlal.s8 q2, d1, d11 vmlal.s8 q15, d3, d9 vmlal.s8 q3, d3, d11 + vld1.8 {d0, d1, d2, d3}, [r0]! vpadal.s16 q6, q14 vpadal.s16 q7, q2 vpadal.s16 q8, q15 vpadal.s16 q9, q3 - vld1.8 {d0, d1, d2, d3}, [r0]! vmull.s8 q14, d0, d8 vmull.s8 q2, d0, d10 vmull.s8 q15, d2, d8 @@ -83,10 +120,7 @@ L3: vpadal.s16 q11, q2 vpadal.s16 q12, q15 vpadal.s16 q13, q3 - sub r5, r5, #16 // deep16 -= 16 - b L3 -End3: vpadd.i32 d0, d12, d13 vpadd.i32 d1, d14, d15 vpadd.i32 d2, d16, d17 @@ -101,7 +135,26 @@ End3: vpadd.i32 d30, d4, d5 vpadd.i32 d31, d6, d7 - // Add weight_bias + cmp r3, #4 + ble LAST_ROW + + vld1.8 {d0, d1, d2, d3}, [r0]! + ldr r1, [sp, #-48] // reload b ptr + vld1.8 {d8, d9, d10, d11}, [r1]! + b AddWeightBias + +LAST_ROW: + ldr r0, [sp, #-52] // reload a ptr + vld1.8 {d0, d1, d2, d3}, [r0]! + ldr r1, [sp, #-48] // reload b ptr + ldr r9, [sp, #4] + mov r10, #2 + mul r9, r9, r10 // the stride of b + add r1, r1, r9 // b ptr + stride + str r1, [sp, #-48] + vld1.8 {d8, d9, d10, d11}, [r1]! + +AddWeightBias: ldr r9, [sp, #12] // reload weight_bias ptr add r9, r9, r8 vld1.32 {d26}, [r9]! @@ -148,9 +201,9 @@ PerTensor: vshr.s32 q6, q6, #31 vqadd.s32 q14, q14, q6 vrshl.s32 q14, q14, q7 - vand q5, q7, q15 - vshr.s32 q5, q5, #31 - vqadd.s32 q15, q15, q5 + vand q3, q7, q15 + vshr.s32 q3, q3, #31 + vqadd.s32 q15, q15, q3 vrshl.s32 q15, q15, q7 b AddDstZP @@ -214,9 +267,9 @@ PerChannel: AddDstZP: // Add the destination zero point ldr r10, [sp, #24] - vdup.32 q4, r10 - vadd.i32 q14, q14, q4 - vadd.i32 q15, q15, q4 + vdup.32 q2, r10 + vadd.i32 q14, q14, q2 + vadd.i32 q15, q15, q2 // Apply the act_min bound ldr r10, [sp, #16] @@ -276,12 +329,6 @@ EndWrite: End2: sub r4, r4, #2 // b col counter -= 2 - ldr r1, [sp, #-48] // load b ptr - ldr r9, [sp, #4] - mov r10, #2 - mul r9, r9, r10 // the stride of b - add r1, r1, r9 // b ptr + stride - str r1, [sp, #-48] ldr r2, [sp, #-44] // load dst ptr add r2, r2, #2 // dst ptr + offset str r2, [sp, #-44]