|
|
|
@ -16,12 +16,11 @@
|
|
|
|
|
|
|
|
|
|
#include "nnacl/fp16/matmul_fp16.h"
|
|
|
|
|
|
|
|
|
|
void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) {
|
|
|
|
|
static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
|
|
|
|
int row_c8 = row / C8NUM * C8NUM;
|
|
|
|
|
int col_c8 = col / C8NUM * C8NUM;
|
|
|
|
|
int ci = 0;
|
|
|
|
|
if (src_float16) {
|
|
|
|
|
const float16_t *src = (const float16_t *)src_ptr;
|
|
|
|
|
int ci = 0;
|
|
|
|
|
for (; ci < col_c8; ci += C8NUM) {
|
|
|
|
|
int ri = 0;
|
|
|
|
|
for (; ri < row_c8; ri += C8NUM) {
|
|
|
|
@ -107,7 +106,12 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row,
|
|
|
|
|
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
|
|
|
|
int row_c8 = row / C8NUM * C8NUM;
|
|
|
|
|
int col_c8 = col / C8NUM * C8NUM;
|
|
|
|
|
int ci = 0;
|
|
|
|
|
const float *src = (const float *)src_ptr;
|
|
|
|
|
for (; ci < col_c8; ci += C8NUM) {
|
|
|
|
|
int ri = 0;
|
|
|
|
@ -211,6 +215,13 @@ void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row,
|
|
|
|
|
dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) {
|
|
|
|
|
if (src_float16) {
|
|
|
|
|
Col2Row8SrcFromFp16(src_ptr, dst_ptr, row, col);
|
|
|
|
|
} else {
|
|
|
|
|
Col2Row8SrcFromFp32(src_ptr, dst_ptr, row, col);
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
@ -274,21 +285,7 @@ void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const f
|
|
|
|
|
MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
|
|
|
|
size_t row_up_16 = UP_ROUND(row, C16NUM);
|
|
|
|
|
size_t row16 = row / C16NUM * C16NUM;
|
|
|
|
|
size_t col8 = col / C8NUM * C8NUM;
|
|
|
|
|
const float16_t *src_r = src_ptr;
|
|
|
|
|
float16_t *dst_r = dst_ptr;
|
|
|
|
|
|
|
|
|
|
size_t ri = 0;
|
|
|
|
|
for (; ri < row16; ri += C16NUM) {
|
|
|
|
|
size_t ci = 0;
|
|
|
|
|
for (; ci < col8; ci += C8NUM) {
|
|
|
|
|
const float16_t *src_c = src_r + ci;
|
|
|
|
|
float16_t *dst_c = dst_r + ci * C16NUM;
|
|
|
|
|
|
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
|
static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) {
|
|
|
|
|
size_t stride = col * 2;
|
|
|
|
|
asm volatile(
|
|
|
|
|
"mov x10, %[src_c]\n"
|
|
|
|
@ -390,10 +387,27 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si
|
|
|
|
|
"st1 {v27.8h}, [x11], #16\n"
|
|
|
|
|
"st1 {v31.8h}, [x11], #16\n"
|
|
|
|
|
:
|
|
|
|
|
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
|
|
|
|
|
: [ dst_c ] "r"(dst_ptr), [ src_c ] "r"(src_ptr), [ stride ] "r"(stride)
|
|
|
|
|
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
|
|
|
|
|
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
|
|
|
|
|
"v30", "v31");
|
|
|
|
|
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
|
|
|
|
|
"v31");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
|
|
|
|
size_t row_up_16 = UP_ROUND(row, C16NUM);
|
|
|
|
|
size_t row16 = row / C16NUM * C16NUM;
|
|
|
|
|
size_t col8 = col / C8NUM * C8NUM;
|
|
|
|
|
const float16_t *src_r = src_ptr;
|
|
|
|
|
float16_t *dst_r = dst_ptr;
|
|
|
|
|
size_t ri = 0;
|
|
|
|
|
// find 16 block unit
|
|
|
|
|
for (; ri < row16; ri += C16NUM) {
|
|
|
|
|
size_t ci = 0;
|
|
|
|
|
for (; ci < col8; ci += C8NUM) {
|
|
|
|
|
const float16_t *src_c = src_r + ci;
|
|
|
|
|
float16_t *dst_c = dst_r + ci * C16NUM;
|
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
|
Row2Col16Block16(src_c, dst_c, col);
|
|
|
|
|
#else
|
|
|
|
|
for (int tr = 0; tr < C16NUM; tr++) {
|
|
|
|
|
for (int tc = 0; tc < C8NUM; tc++) {
|
|
|
|
@ -413,7 +427,7 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si
|
|
|
|
|
dst_r += C16NUM * col;
|
|
|
|
|
}
|
|
|
|
|
for (; ri < row; ri++) {
|
|
|
|
|
for (size_t i = 0; i < col; i++) {
|
|
|
|
|
for (size_t i = 0; i < col; ++i) {
|
|
|
|
|
dst_r[i * C16NUM] = src_r[i];
|
|
|
|
|
}
|
|
|
|
|
src_r += col;
|
|
|
|
|