|
|
|
@ -16,16 +16,16 @@
|
|
|
|
|
|
|
|
|
|
#include "nnacl/fp16/matmul_fp16.h"
|
|
|
|
|
|
|
|
|
|
void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) {
|
|
|
|
|
void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) {
|
|
|
|
|
int row_c8 = row / C8NUM * C8NUM;
|
|
|
|
|
int col_c8 = col / C8NUM * C8NUM;
|
|
|
|
|
int ci = 0;
|
|
|
|
|
if (src_float16) {
|
|
|
|
|
float16_t *src = (float16_t *)src_ptr;
|
|
|
|
|
const float16_t *src = (const float16_t *)src_ptr;
|
|
|
|
|
for (; ci < col_c8; ci += C8NUM) {
|
|
|
|
|
int ri = 0;
|
|
|
|
|
for (; ri < row_c8; ri += C8NUM) {
|
|
|
|
|
float16_t *src_ptr1 = src + ci * row + ri;
|
|
|
|
|
const float16_t *src_ptr1 = src + ci * row + ri;
|
|
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
|
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
|
size_t strid_row = row * 2;
|
|
|
|
@ -93,7 +93,7 @@ void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
for (; ri < row; ++ri) {
|
|
|
|
|
float16_t *src_ptr1 = src + ci * row;
|
|
|
|
|
const float16_t *src_ptr1 = src + ci * row;
|
|
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row;
|
|
|
|
|
for (int tc = 0; tc < C8NUM; ++tc) {
|
|
|
|
|
dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri];
|
|
|
|
@ -108,11 +108,11 @@ void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
float *src = (float *)src_ptr;
|
|
|
|
|
const float *src = (const float *)src_ptr;
|
|
|
|
|
for (; ci < col_c8; ci += C8NUM) {
|
|
|
|
|
int ri = 0;
|
|
|
|
|
for (; ri < row_c8; ri += C8NUM) {
|
|
|
|
|
float *src_ptr1 = src + ci * row + ri;
|
|
|
|
|
const float *src_ptr1 = src + ci * row + ri;
|
|
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM;
|
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
|
size_t strid_row = row * 4;
|
|
|
|
@ -197,7 +197,7 @@ void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
for (; ri < row; ++ri) {
|
|
|
|
|
float *src_ptr1 = src + ci * row;
|
|
|
|
|
const float *src_ptr1 = src + ci * row;
|
|
|
|
|
float16_t *dst_ptr1 = dst_ptr + ci * row;
|
|
|
|
|
for (int tc = 0; tc < C8NUM; ++tc) {
|
|
|
|
|
dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]);
|
|
|
|
@ -274,18 +274,18 @@ void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const f
|
|
|
|
|
MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RowMajor2Col16MajorFp16Opt(float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
|
|
|
|
void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) {
|
|
|
|
|
size_t row_up_16 = UP_ROUND(row, C16NUM);
|
|
|
|
|
size_t row16 = row / C16NUM * C16NUM;
|
|
|
|
|
size_t col8 = col / C8NUM * C8NUM;
|
|
|
|
|
float16_t *src_r = src_ptr;
|
|
|
|
|
const float16_t *src_r = src_ptr;
|
|
|
|
|
float16_t *dst_r = dst_ptr;
|
|
|
|
|
|
|
|
|
|
size_t ri = 0;
|
|
|
|
|
for (; ri < row16; ri += C16NUM) {
|
|
|
|
|
size_t ci = 0;
|
|
|
|
|
for (; ci < col8; ci += C8NUM) {
|
|
|
|
|
float16_t *src_c = src_r + ci;
|
|
|
|
|
const float16_t *src_c = src_r + ci;
|
|
|
|
|
float16_t *dst_c = dst_r + ci * C16NUM;
|
|
|
|
|
|
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
@ -403,7 +403,7 @@ void RowMajor2Col16MajorFp16Opt(float16_t *src_ptr, float16_t *dst_ptr, size_t r
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
for (; ci < col; ci++) {
|
|
|
|
|
float16_t *src_c = src_r + ci;
|
|
|
|
|
const float16_t *src_c = src_r + ci;
|
|
|
|
|
float16_t *dst_c = dst_r + ci * C16NUM;
|
|
|
|
|
for (size_t i = 0; i < C16NUM; i++) {
|
|
|
|
|
dst_c[i] = src_c[i * col];
|
|
|
|
@ -428,57 +428,57 @@ void RowMajor2Col16MajorFp16Opt(float16_t *src_ptr, float16_t *dst_ptr, size_t r
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RowMajor2Col16MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
|
|
|
|
|
void RowMajor2Col16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
|
|
|
|
|
for (int r = 0; r < row; r++) {
|
|
|
|
|
for (int c = 0; c < col; c++) {
|
|
|
|
|
int r_div16 = r / 16;
|
|
|
|
|
int r_mod16 = r % 16;
|
|
|
|
|
if (is_fp32_src) {
|
|
|
|
|
dst[r_div16 * 16 * col + c * 16 + r_mod16] = (float16_t)(((float *)src)[r * col + c]);
|
|
|
|
|
dst[r_div16 * 16 * col + c * 16 + r_mod16] = (float16_t)(((const float *)src)[r * col + c]);
|
|
|
|
|
} else {
|
|
|
|
|
dst[r_div16 * 16 * col + c * 16 + r_mod16] = ((float16_t *)src)[r * col + c];
|
|
|
|
|
dst[r_div16 * 16 * col + c * 16 + r_mod16] = ((const float16_t *)src)[r * col + c];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RowMajor2Row16MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
|
|
|
|
|
void RowMajor2Row16MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
|
|
|
|
|
for (int r = 0; r < row; r++) {
|
|
|
|
|
for (int c = 0; c < col; c++) {
|
|
|
|
|
int c_div16 = c / 16;
|
|
|
|
|
int c_mod16 = c % 16;
|
|
|
|
|
if (is_fp32_src) {
|
|
|
|
|
dst[c_div16 * 16 * row + r * 16 + c_mod16] = (float16_t)(((float *)src)[r * col + c]);
|
|
|
|
|
dst[c_div16 * 16 * row + r * 16 + c_mod16] = (float16_t)(((const float *)src)[r * col + c]);
|
|
|
|
|
} else {
|
|
|
|
|
dst[c_div16 * 16 * row + r * 16 + c_mod16] = ((float16_t *)src)[r * col + c];
|
|
|
|
|
dst[c_div16 * 16 * row + r * 16 + c_mod16] = ((const float16_t *)src)[r * col + c];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RowMajor2Row8MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
|
|
|
|
|
void RowMajor2Row8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
|
|
|
|
|
for (int r = 0; r < row; r++) {
|
|
|
|
|
for (int c = 0; c < col; c++) {
|
|
|
|
|
int c_div8 = c / 8;
|
|
|
|
|
int c_mod8 = c % 8;
|
|
|
|
|
if (is_fp32_src) {
|
|
|
|
|
dst[c_div8 * 8 * row + r * 8 + c_mod8] = (float16_t)(((float *)src)[r * col + c]);
|
|
|
|
|
dst[c_div8 * 8 * row + r * 8 + c_mod8] = (float16_t)(((const float *)src)[r * col + c]);
|
|
|
|
|
} else {
|
|
|
|
|
dst[c_div8 * 8 * row + r * 8 + c_mod8] = ((float16_t *)src)[r * col + c];
|
|
|
|
|
dst[c_div8 * 8 * row + r * 8 + c_mod8] = ((const float16_t *)src)[r * col + c];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void RowMajor2Col8MajorFp16(void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
|
|
|
|
|
void RowMajor2Col8MajorFp16(const void *src, float16_t *dst, int row, int col, bool is_fp32_src) {
|
|
|
|
|
for (int r = 0; r < row; r++) {
|
|
|
|
|
for (int c = 0; c < col; c++) {
|
|
|
|
|
int r_div8 = r / 8;
|
|
|
|
|
int r_mod8 = r % 8;
|
|
|
|
|
if (is_fp32_src) {
|
|
|
|
|
dst[r_div8 * 8 * col + c * 8 + r_mod8] = (float16_t)(((float *)src)[r * col + c]);
|
|
|
|
|
dst[r_div8 * 8 * col + c * 8 + r_mod8] = (float16_t)(((const float *)src)[r * col + c]);
|
|
|
|
|
} else {
|
|
|
|
|
dst[r_div8 * 8 * col + c * 8 + r_mod8] = ((float16_t *)src)[r * col + c];
|
|
|
|
|
dst[r_div8 * 8 * col + c * 8 + r_mod8] = ((const float16_t *)src)[r * col + c];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|