diff --git a/mindspore/lite/nnacl/assembly/arm32/TiledC4MatmulFp32.S b/mindspore/lite/nnacl/assembly/arm32/TiledC4MatmulFp32.S new file mode 100644 index 0000000000..239ef022bb --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/TiledC4MatmulFp32.S @@ -0,0 +1,198 @@ +#ifdef ENABLE_ARM32 + .text + .align 5 + .global TiledC4MatmulFp32 +#ifndef __APPLE__ + .type TiledC4MatmulFp32, %function +#endif + +TiledC4MatmulFp32: +//void TiledC4MatmulFp32(float* dst, const float* src, const float* weight, size_t cal_num, size_t ic4, size_t oc4) +//x0: dst +//x1: src +//x2: weight +//x3: cal_num +//x4: ic4 +//x5: oc4 + +push {r4-r8, lr} +ldr r4, [sp, #24] +ldr r5, [sp, #28] +//step multi by sizeof(float) +mov r8, #4 +mul r3, r8, r3 + +vpush {q4-q7} + +LoopOc: + mov r6, r1 + mov r8, r0 + subs r7, r4, #1 + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + vld1.32 {q4, q5}, [r2]! + vld1.32 {q6, q7}, [r2]! + + vmul.f32 q8, q4, d0[0] + vmul.f32 q9, q4, d2[0] + vmul.f32 q10, q4, d4[0] + vmul.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vmla.f32 q11, q7, d7[1] + + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + + vmul.f32 q12, q4, d0[0] + vmul.f32 q13, q4, d2[0] + vmul.f32 q14, q4, d4[0] + vmul.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + beq LoopIcEnd + + subs r7, r7, #1 + + vld1.32 {q4, q5}, [r2]! + vld1.32 {q0, q1}, [r1]! + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q8, q4, d0[0] + vmla.f32 q9, q4, d2[0] + beq LoopIcEndHalf + + LoopIc: + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.32 {q6, q7}, [r2]! + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q11, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vld1.32 {q4, q5}, [r2]! + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q15, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q8, q4, d0[0] + vmla.f32 q9, q4, d2[0] + + subs r7, r7, #1 + bne LoopIc + LoopIcEndHalf: + vmla.f32 q10, q4, d4[0] + vmla.f32 q11, q4, d6[0] + + vmla.f32 q8, q5, d0[1] + vmla.f32 q9, q5, d2[1] + vld1.32 {q6, q7}, [r2]! + vmla.f32 q10, q5, d4[1] + vmla.f32 q11, q5, d6[1] + + vmla.f32 q8, q6, d1[0] + vmla.f32 q9, q6, d3[0] + vmla.f32 q10, q6, d5[0] + vmla.f32 q11, q6, d7[0] + + vmla.f32 q8, q7, d1[1] + vmla.f32 q9, q7, d3[1] + vmla.f32 q10, q7, d5[1] + vld1.32 {q0, q1}, [r1]! + vmla.f32 q11, q7, d7[1] + + vld1.32 {q2, q3}, [r1]! + + vmla.f32 q12, q4, d0[0] + vmla.f32 q13, q4, d2[0] + vmla.f32 q14, q4, d4[0] + vmla.f32 q15, q4, d6[0] + + vmla.f32 q12, q5, d0[1] + vmla.f32 q13, q5, d2[1] + vmla.f32 q14, q5, d4[1] + vmla.f32 q15, q5, d6[1] + + vmla.f32 q12, q6, d1[0] + vmla.f32 q13, q6, d3[0] + vmla.f32 q14, q6, d5[0] + vmla.f32 q15, q6, d7[0] + + vmla.f32 q12, q7, d1[1] + vmla.f32 q13, q7, d3[1] + vmla.f32 q14, q7, d5[1] + vmla.f32 q15, q7, d7[1] + LoopIcEnd: + vst1.32 {q8, q9}, [r0]! + vst1.32 {q10, q11}, [r0]! + vst1.32 {q12, q13}, [r0]! + vst1.32 {q14, q15}, [r0]! + mov r1, r6 + + subs r5, r5, #1 + add r0, r8, r3 + bne LoopOc + + vpop {q4-q7} + pop {r4-r8, pc} + +#endif diff --git a/mindspore/lite/nnacl/assembly/arm32/WinogradTransLeft.S b/mindspore/lite/nnacl/assembly/arm32/WinogradTransLeft.S new file mode 100644 index 0000000000..3ca05a5583 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/WinogradTransLeft.S @@ -0,0 +1,218 @@ +#ifdef ENABLE_ARM32 + + .text + .align 5 + .global WinogradTransLeft +#ifndef __APPLE__ + .type WinogradTransLeft, %function +#endif + +//void WinogradTransLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length +WinogradTransLeft: + push {r4-r11, lr} + ldr r4, [sp, #36] + ldr r5, [sp, #40] + ldr r6, [sp, #44] + + mov r8, #16 // 4 * sizeof(float) + mul r8, r6, r8 + mul r9, r3, r8 + sub r9, r9, r8 + add r7, r9, r8 // step for S + mov r10, #4 + mul r10, r4, r10 // step for B + +LoopH: + push {r0, r3} + LoopW: + push {r0, r1} + vmov.i32 q14, #0 + mov r11, r6 + InitZero: + vst1.32 {q14}, [r2]! + subs r11, r11, #1 + bne InitZero + + sub r2, r2, r8 + mov r12, r5 + + LoopKStart7: + cmp r12, #7 + blt LoopKStart4 + push {r3-r7} + LoopK7: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + vld1.32 {d2[0]}, [r1], r10 + vld1.32 {d2[1]}, [r1], r10 + vld1.32 {d3[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + add r4, r3, r7 + add r5, r4, r7 + add r6, r5, r7 + add r7, r6, r7 + + LoopLength7: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + vld1.32 {q12}, [r5]! + vmla.f32 q8, q12, d2[0] + vld1.32 {q13}, [r6]! + vmla.f32 q9, q13, d2[1] + vld1.32 {q12}, [r7]! + vmla.f32 q8, q12, d3[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength7 + + sub r2, r2, r8 + sub r12, r12, #7 + add r0, r7, r9 + vmov.32 r1, d30[0] + cmp r12, #7 + bge LoopK7 + + pop {r3-r7} + + LoopKStart4: + cmp r12, #4 + blt LoopKStart3 + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK4: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + add r4, r3, r7 + + LoopLength4: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength4 + + sub r2, r2, r8 + sub r12, r12, #4 + add r0, r4, r9 + vmov.32 r1, d30[0] + cmp r12, #4 + bge LoopK4 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart3: + cmp r12, #3 + blt LoopKStart + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK3: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r7 + add r3, r1, r7 + + LoopLength3: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength3 + + sub r2, r2, r8 + sub r12, r12, #3 + add r0, r3, r9 + vmov.32 r1, d30[0] + cmp r12, #3 + bge LoopK3 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart: + cmp r12, #0 + beq LoopKEnd + + LoopK: + vld1.32 {d30[0]}, [r1], r10 + + vdup.32 q15, d30[0] + mov r11, r6 + LoopLength: + vld1.32 {q0}, [r2] + vld1.32 {q1}, [r0]! + vmla.f32 q0, q1, q15 + + vst1.32 {q0}, [r2]! + subs r11, r11, #1 + bne LoopLength + subs r12, r12, #1 + + sub r2, r2, r8 + add r0, r0, r9 + bne LoopK + + LoopKEnd: + pop {r0, r1} + subs r3, r3, #1 + add r0, r0, r8 + add r2, r2, r8 + bne LoopW + + pop {r0, r3} + add r1, r1, #4 //sizeof(float) + subs r4, r4, #1 + bne LoopH + + pop {r4-r11, pc} + +#endif diff --git a/mindspore/lite/nnacl/assembly/arm32/WinogradTransRight.S b/mindspore/lite/nnacl/assembly/arm32/WinogradTransRight.S new file mode 100644 index 0000000000..4d1d172911 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm32/WinogradTransRight.S @@ -0,0 +1,208 @@ +#ifdef ENABLE_ARM32 + + .text + .align 5 + .global WinogradTransRight +#ifndef __APPLE__ + .type WinogradTransRight, %function +#endif + +//void WinogradTransRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length +WinogradTransRight: + push {r4-r11, lr} + ldr r4, [sp, #36] + ldr r5, [sp, #40] + ldr r6, [sp, #44] + + mov r8, #16 // 4 * sizeof(float) + mul r8, r6, r8 + mul r9, r5, r8 // step for S + mov r10, #4 + mul r10, r4, r10 // step for B + +LoopH: + push {r1, r3} + LoopW: + push {r0, r1} + vmov.i32 q14, #0 + mov r11, r6 + InitZero: + vst1.32 {q14}, [r2]! + subs r11, r11, #1 + bne InitZero + + sub r2, r2, r8 + mov r12, r5 + LoopKStart7: + cmp r12, #7 + blt LoopKStart4 + push {r3-r7} + LoopK7: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + vld1.32 {d2[0]}, [r1], r10 + vld1.32 {d2[1]}, [r1], r10 + vld1.32 {d3[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + add r4, r3, r8 + add r5, r4, r8 + add r6, r5, r8 + add r7, r6, r8 + LoopLength7: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + vld1.32 {q12}, [r5]! + vmla.f32 q8, q12, d2[0] + vld1.32 {q13}, [r6]! + vmla.f32 q9, q13, d2[1] + vld1.32 {q12}, [r7]! + vmla.f32 q8, q12, d3[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength7 + + sub r2, r2, r8 + sub r12, r12, #7 + mov r0, r7 + vmov.32 r1, d30[0] + cmp r12, #7 + bge LoopK7 + + pop {r3-r7} + + LoopKStart4: + cmp r12, #4 + blt LoopKStart3 + vmov.32 d30[1], r3 + vmov.32 d31[0], r4 + LoopK4: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + vld1.32 {d1[1]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + add r4, r3, r8 + + LoopLength4: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + vld1.32 {q13}, [r4]! + vmla.f32 q9, q13, d1[1] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength4 + + sub r2, r2, r8 + sub r12, r12, #4 + mov r0, r4 + vmov.32 r1, d30[0] + cmp r12, #4 + bge LoopK4 + + vmov.32 r3, d30[1] + vmov.32 r4, d31[0] + + LoopKStart3: + cmp r12, #3 + blt LoopKStart + vmov.32 d30[1], r3 + LoopK3: + vld1.32 {d0[0]}, [r1], r10 + vld1.32 {d0[1]}, [r1], r10 + vld1.32 {d1[0]}, [r1], r10 + mov r11, r6 + vmov.32 d30[0], r1 + + add r1, r0, r8 + add r3, r1, r8 + + LoopLength3: + vld1.32 {q8}, [r2] + vld1.32 {q12}, [r0]! + vmla.f32 q8, q12, d0[0] + vld1.32 {q13}, [r1]! + vmul.f32 q9, q13, d0[1] + vld1.32 {q12}, [r3]! + vmla.f32 q8, q12, d1[0] + + vadd.f32 q9, q8, q9 + vst1.32 {q9}, [r2]! + subs r11, r11, #1 + bne LoopLength3 + + sub r2, r2, r8 + sub r12, r12, #3 + mov r0, r3 + vmov.32 r1, d30[0] + cmp r12, #3 + bge LoopK3 + + vmov.32 r3, d30[1] + + LoopKStart: + cmp r12, #0 + beq LoopKEnd + LoopK: + vld1.32 {d30[0]}, [r1], r10 + vdup.32 q15, d30[0] + mov r11, r6 + LoopLength: + vld1.32 {q0}, [r2] + vld1.32 {q1}, [r0]! + vmla.f32 q0, q1, q15 + + vst1.32 {q0}, [r2]! + subs r11, r11, #1 + bne LoopLength + + subs r12, r12, #1 + sub r2, r2, r8 + bne LoopK + LoopKEnd: + pop {r0, r1} + subs r3, r3, #1 + add r2, r2, r8 + add r1, r1, #4 //sizeof(float) + bne LoopW + + pop {r1, r3} + add r0, r0, r9 + subs r4, r4, #1 + bne LoopH + + pop {r4-r11, pc} + +#endif diff --git a/mindspore/lite/nnacl/assembly/arm64/TiledC4MatmulFp32.S b/mindspore/lite/nnacl/assembly/arm64/TiledC4MatmulFp32.S new file mode 100644 index 0000000000..c964366975 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/TiledC4MatmulFp32.S @@ -0,0 +1,267 @@ +#ifdef __aarch64__ + + .text + .align 5 + .global TiledC4MatmulFp32 +#ifndef __APPLE__ + .type TiledC4MatmulFp32, %function +#endif + +TiledC4MatmulFp32: +//void TiledC4MatmulFp32(float* dst, const float* src, const float* weight, size_t ic4, size_t cal_num, size_t oc4) +//x0: dst +//x1: src +//x2: weight +//x3: cal_num +//x4: ic4 +//x5: oc4 + +sub sp, sp, #128 +st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 +st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + +mov x7, #4 //sizeof(float) +mul x3, x3, x7 +mov x7, #64 +mul x10, x4, x7 + +cmp x5, #2 +blt LoopOcHalf +LoopOc: + mov x8, x1 + subs x9, x4, #1 + + add x6, x2, x10 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmul v18.4s, v8.4s, v2.s[0] + fmul v19.4s, v8.4s, v3.s[0] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + fmul v20.4s, v8.4s, v4.s[0] + fmul v21.4s, v8.4s, v5.s[0] + fmul v22.4s, v8.4s, v6.s[0] + fmul v23.4s, v8.4s, v7.s[0] + fmul v24.4s, v12.4s, v0.s[0] + fmul v25.4s, v12.4s, v1.s[0] + fmul v26.4s, v12.4s, v2.s[0] + fmul v27.4s, v12.4s, v3.s[0] + fmul v28.4s, v12.4s, v4.s[0] + fmul v29.4s, v12.4s, v5.s[0] + fmul v30.4s, v12.4s, v6.s[0] + fmul v31.4s, v12.4s, v7.s[0] + + beq LoopIcEnd + LoopIc: + add x2, x2, #128 + prfm pldl1keep, [x2] + prfm pldl1keep, [x2, x10] + sub x2, x2, #128 + prfm pldl1keep, [x8, #128] + prfm pldl1keep, [x8, #192] + + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + fmla v24.4s, v13.4s, v0.s[1] + fmla v25.4s, v13.4s, v1.s[1] + fmla v26.4s, v13.4s, v2.s[1] + fmla v27.4s, v13.4s, v3.s[1] + fmla v28.4s, v13.4s, v4.s[1] + fmla v29.4s, v13.4s, v5.s[1] + fmla v30.4s, v13.4s, v6.s[1] + fmla v31.4s, v13.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + fmla v24.4s, v14.4s, v0.s[2] + fmla v25.4s, v14.4s, v1.s[2] + fmla v26.4s, v14.4s, v2.s[2] + fmla v27.4s, v14.4s, v3.s[2] + fmla v28.4s, v14.4s, v4.s[2] + fmla v29.4s, v14.4s, v5.s[2] + fmla v30.4s, v14.4s, v6.s[2] + fmla v31.4s, v14.4s, v7.s[2] + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + fmla v24.4s, v15.4s, v0.s[3] + fmla v25.4s, v15.4s, v1.s[3] + fmla v26.4s, v15.4s, v2.s[3] + fmla v27.4s, v15.4s, v3.s[3] + fmla v28.4s, v15.4s, v4.s[3] + fmla v29.4s, v15.4s, v5.s[3] + fmla v30.4s, v15.4s, v6.s[3] + fmla v31.4s, v15.4s, v7.s[3] + + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmla v18.4s, v8.4s, v2.s[0] + fmla v19.4s, v8.4s, v3.s[0] + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x6], #64 + fmla v20.4s, v8.4s, v4.s[0] + fmla v21.4s, v8.4s, v5.s[0] + fmla v22.4s, v8.4s, v6.s[0] + fmla v23.4s, v8.4s, v7.s[0] + fmla v24.4s, v12.4s, v0.s[0] + fmla v25.4s, v12.4s, v1.s[0] + fmla v26.4s, v12.4s, v2.s[0] + fmla v27.4s, v12.4s, v3.s[0] + fmla v28.4s, v12.4s, v4.s[0] + fmla v29.4s, v12.4s, v5.s[0] + fmla v30.4s, v12.4s, v6.s[0] + fmla v31.4s, v12.4s, v7.s[0] + + subs x9, x9, #1 + bne LoopIc + + LoopIcEnd: + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + fmla v24.4s, v13.4s, v0.s[1] + fmla v25.4s, v13.4s, v1.s[1] + fmla v26.4s, v13.4s, v2.s[1] + fmla v27.4s, v13.4s, v3.s[1] + fmla v28.4s, v13.4s, v4.s[1] + fmla v29.4s, v13.4s, v5.s[1] + fmla v30.4s, v13.4s, v6.s[1] + fmla v31.4s, v13.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + fmla v24.4s, v14.4s, v0.s[2] + fmla v25.4s, v14.4s, v1.s[2] + fmla v26.4s, v14.4s, v2.s[2] + fmla v27.4s, v14.4s, v3.s[2] + fmla v28.4s, v14.4s, v4.s[2] + fmla v29.4s, v14.4s, v5.s[2] + fmla v30.4s, v14.4s, v6.s[2] + fmla v31.4s, v14.4s, v7.s[2] + + add x7, x0, #64 + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + fmla v24.4s, v15.4s, v0.s[3] + fmla v25.4s, v15.4s, v1.s[3] + fmla v26.4s, v15.4s, v2.s[3] + fmla v27.4s, v15.4s, v3.s[3] + fmla v28.4s, v15.4s, v4.s[3] + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x3 + fmla v29.4s, v15.4s, v5.s[3] + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x7], x3 + fmla v30.4s, v15.4s, v6.s[3] + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], x3 + mov x2, x6 + fmla v31.4s, v15.4s, v7.s[3] + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x7] + + subs x5, x5, #2 + beq LoopOcEnd + cmp x5, #2 + bge LoopOc + +LoopOcHalf: + mov x8, x1 + mov x9, x4 + dup v16.4s, wzr + dup v17.4s, wzr + dup v18.4s, wzr + dup v19.4s, wzr + dup v20.4s, wzr + dup v21.4s, wzr + dup v22.4s, wzr + dup v23.4s, wzr + + LoopIcHalf: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x8], #64 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v1.s[0] + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x8], #64 + fmla v18.4s, v8.4s, v2.s[0] + fmla v19.4s, v8.4s, v3.s[0] + fmla v20.4s, v8.4s, v4.s[0] + fmla v21.4s, v8.4s, v5.s[0] + fmla v22.4s, v8.4s, v6.s[0] + fmla v23.4s, v8.4s, v7.s[0] + + fmla v16.4s, v9.4s, v0.s[1] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v2.s[1] + fmla v19.4s, v9.4s, v3.s[1] + fmla v20.4s, v9.4s, v4.s[1] + fmla v21.4s, v9.4s, v5.s[1] + fmla v22.4s, v9.4s, v6.s[1] + fmla v23.4s, v9.4s, v7.s[1] + + fmla v16.4s, v10.4s, v0.s[2] + fmla v17.4s, v10.4s, v1.s[2] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v3.s[2] + fmla v20.4s, v10.4s, v4.s[2] + fmla v21.4s, v10.4s, v5.s[2] + fmla v22.4s, v10.4s, v6.s[2] + fmla v23.4s, v10.4s, v7.s[2] + + fmla v16.4s, v11.4s, v0.s[3] + fmla v17.4s, v11.4s, v1.s[3] + fmla v18.4s, v11.4s, v2.s[3] + fmla v19.4s, v11.4s, v3.s[3] + fmla v20.4s, v11.4s, v4.s[3] + fmla v21.4s, v11.4s, v5.s[3] + fmla v22.4s, v11.4s, v6.s[3] + fmla v23.4s, v11.4s, v7.s[3] + + subs x9, x9, #1 + bne LoopIcHalf + + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + +LoopOcEnd: + sub sp, sp, #128 + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64 + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64 + ret +#endif diff --git a/mindspore/lite/nnacl/assembly/arm64/WinogradTransLeft.S b/mindspore/lite/nnacl/assembly/arm64/WinogradTransLeft.S new file mode 100644 index 0000000000..ec3a30e7c1 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/WinogradTransLeft.S @@ -0,0 +1,147 @@ +#ifdef __aarch64__ + + .text + .align 5 + .global WinogradTransLeft +#ifndef __APPLE__ + .type WinogradTransLeft, %function +#endif + +WinogradTransLeft: +//void WinogradTransLeft(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6:length + +sub sp, sp, #32 +stp x19, x20, [sp], #32 + +mov x8, #16 // 4 * sizeof(float) +mul x8, x6, x8 +mul x9, x3, x8 +sub x9, x9, x8 +add x7, x9, x8 // step for S +mov x10, #4 +mul x10, x4, x10 // step for B + +LoopH: + mov x13, x0 + mov x15, x3 + LoopW: + mov x14, x13 + mov x17, x1 + dup v30.4s, wzr + mov x11, x6 + InitZero: + st1 {v30.4s}, [x2], #16 + subs x11, x11, #1 + bne InitZero + + sub x2, x2, x8 + mov x12, x5 + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x19, x4 + LoopK4: + ld1 {v0.s}[0], [x17], x10 + ld1 {v0.s}[1], [x17], x10 + ld1 {v0.s}[2], [x17], x10 + ld1 {v0.s}[3], [x17], x10 + mov x11, x6 + mov x18, x17 + add x18, x14, x7 + add x16, x18, x7 + add x19, x16, x7 + + LoopLength4: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x14], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x18], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + ld1 {v21.4s}, [x19], #16 + fmla v17.4s, v21.4s, v0.s[3] + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength4 + + sub x2, x2, x8 + sub x12, x12, #4 + add x14, x19, x9 + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.s}[0], [x17], x10 + ld1 {v0.s}[1], [x17], x10 + ld1 {v0.s}[2], [x17], x10 + mov x11, x6 + mov x18, x17 + add x18, x14, x7 + add x16, x18, x7 + LoopLength3: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x14], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x18], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength3 + + sub x2, x2, x8 + sub x12, x12, #3 + add x14, x16, x9 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LKEnd + LoopK: + ld1r {v31.4s}, [x17], x10 + mov x11, x6 + LoopLength: + ld1 {v0.4s}, [x2] + ld1 {v1.4s}, [x14], #16 + fmla v0.4s, v1.4s, v31.4s + st1 {v0.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength + + subs x12, x12, #1 + sub x2, x2, x8 + add x14, x14, x9 + bne LoopK + + LKEnd: + subs x15, x15, #1 + add x13, x13, x8 + add x2, x2, x8 + bne LoopW + + add x1, x1, #4 //sizeof(float) + subs x4, x4, #1 + bne LoopH + + sub sp, sp, #32 + ldp x19, x20, [sp], #32 + ret + +#endif diff --git a/mindspore/lite/nnacl/assembly/arm64/WinogradTransRight.S b/mindspore/lite/nnacl/assembly/arm64/WinogradTransRight.S new file mode 100644 index 0000000000..ff65ef0122 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/WinogradTransRight.S @@ -0,0 +1,144 @@ +#ifdef __aarch64__ + + .text + .align 5 + .global WinogradTransRight +#ifndef __APPLE__ + .type WinogradTransRight, %function +#endif + +WinogradTransRight: +//void WinogradTransRight(const float* S, const float* B, float* M, size_t w, size_t h, size_t k, size_t length); +//x0: S +//x1: B +//x2: M +//x3: w +//x4: h +//x5: k +//x6: length + +mov x8, #16 // 4 * sizeof(float) +mul x8, x6, x8 +mul x9, x5, x8 // step for S +mov x10, #4 +mul x10, x4, x10 // step for B + +LoopH: + mov x7, x1 + mov x15, x3 + LoopW: + mov x17, x0 + mov x13, x7 + dup v30.4s, wzr + mov x11, x6 + InitZero: + st1 {v30.4s}, [x2], #16 + subs x11, x11, #1 + bne InitZero + sub x2, x2, x8 + mov x12, x5 + + LoopKStart4: + cmp x12, #4 + blt LoopKStart3 + mov x16, x15 + mov x18, x4 + LoopK4: + ld1 {v0.s}[0], [x13], x10 + ld1 {v0.s}[1], [x13], x10 + ld1 {v0.s}[2], [x13], x10 + ld1 {v0.s}[3], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + add x18, x16, x8 + + LoopLength4: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x17], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x14], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + ld1 {v21.4s}, [x18], #16 + fmla v17.4s, v21.4s, v0.s[3] + + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength4 + sub x2, x2, x8 + sub x12, x12, #4 + mov x17, x18 + + cmp x12, #4 + bge LoopK4 + + LoopKStart3: + cmp x12, #3 + blt LoopKStart + mov x16, x15 + LoopK3: + ld1 {v0.s}[0], [x13], x10 + ld1 {v0.s}[1], [x13], x10 + ld1 {v0.s}[2], [x13], x10 + mov x11, x6 + mov x14, x13 + + add x14, x17, x8 + add x16, x14, x8 + + LoopLength3: + ld1 {v16.4s}, [x2] + ld1 {v20.4s}, [x17], #16 + fmla v16.4s, v20.4s, v0.s[0] + ld1 {v21.4s}, [x14], #16 + fmul v17.4s, v21.4s, v0.s[1] + ld1 {v20.4s}, [x16], #16 + fmla v16.4s, v20.4s, v0.s[2] + + fadd v17.4s, v16.4s, v17.4s + st1 {v17.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength3 + sub x2, x2, x8 + sub x12, x12, #3 + mov x17, x18 + cmp x12, #3 + bge LoopK3 + + LoopKStart: + cmp x12, #0 + beq LoopKEnd + + LoopK: + ld1r {v31.4s}, [x13], x10 + + mov x11, x6 + LoopLength: + ld1 {v0.4s}, [x2] + ld1 {v1.4s}, [x17], #16 + fmla v0.4s, v1.4s, v31.4s + + st1 {v0.4s}, [x2], #16 + subs x11, x11, #1 + bne LoopLength + subs x12, x12, #1 + + sub x2, x2, x8 + bne LoopK + LoopKEnd: + subs x15, x15, #1 + add x2, x2, x8 + add x7, x7, #4 //sizeof(float) + bne LoopW + + add x0, x0, x9 + subs x4, x4, #1 + bne LoopH + + ret +#endif diff --git a/mindspore/lite/nnacl/fp32/common_func.c b/mindspore/lite/nnacl/fp32/common_func.c index c320a1da1b..b429d7ab6a 100644 --- a/mindspore/lite/nnacl/fp32/common_func.c +++ b/mindspore/lite/nnacl/fp32/common_func.c @@ -68,7 +68,8 @@ void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bi return; } -void WinogradMatrixProductLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { +#ifndef ENABLE_ARM +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { int unitStep = 4 * length; for (int y = 0; y < h; ++y) { float *dstY = M + y * w * unitStep; @@ -91,7 +92,7 @@ void WinogradMatrixProductLeft(const float *S, const float *B, float *M, size_t } // M = S * B , M = w*h * l, S = k*h * l, B = w*k -void WinogradMatrixProductRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) { int unitStep = 4 * length; for (int y = 0; y < h; ++y) { float *dstY = M + y * w * unitStep; @@ -113,6 +114,7 @@ void WinogradMatrixProductRight(const float *S, const float *B, float *M, size_t } } } +#endif union float32_bits { unsigned int u; diff --git a/mindspore/lite/nnacl/fp32/common_func.h b/mindspore/lite/nnacl/fp32/common_func.h index 55759c2958..0f2bac7e9f 100644 --- a/mindspore/lite/nnacl/fp32/common_func.h +++ b/mindspore/lite/nnacl/fp32/common_func.h @@ -32,8 +32,8 @@ void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bi void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel, size_t plane_size, size_t plane_stride, size_t relu_type); -void WinogradMatrixProductLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); -void WinogradMatrixProductRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); +void WinogradTransLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); +void WinogradTransRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length); float ShortToFloat32(uint16_t src_value); diff --git a/mindspore/lite/nnacl/fp32/deconv_winograd.c b/mindspore/lite/nnacl/fp32/deconv_winograd.c index 5eccd59ed6..7228c84215 100644 --- a/mindspore/lite/nnacl/fp32/deconv_winograd.c +++ b/mindspore/lite/nnacl/fp32/deconv_winograd.c @@ -130,21 +130,21 @@ void DeConvWgInputPack(float *src_ptr, float *dst_ptr, int channel, int stride) return; } -void MSGemmFloatCommon_4(float *dst, const float *src, const float *weight, size_t src_depth_quad, size_t dst_step, - size_t dst_depth_quad, size_t width, size_t weight_depth_offset) { +#ifndef ENABLE_ARM +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t cal_num, size_t ic4, size_t oc4) { int dx, sz, dz; - int src_depth_step = 4 * width; - for (dz = 0; dz < dst_depth_quad; ++dz) { - float *dst_z = dst + dz * dst_step; - const float *weight_dz = weight + dz * (src_depth_quad * 16 + weight_depth_offset); - for (dx = 0; dx < width; ++dx) { + int src_depth_step = 4 * DECONV_WINOGRAD_DEFAULT_TILE; + for (dz = 0; dz < oc4; ++dz) { + float *dst_z = dst + dz * cal_num; + const float *weight_dz = weight + dz * ic4 * 16; + for (dx = 0; dx < DECONV_WINOGRAD_DEFAULT_TILE; ++dx) { float *dst_x = dst_z + dx * 4; dst_x[0] = 0.0f; dst_x[1] = 0.0f; dst_x[2] = 0.0f; dst_x[3] = 0.0f; const float *src_dx = src + 4 * dx; - for (sz = 0; sz < src_depth_quad; ++sz) { + for (sz = 0; sz < ic4; ++sz) { const float *src_z = src_dx + sz * src_depth_step; const float *weight_z = weight_dz + sz * 16; for (int i = 0; i < 4; ++i) { @@ -156,12 +156,7 @@ void MSGemmFloatCommon_4(float *dst, const float *src, const float *weight, size } } } - -void MSGemmFloatUnit_4(float *dstOrigin, const float *src, const float *weight, size_t src_depth_quad, size_t dst_step, - size_t dst_depth_quad, size_t weight_depth_offset) { - MSGemmFloatCommon_4(dstOrigin, src, weight, src_depth_quad, dst_step, dst_depth_quad, DECONV_WINOGRAD_DEFAULT_TILE, - weight_depth_offset); -} +#endif void DeConvWgMerge(const float *src, float *dst, size_t src_stride, size_t dst_stride, size_t count) { for (int i = 0; i < count; ++i) { @@ -179,10 +174,10 @@ void _deConvWinograd(float *tile_in, float *tile_out, float *weight_buf, float * int unit_size, int w_start, int h_start, ConvParameter *conv_param, DeConvParam *deconv_param) { int winograd_plane = unit_size * unit_size; if (!transfered[unit_size]) { - WinogradMatrixProductLeft(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, - DECONV_WINOGRAD_DEFAULT_UNIT, deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); - WinogradMatrixProductRight(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, - deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransLeft(tile_in, at_buf, a_mid_buf, DECONV_WINOGRAD_DEFAULT_UNIT, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, + deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRight(a_mid_buf, at_buf, trans_a_buf, unit_size, unit_size, DECONV_WINOGRAD_DEFAULT_UNIT, + deconv_param->ic_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); transfered[unit_size] = true; } @@ -190,14 +185,14 @@ void _deConvWinograd(float *tile_in, float *tile_out, float *weight_buf, float * float *src = trans_a_buf + index * DECONV_WINOGRAD_DEFAULT_TILE * deconv_param->ic_up4_; float *dst = tmp_buf + index * deconv_param->oc_up4_ * DECONV_WINOGRAD_DEFAULT_TILE; float *weight = weight_buf + index * deconv_param->ic_up4_ * deconv_param->oc_up4_; - MSGemmFloatUnit_4(dst, src, weight, deconv_param->ic_div4_, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, - deconv_param->oc_div4_, 0); + TiledC4MatmulFp32(dst, src, weight, DECONV_WINOGRAD_DEFAULT_TILE * C4NUM, deconv_param->ic_div4_, + deconv_param->oc_div4_); } - WinogradMatrixProductLeft(tmp_buf, bt_buf, b_tmp_buf, unit_size, unit_size, unit_size, - deconv_param->oc_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); - WinogradMatrixProductRight(b_tmp_buf, bt_buf, tmp_buf, unit_size, unit_size, unit_size, - deconv_param->oc_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransLeft(tmp_buf, bt_buf, b_tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); + WinogradTransRight(b_tmp_buf, bt_buf, tmp_buf, unit_size, unit_size, unit_size, + deconv_param->oc_div4_ * DECONV_WINOGRAD_DEFAULT_TILE); // Add to dest for (int uhi = 0; uhi < unit_size; uhi++) { @@ -223,7 +218,7 @@ void _deConvCommon(float *tile_in, float *tile_out, float *weight, float *tmp_bu for (int hi = 0; hi < DECONV_WINOGRAD_DEFAULT_UNIT; hi++) { for (int wi = 0; wi < DECONV_WINOGRAD_DEFAULT_UNIT; wi++) { float *src_in = tile_in + (wi + hi * DECONV_WINOGRAD_DEFAULT_UNIT) * in_stride; - MSGemmFloatUnit_4(tmp_buf, src_in, weight, deconv_param->ic_div4_, DECONV_WINOGRAD_DEFAULT_TILE * 4, count, 0); + TiledC4MatmulFp32(tmp_buf, src_in, weight, DECONV_WINOGRAD_DEFAULT_TILE * 4, deconv_param->ic_div4_, count); for (int uhi = 0; uhi < h_size; uhi++) { for (int uwi = 0; uwi < w_size; uwi++) { diff --git a/mindspore/lite/nnacl/fp32/deconv_winograd.h b/mindspore/lite/nnacl/fp32/deconv_winograd.h index 576b772d7d..47c1993f02 100644 --- a/mindspore/lite/nnacl/fp32/deconv_winograd.h +++ b/mindspore/lite/nnacl/fp32/deconv_winograd.h @@ -34,6 +34,7 @@ void DeconvWg(float *nhwc_input_, float *tile_in, float *tile_out, int start_ind ConvParameter *conv_param, DeConvParam *deconv_param, int task_id); void DeconvWgPost(float *tile_out, float *nc4hw4_output, ConvParameter *conv_param, DeConvParam *deconv_param, int calculate_count, int tile_index); +void TiledC4MatmulFp32(float *dst, const float *src, const float *weight, size_t ic4, size_t cal_num, size_t oc4); #ifdef __cplusplus } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc index 584ab0da8a..07a431e2c5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution.cc @@ -254,7 +254,7 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) && (conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1)) { /* DeConvolutionWinogradCPUKernel */ - kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive); + kernel = new (std::nothrow) kernel::DeConvolutionWinogradCPUKernel(opParameter, inputs, outputs, ctx, primitive); } else { kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc index 63ded03375..71b8098cad 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/deconvolution_winograd.cc @@ -258,10 +258,10 @@ int DeConvolutionWinogradCPUKernel::InitDataParam() { } /* bias */ - auto bias_tensor = in_tensors_.at(kBiasIndex); bias_data_ = malloc(deconv_param_->oc_up4_ * sizeof(float)); memset(bias_data_, 0, deconv_param_->oc_up4_ * sizeof(float)); if (in_tensors_.size() == 3) { + auto bias_tensor = in_tensors_.at(kBiasIndex); memcpy(bias_data_, bias_tensor->data_c(), conv_param_->output_channel_ * sizeof(float)); } return RET_OK;