From 1f9a122f17aeaf4fd104843ba0bf5c9538295f77 Mon Sep 17 00:00:00 2001 From: fuzhiye Date: Tue, 29 Sep 2020 16:26:46 +0800 Subject: [PATCH] replace gemm with matmul for fp16 conv winograd --- mindspore/lite/nnacl/fp16/conv_fp16.c | 18 +- mindspore/lite/nnacl/fp16/matmul_fp16.c | 44 +- mindspore/lite/nnacl/fp16/matmul_fp16.h | 3 + .../lite/nnacl/fp16/winograd_transform_fp16.c | 32 +- .../lite/nnacl/fp16/winograd_utils_fp16.c | 1487 +++++++++++++---- .../lite/nnacl/fp16/winograd_utils_fp16.h | 193 ++- .../kernel/arm/fp16/convolution_fp16.cc | 30 +- .../arm/fp16/convolution_winograd_fp16.cc | 38 +- .../arm/fp16/convolution_winograd_fp16.h | 7 +- 9 files changed, 1435 insertions(+), 417 deletions(-) diff --git a/mindspore/lite/nnacl/fp16/conv_fp16.c b/mindspore/lite/nnacl/fp16/conv_fp16.c index e7a5291c48..91eab4bb15 100644 --- a/mindspore/lite/nnacl/fp16/conv_fp16.c +++ b/mindspore/lite/nnacl/fp16/conv_fp16.c @@ -170,7 +170,6 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa int input_unit = conv_param->input_unit_; int in_batch = conv_param->input_batch_; int in_channel = conv_param->input_channel_; - int ic8 = UP_DIV(in_channel, C8NUM); int out_unit = conv_param->output_unit_; int out_w_block = UP_DIV(conv_param->output_w_, out_unit); int out_h_block = UP_DIV(conv_param->output_h_, out_unit); @@ -179,18 +178,19 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa int out_channel = conv_param->output_channel_; int oc8 = UP_DIV(out_channel, C8NUM); int input_unit_square = input_unit * input_unit; - size_t output_offset = oc8 * C8NUM * input_unit_square * sizeof(float16_t); float16_t *trans_input = buffer_list[0]; float16_t *gemm_out = buffer_list[1]; float16_t *tmp_data = buffer_list[2]; - int trans_input_offset = tile_num * input_unit_square * ic8 * C8NUM; + float16_t *col_buffer = buffer_list[3]; + int trans_input_offset = tile_num * input_unit_square * in_channel; int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; int tmp_data_offset = input_unit_square * C8NUM; + int col_buffer_offset = tile_num * in_channel; // step 1 : filter transform (pre-processed offline) // step 2 : input transform (online) for (int b = 0; b < in_batch; b++) { - int in_batch_offset = b * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; + int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { int out_tile_index = thread_id * tile_num; @@ -200,8 +200,14 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, in_func); // step 3 : gemm - IndirectGemmFp16_16x8(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset, - trans_weight, NULL, input_unit_square, ic8 * 2, oc8 * C8NUM, output_offset, 1, 1, 0, 0); + float16_t *src_ptr = trans_input + task_id * trans_input_offset; + float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset; + float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; + for (int i = 0; i < input_unit_square; ++i) { + RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, tile_num, in_channel); + MatMul16x8(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, + cal_num, oc8 * C8NUM, input_unit_square, false); + } // step 4 : output transform WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, bias_data, diff --git a/mindspore/lite/nnacl/fp16/matmul_fp16.c b/mindspore/lite/nnacl/fp16/matmul_fp16.c index 095b829518..beb62bb043 100644 --- a/mindspore/lite/nnacl/fp16/matmul_fp16.c +++ b/mindspore/lite/nnacl/fp16/matmul_fp16.c @@ -41,8 +41,8 @@ void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, int deep, int row, int col, int stride, bool write_nhwc) { - int row_16 = UP_ROUND(row, C16NUM); - int col_8 = UP_ROUND(col, C8NUM); + // int row_16 = UP_ROUND(row, C16NUM); + // int col_8 = UP_ROUND(col, C8NUM); if (write_nhwc) { /* col16-major * row8-major => col-major */ for (int r = 0; r < row; r++) { @@ -63,24 +63,42 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl } } } else { - /* col16-major * row8-major => row16x8-major */ - for (int r = 0; r < row_16; r++) { - for (int c = 0; c < col_8; c++) { - int r16div = r / C16NUM, r16mod = r % C16NUM; - int c8div = c / C8NUM, c8mod = c % C8NUM; - size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; - float16_t value = 0; - for (int d = 0; d < deep; d++) { - size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; - size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + for (int i = 0; i < row; ++i) { + int src_r_offset = i; + int dst_r_offset = i * col * stride; + for (int j = 0; j < col; ++j) { + int c8div = j / 8, c8mod = j % 8; + size_t ci = dst_r_offset + c8div * 8 * stride + c8mod; + float value = 0; + for (int d = 0; d < deep; ++d) { + size_t ai = src_r_offset + d * C16NUM; + size_t bi = c8div * deep * 8 + d * 8 + c8mod; value = value + a[ai] * b[bi]; } - if (bias != NULL) value += bias[col]; + if (bias != NULL) value += bias[j]; if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); if (act_type != ActType_No) value = MSMAX(0.0f, value); dst[ci] = value; } } + // /* col16-major * row8-major => row16x8-major */ + // for (int r = 0; r < row_16; r++) { + // for (int c = 0; c < col_8; c++) { + // int r16div = r / C16NUM, r16mod = r % C16NUM; + // int c8div = c / C8NUM, c8mod = c % C8NUM; + // size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod; + // float16_t value = 0; + // for (int d = 0; d < deep; d++) { + // size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod; + // size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod; + // value = value + a[ai] * b[bi]; + // } + // if (bias != NULL) value += bias[col]; + // if (act_type == ActType_Relu6) value = MSMIN(6.0f, value); + // if (act_type != ActType_No) value = MSMAX(0.0f, value); + // dst[ci] = value; + // } + // } } return; } diff --git a/mindspore/lite/nnacl/fp16/matmul_fp16.h b/mindspore/lite/nnacl/fp16/matmul_fp16.h index 11c101d7d0..d7503fff61 100644 --- a/mindspore/lite/nnacl/fp16/matmul_fp16.h +++ b/mindspore/lite/nnacl/fp16/matmul_fp16.h @@ -29,6 +29,9 @@ #ifdef __cplusplus extern "C" { #endif +void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type, + int deep, int row, int col, int stride, bool write_nhwc); + void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type, int depth, int row, int col, int stride, bool write_nhwc); diff --git a/mindspore/lite/nnacl/fp16/winograd_transform_fp16.c b/mindspore/lite/nnacl/fp16/winograd_transform_fp16.c index a284fb58f8..e10a76b45b 100644 --- a/mindspore/lite/nnacl/fp16/winograd_transform_fp16.c +++ b/mindspore/lite/nnacl/fp16/winograd_transform_fp16.c @@ -594,7 +594,7 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); - int dst_plane_offset = c * C4NUM; + int dst_plane_offset = c * in_channel; for (int ic = 0; ic < ic8; ic++) { // clear tmp buffer memset(tmp_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t)); @@ -622,6 +622,30 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in #endif } } + } else if (real_c < 8 && real_c >= 4) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = input_data + src_x_offset; + float16_t *dst_addr = tmp_data + dst_x_offset; + int rc = real_c - 4; +#ifdef ENABLE_NEON + vst1_f16(dst_addr, vld1_f16(src_addr)); +#else + for (int k = 0; k < C4NUM; k++) { + dst_addr[k] = src_addr[k]; + } +#endif + src_addr += 4; + dst_addr += 4; + for (int i = 0; i < rc; ++i) { + dst_addr[i] = src_addr[i]; + } + } + } } else { for (int interval = interval_y_s; interval < interval_y_e; interval++) { int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; @@ -639,10 +663,10 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in } // input transform - int dst_ic8_offset = dst_plane_offset + ic * tile_num * C8NUM; - size_t dst_step = ic8 * C8NUM * tile_num; + int dst_ic8_offset = dst_plane_offset + ic * C8NUM; + size_t dst_step = in_channel * tile_num; float16_t *trans_input_ptr = trans_input + dst_ic8_offset; - func(tmp_data, trans_input_ptr, C8NUM, dst_step); + func(tmp_data, trans_input_ptr, C8NUM, dst_step, real_c); } out_tile_index++; } // cal_tile_num loop diff --git a/mindspore/lite/nnacl/fp16/winograd_utils_fp16.c b/mindspore/lite/nnacl/fp16/winograd_utils_fp16.c index f3d6ab8086..ec7c7e96e3 100644 --- a/mindspore/lite/nnacl/fp16/winograd_utils_fp16.c +++ b/mindspore/lite/nnacl/fp16/winograd_utils_fp16.c @@ -138,122 +138,351 @@ static OutputTransFp16Func OutputTransFp16FuncRelu6List8[] = {NULL, InputTransFp16Func GetInputTransFp16Func(int input_unit) { return InputTransFp16FuncList[input_unit]; } -void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step) { - float16x8_t src[16]; - float16x8_t t[16]; - float16x8_t m[16]; - Load16DataFp16; - for (int l = 0; l < 4; ++l) { - int offset = l * 4; - t[l] = vsubq_f16(src[offset], src[2 + offset]); - t[4 + l] = vaddq_f16(src[1 + offset], src[2 + offset]); - t[8 + l] = vsubq_f16(src[2 + offset], src[1 + offset]); - t[12 + l] = vsubq_f16(src[3 + offset], src[1 + offset]); - } - for (int l = 0; l < 4; ++l) { - int offset = l * 4; - m[l] = vsubq_f16(t[offset], t[2 + offset]); - m[4 + l] = vaddq_f16(t[1 + offset], t[2 + offset]); - m[8 + l] = vsubq_f16(t[2 + offset], t[1 + offset]); - m[12 + l] = vsubq_f16(t[3 + offset], t[1 + offset]); - } - for (int i = 0; i < 16; i++) { - int dst_offset = i * dst_step; - vst1_f16(dst_data + dst_offset, vget_low_f16(m[i])); - vst1_f16(dst_data + dst_offset + 64, vget_high_f16(m[i])); +void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[16]; + float16x8_t t[16]; + float16x8_t m[16]; + Load16DataFp16; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vsubq_f16(src[offset], src[2 + offset]); + t[4 + l] = vaddq_f16(src[1 + offset], src[2 + offset]); + t[8 + l] = vsubq_f16(src[2 + offset], src[1 + offset]); + t[12 + l] = vsubq_f16(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = vsubq_f16(t[offset], t[2 + offset]); + m[4 + l] = vaddq_f16(t[1 + offset], t[2 + offset]); + m[8 + l] = vsubq_f16(t[2 + offset], t[1 + offset]); + m[12 + l] = vsubq_f16(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[16]; + float16x4_t t[16]; + float16x4_t m[16]; + Load16DataC4Fp16; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vsub_f16(src[offset], src[2 + offset]); + t[4 + l] = vadd_f16(src[1 + offset], src[2 + offset]); + t[8 + l] = vsub_f16(src[2 + offset], src[1 + offset]); + t[12 + l] = vsub_f16(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = vsub_f16(t[offset], t[2 + offset]); + m[4 + l] = vadd_f16(t[1 + offset], t[2 + offset]); + m[8 + l] = vsub_f16(t[2 + offset], t[1 + offset]); + m[12 + l] = vsub_f16(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[16]; + float16_t t[16]; + float16_t m[16]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } } } -void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step) { - float16x8_t src[36]; - float16x8_t t[36]; - float16x8_t m[36]; - Load36DataFp16; - for (int l = 0; l < 6; ++l) { - int offset = l * 6; - float16x8_t tmp1 = vsubq_f16(src[3 + offset], src[1 + offset]); - float16x8_t tmp2 = vsubq_f16(src[4 + offset], src[2 + offset]); - t[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 4), vmulq_n_f16(src[2 + offset], 5)), src[4 + offset]); - t[6 + l] = vaddq_f16(vmulq_n_f16(vaddq_f16(src[1 + offset], src[2 + offset]), -4), - vaddq_f16(src[3 + offset], src[4 + offset])); - t[12 + l] = vaddq_f16(vmulq_n_f16(vsubq_f16(src[1 + offset], src[2 + offset]), 4), - vsubq_f16(src[4 + offset], src[3 + offset])); - t[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); - t[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); - t[30 + l] = vaddq_f16(vsubq_f16(vmulq_n_f16(src[1 + offset], 4), vmulq_n_f16(src[3 + offset], 5)), src[5 + offset]); - } - for (int l = 0; l < 6; ++l) { - int offset = l * 6; - float16x8_t tmp1 = vsubq_f16(t[3 + offset], t[1 + offset]); - float16x8_t tmp2 = vsubq_f16(t[4 + offset], t[2 + offset]); - m[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 4), vmulq_n_f16(t[2 + offset], 5)), t[4 + offset]); - m[6 + l] = - vaddq_f16(vmulq_n_f16(vaddq_f16(t[1 + offset], t[2 + offset]), -4), vaddq_f16(t[3 + offset], t[4 + offset])); - m[12 + l] = - vaddq_f16(vmulq_n_f16(vsubq_f16(t[1 + offset], t[2 + offset]), 4), vsubq_f16(t[4 + offset], t[3 + offset])); - m[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); - m[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); - m[30 + l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[1 + offset], 4), vmulq_n_f16(t[3 + offset], 5)), t[5 + offset]); - } - for (int i = 0; i < 36; i++) { - int dst_offset = i * dst_step; - vst1_f16(dst_data + dst_offset, vget_low_f16(m[i])); - vst1_f16(dst_data + dst_offset + 64, vget_high_f16(m[i])); +void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[36]; + float16x8_t t[36]; + float16x8_t m[36]; + Load36DataFp16; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vsubq_f16(src[3 + offset], src[1 + offset]); + float16x8_t tmp2 = vsubq_f16(src[4 + offset], src[2 + offset]); + t[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 4), vmulq_n_f16(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = vaddq_f16(vmulq_n_f16(vaddq_f16(src[1 + offset], src[2 + offset]), -4), + vaddq_f16(src[3 + offset], src[4 + offset])); + t[12 + l] = vaddq_f16(vmulq_n_f16(vsubq_f16(src[1 + offset], src[2 + offset]), 4), + vsubq_f16(src[4 + offset], src[3 + offset])); + t[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + t[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + t[30 + l] = + vaddq_f16(vsubq_f16(vmulq_n_f16(src[1 + offset], 4), vmulq_n_f16(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vsubq_f16(t[3 + offset], t[1 + offset]); + float16x8_t tmp2 = vsubq_f16(t[4 + offset], t[2 + offset]); + m[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 4), vmulq_n_f16(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = + vaddq_f16(vmulq_n_f16(vaddq_f16(t[1 + offset], t[2 + offset]), -4), vaddq_f16(t[3 + offset], t[4 + offset])); + m[12 + l] = + vaddq_f16(vmulq_n_f16(vsubq_f16(t[1 + offset], t[2 + offset]), 4), vsubq_f16(t[4 + offset], t[3 + offset])); + m[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + m[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + m[30 + l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[1 + offset], 4), vmulq_n_f16(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[36]; + float16x4_t t[36]; + float16x4_t m[36]; + Load36DataC4Fp16; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x4_t tmp1 = vsub_f16(src[3 + offset], src[1 + offset]); + float16x4_t tmp2 = vsub_f16(src[4 + offset], src[2 + offset]); + t[l] = vadd_f16(vsub_f16(vmul_n_f16(src[offset], 4), vmul_n_f16(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = vadd_f16(vmul_n_f16(vadd_f16(src[1 + offset], src[2 + offset]), -4), + vadd_f16(src[3 + offset], src[4 + offset])); + t[12 + l] = + vadd_f16(vmul_n_f16(vsub_f16(src[1 + offset], src[2 + offset]), 4), vsub_f16(src[4 + offset], src[3 + offset])); + t[18 + l] = vadd_f16(vmul_n_f16(tmp1, 2), tmp2); + t[24 + l] = vadd_f16(vmul_n_f16(tmp1, -2), tmp2); + t[30 + l] = vadd_f16(vsub_f16(vmul_n_f16(src[1 + offset], 4), vmul_n_f16(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x4_t tmp1 = vsub_f16(t[3 + offset], t[1 + offset]); + float16x4_t tmp2 = vsub_f16(t[4 + offset], t[2 + offset]); + m[l] = vadd_f16(vsub_f16(vmul_n_f16(t[offset], 4), vmul_n_f16(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = + vadd_f16(vmul_n_f16(vadd_f16(t[1 + offset], t[2 + offset]), -4), vadd_f16(t[3 + offset], t[4 + offset])); + m[12 + l] = + vadd_f16(vmul_n_f16(vsub_f16(t[1 + offset], t[2 + offset]), 4), vsub_f16(t[4 + offset], t[3 + offset])); + m[18 + l] = vadd_f16(vmul_n_f16(tmp1, 2), tmp2); + m[24 + l] = vadd_f16(vmul_n_f16(tmp1, -2), tmp2); + m[30 + l] = vadd_f16(vsub_f16(vmul_n_f16(t[1 + offset], 4), vmul_n_f16(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[36]; + float16_t t[36]; + float16_t m[36]; + for (int k = 0; k < 36; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16_t tmp1 = src[3 + offset] - src[1 + offset]; + float16_t tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = src[offset] * 4 - src[2 + offset] * 5 + src[4 + offset]; + t[6 + l] = (src[1 + offset] + src[2 + offset]) * -4 + (src[3 + offset] + src[4 + offset]); + t[12 + l] = (src[1 + offset] - src[2 + offset]) * 4 + (src[4 + offset] - src[3 + offset]); + t[18 + l] = tmp1 * 2 + tmp2; + t[24 + l] = tmp1 * -2 + tmp2; + t[30 + l] = src[1 + offset] * 4 - src[3 + offset] * 5 + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16_t tmp1 = t[3 + offset] - t[1 + offset]; + float16_t tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = t[offset] * 4 - t[2 + offset] * 5 + t[4 + offset]; + m[6 + l] = (t[1 + offset] + t[2 + offset]) * -4 + (t[3 + offset] + t[4 + offset]); + m[12 + l] = (t[1 + offset] - t[2 + offset]) * 4 + (t[4 + offset] - t[3 + offset]); + m[18 + l] = tmp1 * 2 + tmp2; + m[24 + l] = tmp1 * -2 + tmp2; + m[30 + l] = t[1 + offset] * 4 - t[3 + offset] * 5 + t[5 + offset]; + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } } } -void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step) { - float16x8_t src[64]; - float16x8_t t[64]; - float16x8_t m[64]; - Load64DataFp16; - for (int l = 0; l < 8; ++l) { - int offset = l * 8; - t[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 0.5625), vmulq_n_f16(src[2 + offset], 3.0625)), - vmulq_n_f16(src[4 + offset], 3.5)), - src[6 + offset]); - float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 1.125), vmulq_n_f16(src[5 + offset], 0.5)); - float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 2.25), vmulq_n_f16(src[4 + offset], 3.25)); - t[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); - t[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); - tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.5625), src[5 + offset]); - tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.5625), vmulq_n_f16(src[4 + offset], 2.5)); - t[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); - t[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); - tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.375), vmulq_n_f16(src[5 + offset], 1.5)); - tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.25), vmulq_n_f16(src[4 + offset], 1.25)); - t[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); - t[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); - t[56 + l] = - vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(src[1 + offset], -0.5625), vmulq_n_f16(src[3 + offset], 3.0625)), - vmulq_n_f16(src[5 + offset], 3.5)), - src[7 + offset]); - } - for (int l = 0; l < 8; ++l) { - int offset = l * 8; - m[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 0.5625), vmulq_n_f16(t[2 + offset], 3.0625)), - vmulq_n_f16(t[4 + offset], 3.5)), - t[6 + offset]); - float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 1.125), vmulq_n_f16(t[5 + offset], 0.5)); - float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 2.25), vmulq_n_f16(t[4 + offset], 3.25)); - m[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); - m[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); - tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.5625), t[5 + offset]); - tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.5625), vmulq_n_f16(t[4 + offset], 2.5)); - m[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); - m[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); - tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.375), vmulq_n_f16(t[5 + offset], 1.5)); - tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.25), vmulq_n_f16(t[4 + offset], 1.25)); - m[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); - m[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); - m[56 + l] = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(t[1 + offset], -0.5625), vmulq_n_f16(t[3 + offset], 3.0625)), - vmulq_n_f16(t[5 + offset], 3.5)), - t[7 + offset]); - } - for (int i = 0; i < 64; i++) { - int dst_offset = i * dst_step; - vst1_f16(dst_data + dst_offset, vget_low_f16(m[i])); - vst1_f16(dst_data + dst_offset + 64, vget_high_f16(m[i])); +void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c) { + int j = 0; + if (real_c == 8) { + float16x8_t src[64]; + float16x8_t t[64]; + float16x8_t m[64]; + Load64DataFp16; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 0.5625), vmulq_n_f16(src[2 + offset], 3.0625)), + vmulq_n_f16(src[4 + offset], 3.5)), + src[6 + offset]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 1.125), vmulq_n_f16(src[5 + offset], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 2.25), vmulq_n_f16(src[4 + offset], 3.25)); + t[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.5625), vmulq_n_f16(src[4 + offset], 2.5)); + t[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.375), vmulq_n_f16(src[5 + offset], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.25), vmulq_n_f16(src[4 + offset], 1.25)); + t[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = + vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(src[1 + offset], -0.5625), vmulq_n_f16(src[3 + offset], 3.0625)), + vmulq_n_f16(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 0.5625), vmulq_n_f16(t[2 + offset], 3.0625)), + vmulq_n_f16(t[4 + offset], 3.5)), + t[6 + offset]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 1.125), vmulq_n_f16(t[5 + offset], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 2.25), vmulq_n_f16(t[4 + offset], 3.25)); + m[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.5625), vmulq_n_f16(t[4 + offset], 2.5)); + m[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.375), vmulq_n_f16(t[5 + offset], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.25), vmulq_n_f16(t[4 + offset], 1.25)); + m[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = + vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(t[1 + offset], -0.5625), vmulq_n_f16(t[3 + offset], 3.0625)), + vmulq_n_f16(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + vst1q_f16(dst_data + dst_offset, m[i]); + } + real_c -= 8; + } else if (real_c < 8 && real_c >= 4) { + float16x4_t src[64]; + float16x4_t t[64]; + float16x4_t m[64]; + Load64DataC4Fp16; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = vsub_f16(vadd_f16(vsub_f16(vmul_n_f16(src[offset], 0.5625), vmul_n_f16(src[2 + offset], 3.0625)), + vmul_n_f16(src[4 + offset], 3.5)), + src[6 + offset]); + float16x4_t tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 1.125), vmul_n_f16(src[5 + offset], 0.5)); + float16x4_t tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 2.25), vmul_n_f16(src[4 + offset], 3.25)); + t[8 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 0.5625), vmul_n_f16(src[4 + offset], 2.5)); + t[24 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(src[1 + offset], 0.375), vmul_n_f16(src[5 + offset], 1.5)); + tmp2 = vsub_f16(vmul_n_f16(src[2 + offset], 0.25), vmul_n_f16(src[4 + offset], 1.25)); + t[40 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = vadd_f16(vsub_f16(vadd_f16(vmul_n_f16(src[1 + offset], -0.5625), vmul_n_f16(src[3 + offset], 3.0625)), + vmul_n_f16(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = vsub_f16(vadd_f16(vsub_f16(vmul_n_f16(t[offset], 0.5625), vmul_n_f16(t[2 + offset], 3.0625)), + vmul_n_f16(t[4 + offset], 3.5)), + t[6 + offset]); + float16x4_t tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 1.125), vmul_n_f16(t[5 + offset], 0.5)); + float16x4_t tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 2.25), vmul_n_f16(t[4 + offset], 3.25)); + m[8 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 0.5625), vmul_n_f16(t[4 + offset], 2.5)); + m[24 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = vadd_f16(vmul_n_f16(t[1 + offset], 0.375), vmul_n_f16(t[5 + offset], 1.5)); + tmp2 = vsub_f16(vmul_n_f16(t[2 + offset], 0.25), vmul_n_f16(t[4 + offset], 1.25)); + m[40 + l] = vadd_f16(vsub_f16(vadd_f16(tmp1, tmp2), vmul_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = vadd_f16(vadd_f16(vsub_f16(tmp2, tmp1), vmul_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = vadd_f16(vsub_f16(vadd_f16(vmul_n_f16(t[1 + offset], -0.5625), vmul_n_f16(t[3 + offset], 3.0625)), + vmul_n_f16(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, m[i]); + } + j = 4; + } + for (; j < real_c; ++j) { + float16_t src[64]; + float16_t t[64]; + float16_t m[64]; + for (int k = 0; k < 64; ++k) { + src[k] = src_data[j + k * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = src[offset] * 0.5625f - src[2 + offset] * 3.0625f + src[4 + offset] * 3.5f - src[6 + offset]; + float16_t tmp1 = src[1 + offset] * 1.125f + src[5 + offset] * 0.5f; + float16_t tmp2 = src[2 + offset] * 2.25f - src[4 + offset] * 3.25f; + t[8 + l] = tmp1 + tmp2 - src[3 + offset] * 1.625f + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + src[3 + offset] * 1.625f + src[6 + offset]; + tmp1 = src[1 + offset] * 0.5625f + src[5 + offset]; + tmp2 = src[2 + offset] * 0.5625f - src[4 + offset] * 2.5f; + t[24 + l] = tmp1 + tmp2 - src[3 + offset] * 2.5f + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + src[3 + offset] * 2.5f + src[6 + offset]; + tmp1 = src[1 + offset] * 0.375f + src[5 + offset] * 1.5f; + tmp2 = src[2 + offset] * 0.25f - src[4 + offset] * 1.25f; + t[40 + l] = tmp1 + tmp2 - src[3 + offset] * 1.875f + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + src[3 + offset] * 1.875f + src[6 + offset]; + t[56 + l] = src[1 + offset] * -0.5625 + src[3 + offset] * 3.0625f - src[5 + offset] * 3.5f + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = t[offset] * 0.5625f - t[2 + offset] * 3.0625f + t[4 + offset] * 3.5f - t[6 + offset]; + float16_t tmp1 = t[1 + offset] * 1.125f + t[5 + offset] * 0.5f; + float16_t tmp2 = t[2 + offset] * 2.25f - t[4 + offset] * 3.25f; + m[8 + l] = tmp1 + tmp2 - t[3 + offset] * 1.625f + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + t[3 + offset] * 1.625f + t[6 + offset]; + tmp1 = t[1 + offset] * 0.5625f + t[5 + offset]; + tmp2 = t[2 + offset] * 0.5625f - t[4 + offset] * 2.5f; + m[24 + l] = tmp1 + tmp2 - t[3 + offset] * 2.5f + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + t[3 + offset] * 2.5f + t[6 + offset]; + tmp1 = t[1 + offset] * 0.375f + t[5 + offset] * 1.5f; + tmp2 = t[2 + offset] * 0.25f - t[4 + offset] * 1.25f; + m[40 + l] = tmp1 + tmp2 - t[3 + offset] * 1.875f + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + t[3 + offset] * 1.875f + t[6 + offset]; + m[56 + l] = t[1 + offset] * -0.5625 + t[3 + offset] * 3.0625f - t[5 + offset] * 3.5f + t[7 + offset]; + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + dst_data[j + dst_offset] = m[i]; + } } } @@ -289,106 +518,295 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { - float16x8_t src[16]; - float16x8_t t[8]; - float16x8_t m[4]; - Load16DataFp16; - float16x8_t bias_ptr = vld1q_f16(bias_data); - for (int l = 0; l < 4; ++l) { - int offset = l * 4; - t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); - t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); - } - for (int l = 0; l < 2; ++l) { - int offset = l * 4; - m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); - m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); - } - if (r_c == C8NUM && r_h == 2 && r_w == 2) { - Store4DataFp16; - } else { - for (int i = 0; i < r_c; i++) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { for (int j = 0; j < r_h; j++) { int dst_k_offset = j * dst_step * out_c; int m_k_offset = j * 2; for (int k = 0; k < r_w; k++) { - dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } } } } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } } } void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { - float16x8_t src[16]; - float16x8_t t[8]; - float16x8_t m[4]; - float16x8_t zero = vdupq_n_f16(0); - Load16DataFp16; - float16x8_t bias_ptr = vld1q_f16(bias_data); - for (int l = 0; l < 4; ++l) { - int offset = l * 4; - t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); - t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); - } - for (int l = 0; l < 2; ++l) { - int offset = l * 4; - m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); - m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); - m[l] = vmaxq_f16(zero, m[l]); - m[l + 2] = vmaxq_f16(zero, m[l + 2]); - } - if (r_c == C8NUM && r_h == 2 && r_w == 2) { - Store4DataFp16; - } else { - for (int i = 0; i < r_c; i++) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { for (int j = 0; j < r_h; j++) { int dst_k_offset = j * dst_step * out_c; int m_k_offset = j * 2; for (int k = 0; k < r_w; k++) { - dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); } } } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + float16x4_t zero = vdup_n_f16(0); + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l + 2] = vmax_f16(zero, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 2] = m[l + 2] > 0 ? m[l + 2] : 0; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } } } void OutputTransform4x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { - float16x8_t src[16]; - float16x8_t t[8]; - float16x8_t m[4]; - float16x8_t zero = vdupq_n_f16(0); - float16x8_t six = vdupq_n_f16(6); - Load16DataFp16; - float16x8_t bias_ptr = vld1q_f16(bias_data); - for (int l = 0; l < 4; ++l) { - int offset = l * 4; - t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); - t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); - } - for (int l = 0; l < 2; ++l) { - int offset = l * 4; - m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); - m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); - m[l] = vmaxq_f16(zero, m[l]); - m[l] = vminq_f16(six, m[l]); - m[l + 2] = vmaxq_f16(zero, m[l + 2]); - m[l + 2] = vminq_f16(six, m[l + 2]); - } - if (r_c == C8NUM && r_h == 2 && r_w == 2) { - Store4DataFp16; - } else { - for (int i = 0; i < r_c; i++) { + int z = 0; + if (r_c == 8) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { for (int j = 0; j < r_h; j++) { int dst_k_offset = j * dst_step * out_c; int m_k_offset = j * 2; for (int k = 0; k < r_w; k++) { - dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[16]; + float16x4_t t[8]; + float16x4_t m[4]; + float16x4_t zero = vdup_n_f16(0); + float16x4_t six = vdup_n_f16(6); + Load16DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vadd_f16(vadd_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vadd_f16(vsub_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vadd_f16(vadd_f16(vadd_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vadd_f16(vadd_f16(vsub_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l] = vmin_f16(six, m[l]); + m[l + 2] = vmax_f16(zero, m[l + 2]); + m[l + 2] = vmin_f16(six, m[l + 2]); + } + if (r_h == 2 && r_w == 2) { + Store4DataC4Fp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } } } } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[16]; + float16_t t[8]; + float16_t m[4]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] + src[1 + offset] + src[2 + offset]; + t[l + 4] = src[1 + offset] - src[2 + offset] + src[3 + offset]; + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = t[offset] + t[1 + offset] + t[2 + offset] + bias_ptr; + m[l + 2] = t[1 + offset] - t[2 + offset] + t[3 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l] = m[l] < 6 ? m[l] : 6; + m[l + 2] = m[l + 2] > 0 ? m[l + 2] : 0; + m[l + 2] = m[l + 2] < 6 ? m[l + 2] : 6; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } } } @@ -1726,214 +2144,577 @@ void OutputTransform8x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_d void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { - float16x8_t src[64]; - float16x8_t t[48]; - float16x8_t m[36]; - Load64DataFp16; - float16x8_t bias_ptr = vld1q_f16(bias_data); - for (int l = 0; l < 8; ++l) { - int offset = l * 8; - float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); - float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); - float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); - float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); - float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); - float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); - t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); - t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); - t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); - t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); - t[l + 40] = - vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); - } - for (int l = 0; l < 6; ++l) { - int offset = l * 8; - float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); - float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); - float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); - float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); - float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); - float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); - m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); - m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); - m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); - m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); - m[l + 30] = vaddq_f16( - vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), - bias_ptr); - } - if (r_c == C8NUM && r_h == 6 && r_w == 6) { - for (int i = 0; i < 6; i++) { - int dst_k_offset = i * dst_step * out_c; - int m_k_offset = i * 6; - vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); - vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); - vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); - vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); - vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); - vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); } - } else { - for (int i = 0; i < r_c; i++) { + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { for (int j = 0; j < r_h; j++) { int dst_k_offset = j * dst_step * out_c; int m_k_offset = j * 6; for (int k = 0; k < r_w; k++) { - dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); } } } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } } } void OutputTransform8x6ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { - float16x8_t src[64]; - float16x8_t t[48]; - float16x8_t m[36]; - float16x8_t zero = vdupq_n_f16(0); - Load64DataFp16; - float16x8_t bias_ptr = vld1q_f16(bias_data); - for (int l = 0; l < 8; ++l) { - int offset = l * 8; - float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); - float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); - float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); - float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); - float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); - float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); - t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); - t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); - t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); - t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); - t[l + 40] = - vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); - } - for (int l = 0; l < 6; ++l) { - int offset = l * 8; - float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); - float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); - float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); - float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); - float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); - float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); - m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); - m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); - m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); - m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); - m[l + 30] = vaddq_f16( - vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), - bias_ptr); - m[l] = vmaxq_f16(zero, m[l]); - m[l + 6] = vmaxq_f16(zero, m[l + 6]); - m[l + 12] = vmaxq_f16(zero, m[l + 12]); - m[l + 18] = vmaxq_f16(zero, m[l + 18]); - m[l + 24] = vmaxq_f16(zero, m[l + 24]); - m[l + 30] = vmaxq_f16(zero, m[l + 30]); - } - if (r_c == C8NUM && r_h == 6 && r_w == 6) { - for (int i = 0; i < 6; i++) { - int dst_k_offset = i * dst_step * out_c; - int m_k_offset = i * 6; - vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); - vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); - vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); - vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); - vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); - vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); } - } else { - for (int i = 0; i < r_c; i++) { + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 18] = vmaxq_f16(zero, m[l + 18]); + m[l + 24] = vmaxq_f16(zero, m[l + 24]); + m[l + 30] = vmaxq_f16(zero, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { for (int j = 0; j < r_h; j++) { int dst_k_offset = j * dst_step * out_c; int m_k_offset = j * 6; for (int k = 0; k < r_w; k++) { - dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); } } } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + float16x4_t zero = vdup_n_f16(0); + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l + 6] = vmax_f16(zero, m[l + 6]); + m[l + 12] = vmax_f16(zero, m[l + 12]); + m[l + 18] = vmax_f16(zero, m[l + 18]); + m[l + 24] = vmax_f16(zero, m[l + 24]); + m[l + 30] = vmax_f16(zero, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 6] = m[l + 6] > 0 ? m[l + 6] : 0; + m[l + 12] = m[l + 12] > 0 ? m[l + 12] : 0; + m[l + 18] = m[l + 18] > 0 ? m[l + 18] : 0; + m[l + 24] = m[l + 24] > 0 ? m[l + 24] : 0; + m[l + 30] = m[l + 30] > 0 ? m[l + 30] : 0; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } } } void OutputTransform8x6Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { - float16x8_t src[64]; - float16x8_t t[48]; - float16x8_t m[36]; - float16x8_t zero = vdupq_n_f16(0); - float16x8_t six = vdupq_n_f16(6); - Load64DataFp16; - float16x8_t bias_ptr = vld1q_f16(bias_data); - for (int l = 0; l < 8; ++l) { - int offset = l * 8; - float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); - float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); - float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); - float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); - float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); - float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); - t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); - t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); - t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); - t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); - t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); - t[l + 40] = - vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); - } - for (int l = 0; l < 6; ++l) { - int offset = l * 8; - float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); - float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); - float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); - float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); - float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); - float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); - m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); - m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); - m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); - m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); - m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); - m[l + 30] = vaddq_f16( - vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), - bias_ptr); - m[l] = vmaxq_f16(zero, m[l]); - m[l] = vminq_f16(six, m[l]); - m[l + 6] = vmaxq_f16(zero, m[l + 6]); - m[l + 6] = vminq_f16(six, m[l + 6]); - m[l + 12] = vmaxq_f16(zero, m[l + 12]); - m[l + 12] = vminq_f16(six, m[l + 12]); - m[l + 18] = vmaxq_f16(zero, m[l + 18]); - m[l + 18] = vminq_f16(six, m[l + 18]); - m[l + 24] = vmaxq_f16(zero, m[l + 24]); - m[l + 24] = vminq_f16(six, m[l + 24]); - m[l + 30] = vmaxq_f16(zero, m[l + 30]); - m[l + 30] = vminq_f16(six, m[l + 30]); - } - if (r_c == C8NUM && r_h == 6 && r_w == 6) { - for (int i = 0; i < 6; i++) { - int dst_k_offset = i * dst_step * out_c; - int m_k_offset = i * 6; - vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); - vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); - vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); - vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); - vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); - vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + int z = 0; + if (r_c == 8) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); } - } else { - for (int i = 0; i < r_c; i++) { + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + m[l + 18] = vmaxq_f16(zero, m[l + 18]); + m[l + 18] = vminq_f16(six, m[l + 18]); + m[l + 24] = vmaxq_f16(zero, m[l + 24]); + m[l + 24] = vminq_f16(six, m[l + 24]); + m[l + 30] = vmaxq_f16(zero, m[l + 30]); + m[l + 30] = vminq_f16(six, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { for (int j = 0; j < r_h; j++) { int dst_k_offset = j * dst_step * out_c; int m_k_offset = j * 6; for (int k = 0; k < r_w; k++) { - dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + vst1q_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); + } + } + } + r_c -= 8; + } else if (r_c < 8 && r_c >= 4) { + float16x4_t src[64]; + float16x4_t t[48]; + float16x4_t m[36]; + float16x4_t zero = vdup_n_f16(0); + float16x4_t six = vdup_n_f16(6); + Load64DataC4Fp16; + float16x4_t bias_ptr = vld1_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp2 = vadd_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp3 = vadd_f16(src[5 + offset], src[6 + offset]); + float16x4_t tmp4 = vsub_f16(src[1 + offset], src[2 + offset]); + float16x4_t tmp5 = vsub_f16(src[3 + offset], src[4 + offset]); + float16x4_t tmp6 = vsub_f16(src[5 + offset], src[6 + offset]); + t[l] = vadd_f16(vadd_f16(vadd_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)); + t[l + 16] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)); + t[l + 24] = vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)); + t[l + 32] = vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)); + t[l + 40] = + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x4_t tmp1 = vadd_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp2 = vadd_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp3 = vadd_f16(t[5 + offset], t[6 + offset]); + float16x4_t tmp4 = vsub_f16(t[1 + offset], t[2 + offset]); + float16x4_t tmp5 = vsub_f16(t[3 + offset], t[4 + offset]); + float16x4_t tmp6 = vsub_f16(t[5 + offset], t[6 + offset]); + m[l] = vadd_f16(vadd_f16(vadd_f16(vadd_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.5), tmp5), vmul_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.25), tmp2), vmul_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.125), tmp5), vmul_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp1, 0.0625), tmp2), vmul_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vadd_f16( + vadd_f16(vadd_f16(vadd_f16(vmul_n_f16(tmp4, 0.03125), tmp5), vmul_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmax_f16(zero, m[l]); + m[l] = vmin_f16(six, m[l]); + m[l + 6] = vmax_f16(zero, m[l + 6]); + m[l + 6] = vmin_f16(six, m[l + 6]); + m[l + 12] = vmax_f16(zero, m[l + 12]); + m[l + 12] = vmin_f16(six, m[l + 12]); + m[l + 18] = vmax_f16(zero, m[l + 18]); + m[l + 18] = vmin_f16(six, m[l + 18]); + m[l + 24] = vmax_f16(zero, m[l + 24]); + m[l + 24] = vmin_f16(six, m[l + 24]); + m[l + 30] = vmax_f16(zero, m[l + 30]); + m[l + 30] = vmin_f16(six, m[l + 30]); + } + if (r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + vst1_f16(dst_data + dst_k_offset + k * out_c, m[k + m_k_offset]); } } } + z = 4; + } + for (; z < r_c; ++z) { + float16_t src[64]; + float16_t t[48]; + float16_t m[36]; + for (int k = 0; k < 16; ++k) { + src[k] = src_data[z + k * src_step]; + } + float16_t bias_ptr = bias_data[z]; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16_t tmp1 = src[1 + offset] + src[2 + offset]; + float16_t tmp2 = src[3 + offset] + src[4 + offset]; + float16_t tmp3 = src[5 + offset] + src[6 + offset]; + float16_t tmp4 = src[1 + offset] - src[2 + offset]; + float16_t tmp5 = src[3 + offset] - src[4 + offset]; + float16_t tmp6 = src[5 + offset] - src[6 + offset]; + t[l] = src[offset] + tmp1 + tmp2 + tmp3; + t[l + 8] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f; + t[l + 16] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f; + t[l + 24] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f; + t[l + 32] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f; + t[l + 40] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + src[7 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16_t tmp1 = t[1 + offset] + t[2 + offset]; + float16_t tmp2 = t[3 + offset] + t[4 + offset]; + float16_t tmp3 = t[5 + offset] + t[6 + offset]; + float16_t tmp4 = t[1 + offset] - t[2 + offset]; + float16_t tmp5 = t[3 + offset] - t[4 + offset]; + float16_t tmp6 = t[5 + offset] - t[6 + offset]; + m[l] = t[offset] + tmp1 + tmp2 + tmp3 + bias_ptr; + m[l + 6] = tmp4 * 0.5f + tmp5 + tmp6 * 1.5f + bias_ptr; + m[l + 12] = tmp1 * 0.25f + tmp2 + tmp3 * 2.25f + bias_ptr; + m[l + 18] = tmp4 * 0.125f + tmp5 + tmp6 * 3.375f + bias_ptr; + m[l + 24] = tmp1 * 0.0625f + tmp2 + tmp3 * 5.0625f + bias_ptr; + m[l + 30] = tmp4 * 0.03125f + tmp5 + tmp6 * 7.59375f + t[7 + offset] + bias_ptr; + m[l] = m[l] > 0 ? m[l] : 0; + m[l] = m[l] > 0 ? m[l] : 0; + m[l + 6] = m[l + 6] > 0 ? m[l + 6] : 0; + m[l + 6] = m[l + 6] < 6 ? m[l + 6] : 6; + m[l + 12] = m[l + 12] > 0 ? m[l + 12] : 0; + m[l + 12] = m[l + 12] < 6 ? m[l + 12] : 6; + m[l + 18] = m[l + 18] > 0 ? m[l + 18] : 0; + m[l + 18] = m[l + 18] < 6 ? m[l + 18] : 6; + m[l + 24] = m[l + 24] > 0 ? m[l + 24] : 0; + m[l + 24] = m[l + 24] < 6 ? m[l + 24] : 6; + m[l + 30] = m[l + 30] > 0 ? m[l + 30] : 0; + m[l + 30] = m[l + 30] < 6 ? m[l + 30] : 6; + } + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[z + dst_k_offset + k * out_c] = m[k + m_k_offset]; + } + } } } diff --git a/mindspore/lite/nnacl/fp16/winograd_utils_fp16.h b/mindspore/lite/nnacl/fp16/winograd_utils_fp16.h index ac057e4b79..fda291e2dc 100644 --- a/mindspore/lite/nnacl/fp16/winograd_utils_fp16.h +++ b/mindspore/lite/nnacl/fp16/winograd_utils_fp16.h @@ -26,7 +26,8 @@ #ifdef __cplusplus extern "C" { #endif -typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); +typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, + int real_c); typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); @@ -56,6 +57,24 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da src[14] = vld1q_f16(src_data + 14 * src_step); \ src[15] = vld1q_f16(src_data + 15 * src_step); +#define Load16DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); + #define Load36DataFp16 \ src[0] = vld1q_f16(src_data + 0 * src_step); \ src[1] = vld1q_f16(src_data + 1 * src_step); \ @@ -94,6 +113,44 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da src[34] = vld1q_f16(src_data + 34 * src_step); \ src[35] = vld1q_f16(src_data + 35 * src_step); +#define Load36DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); \ + src[16] = vld1_f16(src_data + 16 * src_step); \ + src[17] = vld1_f16(src_data + 17 * src_step); \ + src[18] = vld1_f16(src_data + 18 * src_step); \ + src[19] = vld1_f16(src_data + 19 * src_step); \ + src[20] = vld1_f16(src_data + 20 * src_step); \ + src[21] = vld1_f16(src_data + 21 * src_step); \ + src[22] = vld1_f16(src_data + 22 * src_step); \ + src[23] = vld1_f16(src_data + 23 * src_step); \ + src[24] = vld1_f16(src_data + 24 * src_step); \ + src[25] = vld1_f16(src_data + 25 * src_step); \ + src[26] = vld1_f16(src_data + 26 * src_step); \ + src[27] = vld1_f16(src_data + 27 * src_step); \ + src[28] = vld1_f16(src_data + 28 * src_step); \ + src[29] = vld1_f16(src_data + 29 * src_step); \ + src[30] = vld1_f16(src_data + 30 * src_step); \ + src[31] = vld1_f16(src_data + 31 * src_step); \ + src[32] = vld1_f16(src_data + 32 * src_step); \ + src[33] = vld1_f16(src_data + 33 * src_step); \ + src[34] = vld1_f16(src_data + 34 * src_step); \ + src[35] = vld1_f16(src_data + 35 * src_step); + #define Load64DataFp16 \ src[0] = vld1q_f16(src_data + 0 * src_step); \ src[1] = vld1q_f16(src_data + 1 * src_step); \ @@ -160,13 +217,79 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da src[62] = vld1q_f16(src_data + 62 * src_step); \ src[63] = vld1q_f16(src_data + 63 * src_step); +#define Load64DataC4Fp16 \ + src[0] = vld1_f16(src_data + 0 * src_step); \ + src[1] = vld1_f16(src_data + 1 * src_step); \ + src[2] = vld1_f16(src_data + 2 * src_step); \ + src[3] = vld1_f16(src_data + 3 * src_step); \ + src[4] = vld1_f16(src_data + 4 * src_step); \ + src[5] = vld1_f16(src_data + 5 * src_step); \ + src[6] = vld1_f16(src_data + 6 * src_step); \ + src[7] = vld1_f16(src_data + 7 * src_step); \ + src[8] = vld1_f16(src_data + 8 * src_step); \ + src[9] = vld1_f16(src_data + 9 * src_step); \ + src[10] = vld1_f16(src_data + 10 * src_step); \ + src[11] = vld1_f16(src_data + 11 * src_step); \ + src[12] = vld1_f16(src_data + 12 * src_step); \ + src[13] = vld1_f16(src_data + 13 * src_step); \ + src[14] = vld1_f16(src_data + 14 * src_step); \ + src[15] = vld1_f16(src_data + 15 * src_step); \ + src[16] = vld1_f16(src_data + 16 * src_step); \ + src[17] = vld1_f16(src_data + 17 * src_step); \ + src[18] = vld1_f16(src_data + 18 * src_step); \ + src[19] = vld1_f16(src_data + 19 * src_step); \ + src[20] = vld1_f16(src_data + 20 * src_step); \ + src[21] = vld1_f16(src_data + 21 * src_step); \ + src[22] = vld1_f16(src_data + 22 * src_step); \ + src[23] = vld1_f16(src_data + 23 * src_step); \ + src[24] = vld1_f16(src_data + 24 * src_step); \ + src[25] = vld1_f16(src_data + 25 * src_step); \ + src[26] = vld1_f16(src_data + 26 * src_step); \ + src[27] = vld1_f16(src_data + 27 * src_step); \ + src[28] = vld1_f16(src_data + 28 * src_step); \ + src[29] = vld1_f16(src_data + 29 * src_step); \ + src[30] = vld1_f16(src_data + 30 * src_step); \ + src[31] = vld1_f16(src_data + 31 * src_step); \ + src[32] = vld1_f16(src_data + 32 * src_step); \ + src[33] = vld1_f16(src_data + 33 * src_step); \ + src[34] = vld1_f16(src_data + 34 * src_step); \ + src[35] = vld1_f16(src_data + 35 * src_step); \ + src[36] = vld1_f16(src_data + 36 * src_step); \ + src[37] = vld1_f16(src_data + 37 * src_step); \ + src[38] = vld1_f16(src_data + 38 * src_step); \ + src[39] = vld1_f16(src_data + 39 * src_step); \ + src[40] = vld1_f16(src_data + 40 * src_step); \ + src[41] = vld1_f16(src_data + 41 * src_step); \ + src[42] = vld1_f16(src_data + 42 * src_step); \ + src[43] = vld1_f16(src_data + 43 * src_step); \ + src[44] = vld1_f16(src_data + 44 * src_step); \ + src[45] = vld1_f16(src_data + 45 * src_step); \ + src[46] = vld1_f16(src_data + 46 * src_step); \ + src[47] = vld1_f16(src_data + 47 * src_step); \ + src[48] = vld1_f16(src_data + 48 * src_step); \ + src[49] = vld1_f16(src_data + 49 * src_step); \ + src[50] = vld1_f16(src_data + 50 * src_step); \ + src[51] = vld1_f16(src_data + 51 * src_step); \ + src[52] = vld1_f16(src_data + 52 * src_step); \ + src[53] = vld1_f16(src_data + 53 * src_step); \ + src[54] = vld1_f16(src_data + 54 * src_step); \ + src[55] = vld1_f16(src_data + 55 * src_step); \ + src[56] = vld1_f16(src_data + 56 * src_step); \ + src[57] = vld1_f16(src_data + 57 * src_step); \ + src[58] = vld1_f16(src_data + 58 * src_step); \ + src[59] = vld1_f16(src_data + 59 * src_step); \ + src[60] = vld1_f16(src_data + 60 * src_step); \ + src[61] = vld1_f16(src_data + 61 * src_step); \ + src[62] = vld1_f16(src_data + 62 * src_step); \ + src[63] = vld1_f16(src_data + 63 * src_step); + InputTransFp16Func GetInputTransFp16Func(int input_unit); -void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); +void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); -void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); +void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); -void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); +void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c); OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type); @@ -176,6 +299,12 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT vst1q_f16(dst_data + dst_step * out_c, m[2]); \ vst1q_f16(dst_data + dst_step * out_c + out_c, m[3]); +#define Store4DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + dst_step * out_c, m[2]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[3]); + #define Store9DataFp16 \ vst1q_f16(dst_data, m[0]); \ vst1q_f16(dst_data + out_c, m[1]); \ @@ -187,6 +316,17 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); +#define Store9DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + dst_step * out_c, m[3]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[6]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + #define Store16DataFp16 \ vst1q_f16(dst_data, m[0]); \ vst1q_f16(dst_data + out_c, m[1]); \ @@ -205,6 +345,24 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); +#define Store16DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + 3 * out_c, m[3]); \ + vst1_f16(dst_data + dst_step * out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[5]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[8]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + vst1_f16(dst_data + 3 * dst_step * out_c, m[12]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + #define Store25DataFp16 \ vst1q_f16(dst_data, m[0]); \ vst1q_f16(dst_data + out_c, m[1]); \ @@ -232,6 +390,33 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT vst1q_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ vst1q_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); +#define Store25DataC4Fp16 \ + vst1_f16(dst_data, m[0]); \ + vst1_f16(dst_data + out_c, m[1]); \ + vst1_f16(dst_data + 2 * out_c, m[2]); \ + vst1_f16(dst_data + 3 * out_c, m[3]); \ + vst1_f16(dst_data + 4 * out_c, m[4]); \ + vst1_f16(dst_data + dst_step * out_c, m[5]); \ + vst1_f16(dst_data + dst_step * out_c + out_c, m[6]); \ + vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + vst1_f16(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + vst1_f16(dst_data + 2 * dst_step * out_c, m[10]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + vst1_f16(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + vst1_f16(dst_data + 3 * dst_step * out_c, m[15]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + vst1_f16(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + vst1_f16(dst_data + 4 * dst_step * out_c, m[20]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + vst1_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 43fd794501..78f1d509d5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -197,25 +197,27 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & auto conv_param = reinterpret_cast(opParameter); int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; - conv_param->input_h_ = inputs.front()->Height(); - conv_param->input_w_ = inputs.front()->Width(); - conv_param->output_h_ = outputs.front()->Height(); - conv_param->output_w_ = outputs.front()->Width(); + bool use_winograd = false; + int out_unit; + if (primitive != nullptr && primitive->GetInferFlag()) { + conv_param->input_h_ = inputs.front()->Height(); + conv_param->input_w_ = inputs.front()->Width(); + conv_param->input_channel_ = inputs.front()->Channel(); + conv_param->output_h_ = outputs.front()->Height(); + conv_param->output_w_ = outputs.front()->Width(); + conv_param->output_channel_ = outputs.front()->Channel(); + conv_param->op_parameter_.thread_num_ = ctx->thread_num_; + CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); + } kernel::LiteKernel *kernel = nullptr; if (kernel_h == 1 && kernel_w == 1) { kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); + } else if (use_winograd) { + kernel = new (std::nothrow) + kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit); } else { - bool use_winograd = false; - int out_unit; - CheckIfUseWinograd(&use_winograd, &out_unit, conv_param); - if (use_winograd) { - kernel = new (std::nothrow) - kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit); - } - if (kernel_h != 1 && kernel_w != 1 && !use_winograd) { - kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); - } + kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); } if (kernel == nullptr) { MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc index c3a0f9f307..e75845ae12 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc @@ -39,8 +39,6 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ // original weight format : ohwi auto channel_in = conv_param_->input_channel_; auto channel_out = conv_param_->output_channel_; - int ic8 = UP_DIV(channel_in, C8NUM); - int ic4 = ic8 * 2; int input_unit_square = input_unit_ * input_unit_; int oc_block_num = UP_DIV(channel_out, oc_block); @@ -84,17 +82,7 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ MS_LOG(ERROR) << "malloc trans_out_data failed."; return RET_ERROR; } - std::vector shape{input_unit_ * input_unit_, oc_block_num, ic4, C4NUM, oc_block}; - std::vector strides; - for (int i = 0; i < 4; i++) { - int stride = 1; - for (int j = i + 1; j < 5; j++) { - stride *= shape[j]; - } - strides.push_back(stride); - } - int kernel_plane_stride = channel_in; if (oc_block == 0) { MS_LOG(ERROR) << "Divide by zero"; free(tmp_weight_data); @@ -104,18 +92,17 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ free(matrix_gt_data_fp16); return RET_ERROR; } + int stride1 = channel_in * oc_block; for (int i = 0; i < channel_out; i++) { int out_c_block = i / oc_block; int out_c_res = i % oc_block; int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in; - int output_oz_offset = out_c_block * strides[1] * input_unit_ * input_unit_ + out_c_res; + int output_oz_offset = out_c_block * stride1 + out_c_res; for (int j = 0; j < channel_in; j++) { - int ic4_block = j / C4NUM; - int ic4_res = j % C4NUM; int input_iz_offset = input_oz_offset + j; - int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3]; + int output_iz_offset = output_oz_offset + j * oc_block; for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) { - int input_xy_offset = input_iz_offset + k * kernel_plane_stride; + int input_xy_offset = input_iz_offset + k * channel_in; tmp_weight_data[k] = *(weight_data + input_xy_offset); } // now we only support row-major matrix-multiply @@ -125,7 +112,7 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_); for (int z = 0; z < input_unit_square; z++) { - int output_xy_offset = output_iz_offset + z * strides[1]; + int output_xy_offset = output_iz_offset + z * oc_block_num * stride1; trans_weight_[output_xy_offset] = trans_out_data[z]; } } @@ -142,7 +129,6 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { auto filter_tensor = in_tensors_.at(kWeightIndex); int in_channel = filter_tensor->Channel(); int out_channel = filter_tensor->Batch(); - int ic8 = UP_DIV(in_channel, C8NUM); conv_param_->input_channel_ = in_channel; conv_param_->output_channel_ = out_channel; @@ -157,7 +143,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { } // set data - auto trans_matrix_data_size = input_unit_ * input_unit_ * ic8 * C8NUM * oc_block_num * oc_block * sizeof(float16_t); + auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float16_t); trans_weight_ = reinterpret_cast(malloc(trans_matrix_data_size)); if (trans_weight_ == nullptr) { MS_LOG(ERROR) << "malloc trans_weight_ failed."; @@ -209,9 +195,9 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { const int cal_num = 16; int channel_out = conv_param_->output_channel_; int oc8 = UP_DIV(channel_out, C8NUM); - int ic8 = UP_DIV(conv_param_->input_channel_, C8NUM); - size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic8 * C8NUM * sizeof(float16_t); + size_t tile_buffer_size = + thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t); trans_input_ = reinterpret_cast(ctx_->allocator->Malloc(tile_buffer_size)); if (trans_input_ == nullptr) { MS_LOG(ERROR) << "malloc trans_input_ failed."; @@ -232,9 +218,17 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { return RET_ERROR; } + col_buffer_ = reinterpret_cast( + ctx_->allocator->Malloc(thread_count_ * cal_num * conv_param_->input_channel_ * sizeof(float16_t))); + if (col_buffer_ == nullptr) { + MS_LOG(ERROR) << "malloc col_buffer_ failed."; + return RET_ERROR; + } + tmp_buffer_address_list_[0] = trans_input_; tmp_buffer_address_list_[1] = gemm_out_; tmp_buffer_address_list_[2] = tmp_data_; + tmp_buffer_address_list_[3] = col_buffer_; return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h index f4c0e4ddb2..813cafbc10 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h @@ -67,6 +67,10 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { ctx_->allocator->Free(gemm_out_); gemm_out_ = nullptr; } + if (col_buffer_ != nullptr) { + ctx_->allocator->Free(col_buffer_); + col_buffer_ = nullptr; + } } int kernel_unit_; int input_unit_; @@ -75,7 +79,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { float16_t *trans_input_ = nullptr; float16_t *gemm_out_ = nullptr; float16_t *trans_weight_ = nullptr; - TmpBufferAddressFp16 tmp_buffer_address_list_[3]; + float16_t *col_buffer_ = nullptr; + TmpBufferAddressFp16 tmp_buffer_address_list_[4]; InputTransFp16Func in_func_; OutputTransFp16Func out_func_; };