From d994d3430f536879b511159c31c7f73834ce5ed7 Mon Sep 17 00:00:00 2001 From: liuzhongkai Date: Mon, 28 Sep 2020 15:16:51 +0800 Subject: [PATCH] windgrad init optimize --- .../nnacl/assembly/arm64/MatmulWinogradFp32.S | 172 ++++++++++++++++++ .../lite/nnacl/minimal_filtering_generator.c | 19 ++ .../lite/nnacl/minimal_filtering_generator.h | 2 + mindspore/lite/nnacl/pack.c | 8 + mindspore/lite/nnacl/pack.h | 2 + .../kernel/arm/fp32/convolution_winograd.cc | 106 ++++++----- 6 files changed, 264 insertions(+), 45 deletions(-) create mode 100644 mindspore/lite/nnacl/assembly/arm64/MatmulWinogradFp32.S diff --git a/mindspore/lite/nnacl/assembly/arm64/MatmulWinogradFp32.S b/mindspore/lite/nnacl/assembly/arm64/MatmulWinogradFp32.S new file mode 100644 index 0000000000..0c4356697f --- /dev/null +++ b/mindspore/lite/nnacl/assembly/arm64/MatmulWinogradFp32.S @@ -0,0 +1,172 @@ +#ifdef __aarch64__ + +.text +.align 5 +.global MatrixMultiplyWinograd +#ifndef __APPLE__ +.type MatrixMultiplyWinograd, %function +#endif + +// MatrixMultiplyWinograd(float *matix_a, float *matrix_b, float *matrix_c, int m, int k, int n, int in_channel, int c4_channel) + // x0: matrix_a, x1: matrix_b, x2: matrix_c, x3: m, x4: k, x5: n, x6: in_channel, x7: c4_channel +MatrixMultiplyWinograd: + // registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to + // https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers + // x19 ~ x29 should be also preserved + // whereas our coding style do not permit such amount of parameters + sub sp, sp, #48 + st1 {v8.4s}, [sp], #16 + stp x19, x20, [sp], #16 + stp x21, x22, [sp], #16 + mov x8, #4 + mul x10, x5, x8 + mov x17, x3 // m + mul x13, x6, x8 // in_channel * 4 + mul x21, x13, x4 // in_channel * k + + LoopM: + mov x15, x5 // n + mov x14, x1 // mat_b + LoopN: + mov x16, x0 // mat_a_m + sub x18, x5, x15 // ni + sub x19, x17, x3 // mi + mul x18, x18, x17 // ni * m + mov x11, x6 // in_channel + add x18, x18, x19 // (ni * m) + mi + mul x18, x18, x7 // x18 * c4_channel + add x20, x2, x18 // dst + offset + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + b EndLoopC + LoopC16: + mov x12, x14 + mov x9, x4 // new_k + dup v5.4s, wzr + dup v6.4s, wzr + dup v7.4s, wzr + dup v8.4s, wzr + LoopK16: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + fmla v6.4s, v1.4s, v4.s[0] + fmla v7.4s, v2.4s, v4.s[0] + fmla v8.4s, v3.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK16 + Write16: + st1 {v5.4s}, [x20], #16 + st1 {v6.4s}, [x20], #16 + st1 {v7.4s}, [x20], #16 + st1 {v8.4s}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #64 // add 64B + subs x11, x11, #16 + beq EndLoopC + cmp x11, #16 + bge LoopC16 + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC8: + dup v5.4s, wzr + dup v6.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK8: + ld1 {v0.4s, v1.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + fmla v6.4s, v1.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK8 + Write8: + st1 {v5.4s}, [x20], #16 + st1 {v6.4s}, [x20], #16 + + sub x16, x16, x21 // back x13 * k + add x16, x16, #32 // add 64B + subs x11, x11, #8 + beq EndLoopC + cmp x11, #8 + bge LoopC8 + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC4: + dup v5.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK4: + ld1 {v0.4s}, [x16], x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK4 + Write4: + st1 {v5.4s}, [x20], #16 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #16 // add 16B + subs x11, x11, #4 + beq EndLoopC + cmp x11, #4 + bge LoopC4 + cmp x11, #1 + bge LoopC + + LoopC: + dup v5.4s, wzr + mov x9, x4 // new_k + mov x12, x14 + LoopK: + ldr s0, [x16] + add x16, x16, x13 + ldr s4, [x12] + add x12, x12, x10 + fmla v5.4s, v0.4s, v4.s[0] + subs x9, x9, #1 + bne LoopK + Write1: + str s5, [x20], #4 + + sub x16, x16, x21 // ptr back x13 * k + add x16, x16, #4 // ptr add 4B + subs x11, x11, #1 + beq EndLoopC + b LoopC + + EndLoopC: + add x14, x14, #4 + subs x15, x15, #1 + beq EndLoopN + b LoopN + EndLoopN: + subs x3, x3, #1 + beq EndLoopM + add x0, x0, x21 + b LoopM + EndLoopM: + sub sp, sp, #48 + st1 {v8.4s}, [sp], #16 + ldp x19, x20, [sp], #16 + ldp x21, x22, [sp], #16 + ret +#endif diff --git a/mindspore/lite/nnacl/minimal_filtering_generator.c b/mindspore/lite/nnacl/minimal_filtering_generator.c index ade01d7b16..91b27df177 100644 --- a/mindspore/lite/nnacl/minimal_filtering_generator.c +++ b/mindspore/lite/nnacl/minimal_filtering_generator.c @@ -121,6 +121,25 @@ int B(float *poly_array, float *matrix_b, int in_unit) { return NNACL_OK; } +#ifndef ENABLE_ARM64 +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel) { + int cnt = 0; + for (int i = 0; i < m; ++i) { + for (int j = 0; j < n; ++j) { + for (int y = 0; y < in_channel; ++y) { + float tmp = 0; + for (int z = 0; z < k; ++z) { + tmp += matix_a[z * in_channel + y + i * in_channel * k] * matrix_b[j + z * n]; + } + matrix_c[cnt++] = tmp; + } + cnt += c4_channel / 4 - in_channel; + } + } +} +#endif + void GenerateIntervalArray(float *array, float interval, int degree) { array[0] = 0; for (int i = 1; i < degree; ++i) { diff --git a/mindspore/lite/nnacl/minimal_filtering_generator.h b/mindspore/lite/nnacl/minimal_filtering_generator.h index 95098d5b4d..2936bbedc0 100644 --- a/mindspore/lite/nnacl/minimal_filtering_generator.h +++ b/mindspore/lite/nnacl/minimal_filtering_generator.h @@ -44,6 +44,8 @@ void MatrixMultiply(const float *matrix_a, const float *matrix_b, float *matrix_ int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *matrix_bt, float *matrix_g, float *matrix_gt, float coefficient, int out_unit, int filter_size); +void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n, + int in_channel, int c4_channel); #ifdef ENABLE_ARM void MatrixMultiplyVec(const float32x4_t *matrix_a, const float32x4_t *matrix_b, float32x4_t *matrix_c, diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index 70db272187..fad8eccb44 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -22,6 +22,14 @@ void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) return PackNCHWToNHWCFp32(src, dst, 1, plane, channel); } +void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel) { + for (int i = 0; i < height; ++i) { + for (int j = 0; j < width; ++j) { + memcpy(dst + (j * height + i) * channel, src + (i * width + j) * channel, channel * sizeof(float)); + } + } +} + void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum) { // original weight format : ohwi int kernel_h = conv_param->kernel_h_; diff --git a/mindspore/lite/nnacl/pack.h b/mindspore/lite/nnacl/pack.h index a72cac90b6..d4178d74f0 100644 --- a/mindspore/lite/nnacl/pack.h +++ b/mindspore/lite/nnacl/pack.h @@ -30,6 +30,8 @@ extern "C" { void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, float *packed_input, int real_cal_num, int block_index); +void PackHWCToWHC(const float *src, float *dst, int height, int width, int channel); + void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index, int32_t *input_sum, ConvParameter *conv_param); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc index e0f483d77e..b1fdd2e48d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc @@ -16,6 +16,7 @@ #include "src/runtime/kernel/arm/fp32/convolution_winograd.h" #include "nnacl/fp32/conv.h" +#include "nnacl/pack.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" @@ -31,78 +32,93 @@ using mindspore::schema::PrimitiveType_Conv2D; namespace mindspore::kernel { int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_data, float *matrix_g, float *matrix_gt, int oc_block) { + if (oc_block == 0) { + MS_LOG(ERROR) << "Divide by zero"; + return RET_ERROR; + } // original weight format : ohwi auto channel_in = conv_param_->input_channel_; auto channel_out = conv_param_->output_channel_; - int input_unit_square = input_unit_ * input_unit_; int ic4 = UP_DIV(channel_in, C4NUM); int oc_block_num = UP_DIV(channel_out, oc_block); + int c4_channel = ic4 * C4NUM; + int block_stride = c4_channel * oc_block; + int block_num_stride = block_stride * oc_block_num; // trans_filter = G*g*GT (g represents weight_data) - // separate into two steps ===> tmp = G*g ===> out = tmp * GT - auto tmp_weight_data = reinterpret_cast(malloc(kernel_unit_ * kernel_unit_ * sizeof(float))); - if (tmp_weight_data == nullptr) { - MS_LOG(ERROR) << "malloc tmp_weight_data failed."; - return RET_MEMORY_FAILED; - } - auto tmp_data = reinterpret_cast(malloc(input_unit_ * kernel_unit_ * sizeof(float))); + // separate into two steps ===> tmp = (g * GT)T ===> trans = (tmp * GT)T use same function:MatrixMultiplyWinograd + auto tmp_data = reinterpret_cast(malloc(c4_channel * input_unit_ * kernel_unit_ * sizeof(float))); if (tmp_data == nullptr) { - free(tmp_weight_data); MS_LOG(ERROR) << "malloc tmp_data failed."; return RET_MEMORY_FAILED; } - auto trans_out_data = reinterpret_cast(malloc(input_unit_ * input_unit_ * sizeof(float))); + memset(tmp_data, 0, c4_channel * input_unit_ * kernel_unit_ * sizeof(float)); + auto trans_out_data = reinterpret_cast(malloc(c4_channel * input_unit_ * input_unit_ * sizeof(float))); if (trans_out_data == nullptr) { free(tmp_data); - free(tmp_weight_data); MS_LOG(ERROR) << "malloc trans_out_data failed."; return RET_MEMORY_FAILED; } - std::vector shape{input_unit_ * input_unit_, oc_block_num, ic4, C4NUM, oc_block}; - std::vector strides; - for (int i = 0; i < 4; i++) { - int stride = 1; - for (int j = i + 1; j < 5; j++) { - stride *= shape[j]; - } - strides.push_back(stride); - } - int kernel_plane_stride = channel_in; - if (oc_block == 0) { - MS_LOG(ERROR) << "Divide by zero"; - free(tmp_weight_data); +#ifndef ENABLE_ARM64 + auto tmp_data1 = reinterpret_cast(malloc(c4_channel * input_unit_ * kernel_unit_ * sizeof(float))); + if (tmp_data1 == nullptr) { free(tmp_data); free(trans_out_data); - return RET_ERROR; + MS_LOG(ERROR) << "malloc tmp_data1 failed."; + return RET_MEMORY_FAILED; } + auto trans_out_data1 = reinterpret_cast(malloc(c4_channel * input_unit_ * input_unit_ * sizeof(float))); + if (trans_out_data1 == nullptr) { + free(tmp_data); + free(tmp_data1); + free(trans_out_data); + MS_LOG(ERROR) << "malloc trans_out_data1 failed."; + return RET_MEMORY_FAILED; + } +#endif + + int input_oz_offset = kernel_unit_ * kernel_unit_ * channel_in; for (int i = 0; i < channel_out; i++) { int out_c_block = i / oc_block; int out_c_res = i % oc_block; - int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in; - int output_oz_offset = out_c_block * strides[1] + out_c_res; - for (int j = 0; j < channel_in; j++) { - int ic4_block = j / C4NUM; - int ic4_res = j % C4NUM; - int input_iz_offset = input_oz_offset + j; - int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3]; - for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) { - int input_xy_offset = input_iz_offset + k * kernel_plane_stride; - tmp_weight_data[k] = *(weight_data + input_xy_offset); - } - // now we only support row-major matrix-multiply - // tmp = G * g - MatrixMultiply(matrix_g, tmp_weight_data, tmp_data, input_unit_, kernel_unit_, kernel_unit_); - // out = tmp * GT - MatrixMultiply(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_); + int output_oz_offset = out_c_block * block_stride + out_c_res; - for (int z = 0; z < input_unit_square; z++) { - int output_xy_offset = output_iz_offset + z * strides[0]; - *(trans_weight_ + output_xy_offset) = trans_out_data[z]; +#ifndef ENABLE_ARM64 + // tmp_data = g * GT + MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_, + input_unit_, channel_in, c4_channel * 4); + // tmp_data1 = (tmp_data)T + PackHWCToWHC(tmp_data, tmp_data1, kernel_unit_, input_unit_, c4_channel); + // trans_out_data1 = tmp * GT + MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit_, kernel_unit_, input_unit_, c4_channel, + c4_channel * 4); + // trans_out_data = (trans_out_data1)T + PackHWCToWHC(trans_out_data1, trans_out_data, input_unit_, input_unit_, c4_channel); +#else + // tmp = (g * GT)T + MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_, + input_unit_, channel_in, c4_channel * 4); + // trans = (tmp * GT)T + MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_, c4_channel, + c4_channel * 4); +#endif + + int in_offset = 0; + for (int j = 0; j < input_unit_; ++j) { + for (int k = 0; k < input_unit_; ++k) { + for (int c = 0; c < c4_channel; ++c) { + *(trans_weight_ + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; + } + in_offset += c4_channel; + output_oz_offset += block_num_stride; } } } - free(tmp_weight_data); +#ifndef ENABLE_ARM64 + free(tmp_data1); + free(trans_out_data1); +#endif free(tmp_data); free(trans_out_data); return RET_OK;