From 73f3bf11762d56ac0bd6f1702362010ec89f91cc Mon Sep 17 00:00:00 2001 From: liuzhongkai Date: Wed, 14 Oct 2020 15:40:14 +0800 Subject: [PATCH] fp16 conv1x1 init asm optimize --- mindspore/lite/nnacl/fp16/matmul_fp16.c | 192 +++++++++++++++++++++++- 1 file changed, 184 insertions(+), 8 deletions(-) diff --git a/mindspore/lite/nnacl/fp16/matmul_fp16.c b/mindspore/lite/nnacl/fp16/matmul_fp16.c index 85d7998b94..eb9bef6cbf 100644 --- a/mindspore/lite/nnacl/fp16/matmul_fp16.c +++ b/mindspore/lite/nnacl/fp16/matmul_fp16.c @@ -17,22 +17,198 @@ #include "nnacl/fp16/matmul_fp16.h" void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) { + int row_c8 = row / C8NUM * C8NUM; + int col_c8 = col / C8NUM * C8NUM; + int ci = 0; if (src_float16) { float16_t *src = (float16_t *)src_ptr; + for (; ci < col_c8; ci += C8NUM) { + int ri = 0; + for (; ri < row_c8; ri += C8NUM) { + float16_t *src_ptr1 = src + ci * row + ri; + float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; +#ifdef ENABLE_ARM64 + size_t strid_row = row * 2; + asm volatile( + "mov x10, %[src_ptr1]\n" + "mov x11, %[dst_ptr1]\n" + "mov x12, %[strid_row]\n" + "ld1 {v0.8h}, [x10], x12\n" + "ld1 {v1.8h}, [x10], x12\n" + "ld1 {v2.8h}, [x10], x12\n" + "ld1 {v3.8h}, [x10], x12\n" + "ld1 {v4.8h}, [x10], x12\n" + "ld1 {v5.8h}, [x10], x12\n" + "ld1 {v6.8h}, [x10], x12\n" + "ld1 {v7.8h}, [x10], x12\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip1 v9.8h, v2.8h, v3.8h\n" + "zip1 v10.8h, v4.8h, v5.8h\n" + "zip1 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v16.2d, v12.2d, v14.2d\n" + "trn2 v18.2d, v12.2d, v14.2d\n" + "trn1 v17.2d, v13.2d, v15.2d\n" + "trn2 v19.2d, v13.2d, v15.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v22.2d, v12.2d, v14.2d\n" + "trn1 v21.2d, v13.2d, v15.2d\n" + "trn2 v23.2d, v13.2d, v15.2d\n" + + "st1 {v16.8h}, [x11], #16\n" + "st1 {v17.8h}, [x11], #16\n" + "st1 {v18.8h}, [x11], #16\n" + "st1 {v19.8h}, [x11], #16\n" + "st1 {v20.8h}, [x11], #16\n" + "st1 {v21.8h}, [x11], #16\n" + "st1 {v22.8h}, [x11], #16\n" + "st1 {v23.8h}, [x11], #16\n" + : + : [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#else + for (int tr = 0; tr < C8NUM; ++tr) { + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr]; + } + } +#endif + } + for (; ri < row; ++ri) { + float16_t *src_ptr1 = src + ci * row; + float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri]; + } + } + } for (int r = 0; r < row; r++) { - for (int c = 0; c < col; c++) { - int cd8 = c / 8; - int cm8 = c % 8; - dst_ptr[cd8 * 8 * row + r * 8 + cm8] = (float16_t)(src[c * row + r]); + for (int tc = ci; tc < col; tc++) { + int cd8 = tc / C8NUM; + int cm8 = tc % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r]; } } } else { float *src = (float *)src_ptr; + for (; ci < col_c8; ci += C8NUM) { + int ri = 0; + for (; ri < row_c8; ri += C8NUM) { + float *src_ptr1 = src + ci * row + ri; + float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; +#ifdef ENABLE_ARM64 + size_t strid_row = row * 4; + asm volatile( + "mov x10, %[src_ptr1]\n" + "mov x11, %[dst_ptr1]\n" + "mov x12, %[strid_row]\n" + "ld1 {v8.4s, v9.4s}, [x10], x12\n" + "ld1 {v10.4s, v11.4s}, [x10], x12\n" + "ld1 {v12.4s, v13.4s}, [x10], x12\n" + "ld1 {v14.4s, v15.4s}, [x10], x12\n" + "ld1 {v16.4s, v17.4s}, [x10], x12\n" + "ld1 {v18.4s, v19.4s}, [x10], x12\n" + "ld1 {v20.4s, v21.4s}, [x10], x12\n" + "ld1 {v22.4s, v23.4s}, [x10], x12\n" + + "fcvtn v0.4h, v8.4s\n" + "fcvtn2 v0.8h, v9.4s\n" + "fcvtn v1.4h, v10.4s\n" + "fcvtn2 v1.8h, v11.4s\n" + "fcvtn v2.4h, v12.4s\n" + "fcvtn2 v2.8h, v13.4s\n" + "fcvtn v3.4h, v14.4s\n" + "fcvtn2 v3.8h, v15.4s\n" + "fcvtn v4.4h, v16.4s\n" + "fcvtn2 v4.8h, v17.4s\n" + "fcvtn v5.4h, v18.4s\n" + "fcvtn2 v5.8h, v19.4s\n" + "fcvtn v6.4h, v20.4s\n" + "fcvtn2 v6.8h, v21.4s\n" + "fcvtn v7.4h, v22.4s\n" + "fcvtn2 v7.8h, v23.4s\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip1 v9.8h, v2.8h, v3.8h\n" + "zip1 v10.8h, v4.8h, v5.8h\n" + "zip1 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v16.2d, v12.2d, v14.2d\n" + "trn2 v18.2d, v12.2d, v14.2d\n" + "trn1 v17.2d, v13.2d, v15.2d\n" + "trn2 v19.2d, v13.2d, v15.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v22.2d, v12.2d, v14.2d\n" + "trn1 v21.2d, v13.2d, v15.2d\n" + "trn2 v23.2d, v13.2d, v15.2d\n" + + "st1 {v16.8h}, [x11], #16\n" + "st1 {v17.8h}, [x11], #16\n" + "st1 {v18.8h}, [x11], #16\n" + "st1 {v19.8h}, [x11], #16\n" + "st1 {v20.8h}, [x11], #16\n" + "st1 {v21.8h}, [x11], #16\n" + "st1 {v22.8h}, [x11], #16\n" + "st1 {v23.8h}, [x11], #16\n" + : + : [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); +#else + for (int tr = 0; tr < C8NUM; ++tr) { + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]); + } + } +#endif + } + for (; ri < row; ++ri) { + float *src_ptr1 = src + ci * row; + float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]); + } + } + } for (int r = 0; r < row; r++) { - for (int c = 0; c < col; c++) { - int cd8 = c / 8; - int cm8 = c % 8; - dst_ptr[cd8 * 8 * row + r * 8 + cm8] = (float16_t)(src[c * row + r]); + for (int tc = ci; tc < col; tc++) { + int cd8 = tc / C8NUM; + int cm8 = tc % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]); } } }