From f6f9d3915cb531435345ff0825ee77488f67e63a Mon Sep 17 00:00:00 2001 From: liuzhongkai Date: Wed, 30 Sep 2020 15:04:30 +0800 Subject: [PATCH] fp16 winograd init optimize --- .../nnacl/assembly/fp16/MatmulWinogradFp16.S | 206 ++++++++++++++++++ mindspore/lite/nnacl/fp16/matrix_fp16.c | 18 ++ mindspore/lite/nnacl/fp16/matrix_fp16.h | 2 + mindspore/lite/nnacl/fp16/pack_fp16.c | 8 + mindspore/lite/nnacl/fp16/pack_fp16.h | 2 + .../arm/fp16/convolution_winograd_fp16.cc | 111 ++++++---- 6 files changed, 299 insertions(+), 48 deletions(-) create mode 100644 mindspore/lite/nnacl/assembly/fp16/MatmulWinogradFp16.S diff --git a/mindspore/lite/nnacl/assembly/fp16/MatmulWinogradFp16.S b/mindspore/lite/nnacl/assembly/fp16/MatmulWinogradFp16.S new file mode 100644 index 0000000000..882a33d851 --- /dev/null +++ b/mindspore/lite/nnacl/assembly/fp16/MatmulWinogradFp16.S @@ -0,0 +1,206 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global MatrixMultiplyWinogradFp16 +#ifndef __APPLE__ +.type MatrixMultiplyWinogradFp16, %function +#endif + +// MatrixMultiplyWinogradFp16(float16_t *matix_a, float16_t *matrix_b, float16_t *matrix_c, int m, int k, int n, int in_channel) + // x0: matrix_a, x1: matrix_b, x2: matrix_c, x3: m, x4: k, x5: n, x6: in_channel +MatrixMultiplyWinogradFp16: + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + st1 {v8.8h}, [sp], #16 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + + mov x8, #2 + mul x10, x5, x8 // n * 2 + mov x17, x3 // m + mul x13, x6, x8 // in_channel * 2 + mul x21, x13, x4 // in_channel * k * 2 + + LoopM: + mov x15, x5 // n + mov x14, x1 // mat_b + LoopN: + mov x16, x0 // mat_a_m + sub x18, x5, x15 // ni + sub x19, x17, x3 // mi + mul x18, x18, x17 // ni * m + mov x11, x6 // in_channel + add x18, x18, x19 // (ni * m) + mi + mul x18, x18, x13 // x18 * channel_in * 2 + add x20, x2, x18 // dst + offset + cmp x11, #32 + bge LoopC32 + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + b EndLoopC + LoopC32: + mov x12, x14 + mov x9, x4 // new_k + dup v5.8h, wzr + dup v6.8h, wzr + dup v7.8h, wzr + dup v8.8h, wzr + LoopK32: + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + fmla v6.8h, v1.8h, v4.h[0] + fmla v7.8h, v2.8h, v4.h[0] + fmla v8.8h, v3.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK32 + Write32: + st1 {v5.8h}, [x20], #16 + st1 {v6.8h}, [x20], #16 + st1 {v7.8h}, [x20], #16 + st1 {v8.8h}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #64 // add 64B + subs x11, x11, #32 + beq EndLoopC + cmp x11, #32 + bge LoopC32 + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC16: + dup v5.8h, wzr + dup v6.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK16: + ld1 {v0.8h, v1.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + fmla v6.8h, v1.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK16 + Write16: + st1 {v5.8h}, [x20], #16 + st1 {v6.8h}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #32 // add 32B + subs x11, x11, #16 + beq EndLoopC + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC8: + dup v5.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK8: + ld1 {v0.8h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.8h, v0.8h, v4.h[0] + subs x9, x9, #1 + bne LoopK8 + Write8: + st1 {v5.8h}, [x20], #16 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #16 // add 16B + subs x11, x11, #8 + beq EndLoopC + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC4: + dup v5.4h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK4: + ld1 {v0.4h}, [x16], x13 + ldr h4, [x12] + add x12, x12, x10 + fmla v5.4h, v0.4h, v4.h[0] + subs x9, x9, #1 + bne LoopK4 + Write4: + st1 {v5.4h}, [x20], #8 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #8 // add 8B + subs x11, x11, #4 + beq EndLoopC + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC: + dup v5.8h, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK: + ldr h0, [x16] + add x16, x16, x13 + ldr h4, [x12] + add x12, x12, x10 + fmul h0, h0, h4 + fadd h5, h5, h0 + subs x9, x9, #1 + bne LoopK + Write: + str h5, [x20], #2 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #2 // ptr add 2B + subs x11, x11, #1 + beq EndLoopC + b LoopC + + EndLoopC: + add x14, x14, #2 + subs x15, x15, #1 + beq EndLoopN + b LoopN + EndLoopN: + subs x3, x3, #1 + beq EndLoopM + add x0, x0, x21 + b LoopM + + EndLoopM: + sub sp, sp, #48 + st1 {v8.8h}, [sp], #16 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore/lite/nnacl/fp16/matrix_fp16.c b/mindspore/lite/nnacl/fp16/matrix_fp16.c index feeb76ce49..ce8acfba75 100644 --- a/mindspore/lite/nnacl/fp16/matrix_fp16.c +++ b/mindspore/lite/nnacl/fp16/matrix_fp16.c @@ -32,6 +32,24 @@ void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, fl } } +#ifndef ENABLE_ARM64 +void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, + int n, int in_channel) { + int cnt = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + for (int y = 0; y < in_channel; ++y) { + float16_t tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + matrix_c[cnt++] = tmp; + } + } + } +} +#endif + void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, const float16_t *bias, int m, int k, int n) { if (bias == NULL) { diff --git a/mindspore/lite/nnacl/fp16/matrix_fp16.h b/mindspore/lite/nnacl/fp16/matrix_fp16.h index 6834fc8d6c..dfc5f24523 100644 --- a/mindspore/lite/nnacl/fp16/matrix_fp16.h +++ b/mindspore/lite/nnacl/fp16/matrix_fp16.h @@ -26,6 +26,8 @@ void MatrixMultiplyFp16(const float16_t *matrix_a, const float16_t *matrix_b, fl void MatrixMultiplyVecFp16(const float16x8_t *matrix_a, const float16x8_t *matrix_b, float16x8_t *matrix_c, const float16_t *bias, int m, int k, int n); +void MatrixMultiplyWinogradFp16(const float16_t *matix_a, const float16_t *matrix_b, float16_t *matrix_c, int m, int k, + int n, int in_channel); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp16/pack_fp16.c b/mindspore/lite/nnacl/fp16/pack_fp16.c index 27f98cbfdf..3baecff7cf 100644 --- a/mindspore/lite/nnacl/fp16/pack_fp16.c +++ b/mindspore/lite/nnacl/fp16/pack_fp16.c @@ -65,6 +65,14 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1 } // tile num loop } +void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float16_t)); + } + } +} + void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param) { // origin weight format : ohwi int input_channel = conv_param->input_channel_; diff --git a/mindspore/lite/nnacl/fp16/pack_fp16.h b/mindspore/lite/nnacl/fp16/pack_fp16.h index 2f1ad6eebe..fc82ff66a3 100644 --- a/mindspore/lite/nnacl/fp16/pack_fp16.h +++ b/mindspore/lite/nnacl/fp16/pack_fp16.h @@ -31,6 +31,8 @@ void Im2ColPackUnitFp16(float16_t *input_data, ConvParameter *conv_param, float1 void PackWeightToC8Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); +void PackHWCToWHCFp16(const float16_t *src, float16_t *dst, int height, int width, int channel); + void PackWeightToC4Fp16(const float16_t *origin_weight_data, float16_t *packed_weight_data, ConvParameter *conv_param); void PackNHWCToNC4HW4Fp16(const void *src, void *dst, int batch, int plane, int channel); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc index e75845ae12..2aa733b4cf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc @@ -36,91 +36,106 @@ using mindspore::schema::PrimitiveType_Conv2D; namespace mindspore::kernel { int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g, float *matrix_gt, int oc_block) { + if (oc_block == 0) { + MS_LOG(ERROR) << "Divide by zero"; + return RET_ERROR; + } // original weight format : ohwi auto channel_in = conv_param_->input_channel_; auto channel_out = conv_param_->output_channel_; - int input_unit_square = input_unit_ * input_unit_; int oc_block_num = UP_DIV(channel_out, oc_block); + int block_stride = channel_in * oc_block; + int block_num_stride = block_stride * oc_block_num; - auto matrix_g_data_fp16 = reinterpret_cast(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t))); - if (matrix_g_data_fp16 == nullptr) { - MS_LOG(ERROR) << "malloc matrix_g_data_fp16 failed."; - return RET_ERROR; - } auto matrix_gt_data_fp16 = reinterpret_cast(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t))); if (matrix_gt_data_fp16 == nullptr) { - free(matrix_g_data_fp16); MS_LOG(ERROR) << "malloc matrix_gt_data_fp16 failed."; return RET_ERROR; } - Float32ToFloat16(matrix_g, matrix_g_data_fp16, input_unit_ * kernel_unit_); Float32ToFloat16(matrix_gt, matrix_gt_data_fp16, input_unit_ * kernel_unit_); - // trans_filter = G*g*GT (g represents weight_data) - // separate into two steps ===> tmp = G*g ===> out = tmp * GT - auto tmp_weight_data = reinterpret_cast(malloc(kernel_unit_ * kernel_unit_ * sizeof(float16_t))); - if (tmp_weight_data == nullptr) { - free(matrix_g_data_fp16); - free(matrix_gt_data_fp16); - MS_LOG(ERROR) << "malloc tmp_weight_data failed."; - return RET_ERROR; - } - auto tmp_data = reinterpret_cast(malloc(input_unit_ * kernel_unit_ * sizeof(float16_t))); + // trans_filter = G*g*GT (g represents weight_data) = [(g * (G)T)T * (G)T]T + // separate into two steps ===> tmp = (g * (G)T)T ===> out = [tmp * (G)T]T + auto tmp_data = reinterpret_cast(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float16_t))); if (tmp_data == nullptr) { - free(tmp_weight_data); - free(matrix_g_data_fp16); free(matrix_gt_data_fp16); MS_LOG(ERROR) << "malloc tmp_data failed."; return RET_ERROR; } - auto trans_out_data = reinterpret_cast(malloc(input_unit_ * input_unit_ * sizeof(float16_t))); + auto trans_out_data = + reinterpret_cast(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float16_t))); if (trans_out_data == nullptr) { free(tmp_data); - free(tmp_weight_data); - free(matrix_g_data_fp16); free(matrix_gt_data_fp16); MS_LOG(ERROR) << "malloc trans_out_data failed."; return RET_ERROR; } - if (oc_block == 0) { - MS_LOG(ERROR) << "Divide by zero"; - free(tmp_weight_data); +#ifndef ENABLE_ARM64 + auto tmp_data1 = reinterpret_cast(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float16_t))); + if (tmp_data1 == nullptr) { free(tmp_data); + free(matrix_gt_data_fp16); free(trans_out_data); - free(matrix_g_data_fp16); + MS_LOG(ERROR) << "malloc tmp_data1 failed."; + return RET_ERROR; + } + auto trans_out_data1 = + reinterpret_cast(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float16_t))); + if (trans_out_data1 == nullptr) { + free(tmp_data); + free(tmp_data1); free(matrix_gt_data_fp16); + free(trans_out_data); + MS_LOG(ERROR) << "malloc trans_out_data1 failed."; return RET_ERROR; } - int stride1 = channel_in * oc_block; +#endif + + int input_oz_offset = kernel_unit_ * kernel_unit_ * channel_in; for (int i = 0; i < channel_out; i++) { int out_c_block = i / oc_block; int out_c_res = i % oc_block; - int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in; - int output_oz_offset = out_c_block * stride1 + out_c_res; - for (int j = 0; j < channel_in; j++) { - int input_iz_offset = input_oz_offset + j; - int output_iz_offset = output_oz_offset + j * oc_block; - for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) { - int input_xy_offset = input_iz_offset + k * channel_in; - tmp_weight_data[k] = *(weight_data + input_xy_offset); - } - // now we only support row-major matrix-multiply - // tmp = G * g - MatrixMultiplyFp16(matrix_g_data_fp16, tmp_weight_data, tmp_data, input_unit_, kernel_unit_, kernel_unit_); - // out = tmp * GT - MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_); - - for (int z = 0; z < input_unit_square; z++) { - int output_xy_offset = output_iz_offset + z * oc_block_num * stride1; - trans_weight_[output_xy_offset] = trans_out_data[z]; + int output_oz_offset = out_c_block * block_stride + out_c_res; + +#ifndef ENABLE_ARM64 + // tmp_data = g * GT + MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit_, + kernel_unit_, input_unit_, channel_in); + // tmp_data1 = (tmp_data)T + PackHWCToWHCFp16(tmp_data, tmp_data1, kernel_unit_, input_unit_, channel_in); + // trans_out_data1 = tmp * GT + MatrixMultiplyWinogradFp16(tmp_data1, matrix_gt_data_fp16, trans_out_data1, input_unit_, kernel_unit_, input_unit_, + channel_in); + // trans_out_data = (trans_out_data1)T + PackHWCToWHCFp16(trans_out_data1, trans_out_data, input_unit_, input_unit_, channel_in); +#else + // tmp = (g * GT)T + MatrixMultiplyWinogradFp16(weight_data + i * input_oz_offset, matrix_gt_data_fp16, tmp_data, kernel_unit_, + kernel_unit_, input_unit_, channel_in); + // trans = (tmp * GT)T + MatrixMultiplyWinogradFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_, + channel_in); +#endif + + int in_offset = 0; + for (int j = 0; j < input_unit_; ++j) { + for (int k = 0; k < input_unit_; ++k) { + for (int c = 0; c < channel_in; ++c) { + *(trans_weight_ + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; + } + in_offset += channel_in; + output_oz_offset += block_num_stride; } } } - free(tmp_weight_data); + +#ifndef ENABLE_ARM64 + free(tmp_data1); + free(trans_out_data1); +#endif free(tmp_data); free(trans_out_data); - free(matrix_g_data_fp16); free(matrix_gt_data_fp16); return RET_OK; }