|
|
|
@ -16,6 +16,18 @@
|
|
|
|
|
|
|
|
|
|
#include "nnacl/fp32/matmul.h"
|
|
|
|
|
|
|
|
|
|
void RowMajor2Row4Major(float *src_ptr, float *dst_ptr, int row, int col) {
|
|
|
|
|
for (int r = 0; r < row; r++) {
|
|
|
|
|
float *src = src_ptr + r * col;
|
|
|
|
|
for (int c = 0; c < col; c++) {
|
|
|
|
|
int cd8 = c / 4;
|
|
|
|
|
int cm8 = c % 4;
|
|
|
|
|
dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) {
|
|
|
|
|
for (int r = 0; r < row; r++) {
|
|
|
|
|
float *src = src_ptr + r * col;
|
|
|
|
@ -115,6 +127,61 @@ void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
|
|
|
|
|
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
|
|
|
|
|
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
|
|
|
|
|
"v30", "v31");
|
|
|
|
|
#elif ENABLE_ARM32
|
|
|
|
|
size_t stride = col * sizeof(float);
|
|
|
|
|
asm volatile(
|
|
|
|
|
"mov r10, %[src_c]\n"
|
|
|
|
|
"mov r12, %[dst_c]\n"
|
|
|
|
|
|
|
|
|
|
"vld1.32 {q0}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q3}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q10}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q13}, [r10], %[stride]\n"
|
|
|
|
|
|
|
|
|
|
"vtrn.32 d0, d6\n"
|
|
|
|
|
"vtrn.32 d1, d7\n"
|
|
|
|
|
"vtrn.32 d20, d26\n"
|
|
|
|
|
"vtrn.32 d21, d27\n"
|
|
|
|
|
|
|
|
|
|
"vld1.32 {q1}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q8}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q11}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q14}, [r10], %[stride]\n"
|
|
|
|
|
|
|
|
|
|
"vswp d1, d20\n"
|
|
|
|
|
"vswp d7, d26\n"
|
|
|
|
|
|
|
|
|
|
"vld1.32 {q2}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q9}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q12}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q15}, [r10], %[stride]\n"
|
|
|
|
|
|
|
|
|
|
"vtrn.32 d2, d16\n"
|
|
|
|
|
"vtrn.32 d3, d17\n"
|
|
|
|
|
"vtrn.32 d22, d28\n"
|
|
|
|
|
"vtrn.32 d23, d29\n"
|
|
|
|
|
|
|
|
|
|
"vswp d3, d22\n"
|
|
|
|
|
"vswp d17, d28\n"
|
|
|
|
|
|
|
|
|
|
"vtrn.32 d4, d18\n"
|
|
|
|
|
"vtrn.32 d5, d19\n"
|
|
|
|
|
"vtrn.32 d24, d30\n"
|
|
|
|
|
"vtrn.32 d25, d31\n"
|
|
|
|
|
|
|
|
|
|
"vswp d5, d24\n"
|
|
|
|
|
"vswp d19, d30\n"
|
|
|
|
|
|
|
|
|
|
"vst1.32 {q0, q1}, [r12]!\n"
|
|
|
|
|
"vst1.32 {q2, q3}, [r12]!\n"
|
|
|
|
|
"vst1.32 {q8, q9}, [r12]!\n"
|
|
|
|
|
"vst1.32 {q10, q11}, [r12]!\n"
|
|
|
|
|
"vst1.32 {q12, q13}, [r12]!\n"
|
|
|
|
|
"vst1.32 {q14, q15}, [r12]!\n"
|
|
|
|
|
|
|
|
|
|
:
|
|
|
|
|
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
|
|
|
|
|
: "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
|
|
|
|
|
#else
|
|
|
|
|
for (int tr = 0; tr < C12NUM; tr++) {
|
|
|
|
|
for (int tc = 0; tc < C4NUM; tc++) {
|
|
|
|
@ -242,6 +309,75 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RowMajor2Col4Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
|
|
|
|
|
size_t row8 = row / C4NUM * C4NUM;
|
|
|
|
|
size_t col4 = col / C4NUM * C4NUM;
|
|
|
|
|
float *src_r = src_ptr;
|
|
|
|
|
float *dst_r = dst_ptr;
|
|
|
|
|
|
|
|
|
|
size_t ri = 0;
|
|
|
|
|
for (; ri < row8; ri += C4NUM) {
|
|
|
|
|
size_t ci = 0;
|
|
|
|
|
for (; ci < col4; ci += C4NUM) {
|
|
|
|
|
float *src_c = src_r + ci;
|
|
|
|
|
float *dst_c = dst_r + ci * C4NUM;
|
|
|
|
|
|
|
|
|
|
/* 4x4 row-major to col-major */
|
|
|
|
|
#ifdef ENABLE_ARM32
|
|
|
|
|
size_t stride = col * 4;
|
|
|
|
|
asm volatile(
|
|
|
|
|
"mov r10, %[src_c]\n"
|
|
|
|
|
"mov r12, %[dst_c]\n"
|
|
|
|
|
|
|
|
|
|
"vld1.32 {q0}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q1}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q2}, [r10], %[stride]\n"
|
|
|
|
|
"vld1.32 {q3}, [r10], %[stride]\n"
|
|
|
|
|
|
|
|
|
|
"vtrn.32 d0, d2\n"
|
|
|
|
|
"vtrn.32 d1, d3\n"
|
|
|
|
|
"vtrn.32 d4, d6\n"
|
|
|
|
|
"vtrn.32 d5, d7\n"
|
|
|
|
|
|
|
|
|
|
"vswp d1, d4\n"
|
|
|
|
|
"vswp d3, d6\n"
|
|
|
|
|
|
|
|
|
|
"vst1.32 {q0}, [r12]!\n"
|
|
|
|
|
"vst1.32 {q1}, [r12]!\n"
|
|
|
|
|
"vst1.32 {q2}, [r12]!\n"
|
|
|
|
|
"vst1.32 {q3}, [r12]!\n"
|
|
|
|
|
|
|
|
|
|
:
|
|
|
|
|
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
|
|
|
|
|
: "r10", "r12", "q0", "q1", "q2", "q3");
|
|
|
|
|
#else
|
|
|
|
|
for (int tr = 0; tr < C4NUM; tr++) {
|
|
|
|
|
for (int tc = 0; tc < C4NUM; tc++) {
|
|
|
|
|
dst_c[tc * C4NUM + tr] = src_c[tr * col + tc];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
for (; ci < col; ci++) {
|
|
|
|
|
float *src_c = src_r + ci;
|
|
|
|
|
float *dst_c = dst_r + ci * C4NUM;
|
|
|
|
|
for (size_t i = 0; i < C4NUM; i++) {
|
|
|
|
|
dst_c[i] = src_c[i * col];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
src_r += C4NUM * col;
|
|
|
|
|
dst_r += C4NUM * col;
|
|
|
|
|
}
|
|
|
|
|
for (; ri < row; ri++) {
|
|
|
|
|
for (size_t i = 0; i < col; i++) {
|
|
|
|
|
dst_r[i * C4NUM] = src_r[i];
|
|
|
|
|
}
|
|
|
|
|
src_r += col;
|
|
|
|
|
dst_r += 1;
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void MatrixUnPackUnit(const void *src, void *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride,
|
|
|
|
|
size_t data_lenth) {
|
|
|
|
|
size_t copy_size = col * data_lenth;
|
|
|
|
@ -418,6 +554,9 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
|
|
|
|
|
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
|
|
|
|
|
(int)(out_type == OutType_TileC8));
|
|
|
|
|
}
|
|
|
|
|
#elif ENABLE_ARM32
|
|
|
|
|
MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
|
|
|
|
|
(int)(out_type == OutType_TileC8));
|
|
|
|
|
#else
|
|
|
|
|
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
|
|
|
|
|
#endif
|
|
|
|
|