From 2d00b74de2486a15da64e387df77fe8a822f779b Mon Sep 17 00:00:00 2001 From: fuzhiye Date: Mon, 28 Sep 2020 14:50:24 +0800 Subject: [PATCH] optimize winograd input transform func --- mindspore/lite/nnacl/fp32/conv.c | 13 +- mindspore/lite/nnacl/pack.c | 217 ---------- mindspore/lite/nnacl/pack.h | 23 - mindspore/lite/nnacl/winograd_transform.c | 6 +- mindspore/lite/nnacl/winograd_utils.c | 396 +++++++++--------- mindspore/lite/nnacl/winograd_utils.h | 8 +- .../kernel/arm/fp32/convolution_winograd.cc | 44 +- 7 files changed, 238 insertions(+), 469 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/conv.c b/mindspore/lite/nnacl/fp32/conv.c index cb65c3159c..92643dddb8 100644 --- a/mindspore/lite/nnacl/fp32/conv.c +++ b/mindspore/lite/nnacl/fp32/conv.c @@ -77,7 +77,6 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ int input_unit = conv_param->input_unit_; int in_batch = conv_param->input_batch_; int in_channel = conv_param->input_channel_; - int ic4 = UP_DIV(in_channel, C4NUM); int out_unit = conv_param->output_unit_; int out_w_block = UP_DIV(conv_param->output_w_, out_unit); int out_h_block = UP_DIV(conv_param->output_h_, out_unit); @@ -96,10 +95,10 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ float *gemm_out = buffer_list[1]; float *tmp_data = buffer_list[2]; float *col_buffer = buffer_list[3]; - int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM; + int trans_input_offset = tile_num * input_unit_square * in_channel; int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; int tmp_data_offset = input_unit_square * C4NUM; - int col_buffer_offset = tile_num * ic4 * C4NUM; + int col_buffer_offset = tile_num * in_channel; // step 1 : filter transform (pre-processed offline) // step 2 : input transform (online) for (int b = 0; b < in_batch; b++) { @@ -107,7 +106,7 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ int out_batch_offset = b * out_channel * conv_param->output_w_ * conv_param->output_h_; for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { int out_tile_index = thread_id * tile_num; - int cal_num = output_count - thread_id * tile_num; + int cal_num = output_count - out_tile_index; cal_num = cal_num > tile_num ? tile_num : cal_num; WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset, tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, @@ -118,11 +117,11 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; for (int i = 0; i < input_unit_square; ++i) { #ifdef ENABLE_ARM32 - RowMajor2Col4Major(src_ptr + i * C4NUM * ic4 * C4NUM, tmp_col_ptr, C4NUM, ic4 * C4NUM); + RowMajor2Col4Major(src_ptr + i * C4NUM * in_channel, tmp_col_ptr, C4NUM, in_channel); #else - RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM); + RowMajor2Col12Major(src_ptr + i * C12NUM * in_channel, tmp_col_ptr, C12NUM, in_channel); #endif - MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM, + MatMulOpt(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, cal_num, oc8 * C8NUM, input_unit_square, 2); } diff --git a/mindspore/lite/nnacl/pack.c b/mindspore/lite/nnacl/pack.c index fad8eccb44..2f02b65b42 100644 --- a/mindspore/lite/nnacl/pack.c +++ b/mindspore/lite/nnacl/pack.c @@ -630,26 +630,6 @@ void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int c } } -void PackNCHWToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { - int nhwc4_batch_offset = 0; - int c4 = UP_DIV(channel, C4NUM); - int nhwc4_batch_unit_offset = c4 * C4NUM * plane; - - for (int b = 0; b < batch; b++) { - int batch_offset = b * channel * plane; - for (int c = 0; c < channel; c++) { - int src_c_offset = batch_offset + c * plane; - int dst_c_offset = nhwc4_batch_offset + c; - for (int i = 0; i < plane; i++) { - int src_plane_offset = src_c_offset + i; - int dst_plane_offset = dst_c_offset + i * c4 * C4NUM; - ((float *)dst)[dst_plane_offset] = ((float *)src)[src_plane_offset]; - } - } - nhwc4_batch_offset += nhwc4_batch_unit_offset; - } -} - void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel) { int c4 = UP_DIV(channel, C4NUM); for (int b = 0; b < batch; b++) { @@ -700,105 +680,6 @@ void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int } } -void PackNC4HW4ToNHWCReluFp32(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * c4 * C4NUM; - int dst_offset = b * plane * channel; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_offset + k * C4NUM; - int dst_kernel_offset = dst_offset + k * channel; - for (int c = 0; c < c4 - 1; c++) { - int src_c_offset = src_kernel_offset + c * plane * C4NUM; - int dst_c_offset = dst_kernel_offset + c * C4NUM; -#ifdef ENABLE_NEON - float32x4_t input_ptr = vld1q_f32((float *)src + src_c_offset); - float32x4_t zero = vdupq_n_f32(0); - input_ptr = vmaxq_f32(zero, input_ptr); - vst1q_f32((float *)dst + dst_c_offset, input_ptr); -#else - for (int i = 0; i < C4NUM; ++i) { - float input_data = ((float *)src + src_c_offset)[i]; - input_data = input_data < 0 ? 0 : input_data; - ((float *)dst + dst_c_offset)[i] = input_data; - } -#endif - } - // res part - int res_c = channel - (c4 - 1) * C4NUM; - for (int i = 0; i < res_c; i++) { - int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; - int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; - float input_data = ((float *)src + src_res_c_offset)[0]; - input_data = input_data < 0 ? 0 : input_data; - ((float *)dst + dst_res_c_offset)[0] = input_data; - } - } - } -} - -void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * c4 * C4NUM; - int dst_offset = b * plane * channel; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_offset + k * C4NUM; - int dst_kernel_offset = dst_offset + k * channel; - for (int c = 0; c < c4 - 1; c++) { - int src_c_offset = src_kernel_offset + c * plane * C4NUM; - int dst_c_offset = dst_kernel_offset + c * C4NUM; -#ifdef ENABLE_NEON - float32x4_t input_ptr = vld1q_f32((float *)src + src_c_offset); - float32x4_t zero = vdupq_n_f32(0); - float32x4_t six = vdupq_n_f32(6); - input_ptr = vmaxq_f32(zero, input_ptr); - input_ptr = vminq_f32(six, input_ptr); - vst1q_f32((float *)dst + dst_c_offset, input_ptr); -#else - for (int i = 0; i < C4NUM; ++i) { - float input_data = ((float *)src + src_c_offset)[i]; - input_data = input_data < 0 ? 0 : input_data; - input_data = input_data > 6 ? 6 : input_data; - ((float *)dst + dst_c_offset)[i] = input_data; - } -#endif - } - // res part - int res_c = channel - (c4 - 1) * C4NUM; - for (int i = 0; i < res_c; i++) { - int src_res_c_offset = src_kernel_offset + (c4 - 1) * C4NUM * plane + i; - int dst_res_c_offset = dst_kernel_offset + (c4 - 1) * C4NUM + i; - float input_data = ((float *)src + src_res_c_offset)[0]; - input_data = input_data < 0 ? 0 : input_data; - input_data = input_data > 6 ? 6 : input_data; - ((float *)dst + dst_res_c_offset)[0] = input_data; - } - } - } -} - -void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel) {} - -void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * c4 * C4NUM; - int dst_offset = b * plane * channel; - for (int c = 0; c < channel; c++) { - int c4_block_num = c / C4NUM; - int c4_block_res = c % C4NUM; - int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; - int dst_c_offset = dst_offset + c * plane; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k * C4NUM; - int dst_kernel_offset = dst_c_offset + k; - ((float *)dst + dst_kernel_offset)[0] = ((float *)src + src_kernel_offset)[0]; - } - } - } -} - void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel) { for (int n = 0; n < batch; n++) { for (int hw = 0; hw < plane; hw++) { @@ -896,45 +777,6 @@ void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int c } } -void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { - int nhwc4_batch_offset = 0; - int c4 = UP_DIV(channel, C4NUM); - int nhwc4_batch_unit_offset = c4 * C4NUM * plane; - - for (int b = 0; b < batch; b++) { - int batch_offset = b * channel * plane; - for (int c = 0; c < channel; c++) { - int src_c_offset = batch_offset + c * plane; - int dst_c_offset = nhwc4_batch_offset + c; - for (int i = 0; i < plane; i++) { - int src_plane_offset = src_c_offset + i; - int dst_plane_offset = dst_c_offset + i * c4 * C4NUM; - ((uint8_t *)dst)[dst_plane_offset] = ((uint8_t *)src)[src_plane_offset]; - } - } - nhwc4_batch_offset += nhwc4_batch_unit_offset; - } -} - -void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * c4 * C4NUM; - int dst_offset = b * plane * channel; - for (int c = 0; c < channel; c++) { - int c4_block_num = c / C4NUM; - int c4_block_res = c % C4NUM; - int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; - int dst_c_offset = dst_offset + c4_block_num * C4NUM + c4_block_res; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k * C4NUM; - int dst_kernel_offset = dst_c_offset + k * c4 * C4NUM; - ((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0]; - } - } - } -} - void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { int c4 = UP_DIV(channel, C4NUM); for (int b = 0; b < batch; b++) { @@ -962,25 +804,6 @@ void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int } } -void PackNC4HW4ToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel) { - int c4 = UP_DIV(channel, C4NUM); - for (int b = 0; b < batch; b++) { - int src_offset = b * plane * c4 * C4NUM; - int dst_offset = b * plane * channel; - for (int c = 0; c < channel; c++) { - int c4_block_num = c / C4NUM; - int c4_block_res = c % C4NUM; - int src_c_offset = src_offset + c4_block_num * plane * C4NUM + c4_block_res; - int dst_c_offset = dst_offset + c * plane; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_c_offset + k * C4NUM; - int dst_kernel_offset = dst_c_offset + k; - ((uint8_t *)dst + dst_kernel_offset)[0] = ((uint8_t *)src + src_kernel_offset)[0]; - } - } - } -} - void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel) { for (int n = 0; n < batch; n++) { for (int hw = 0; hw < plane; hw++) { @@ -996,25 +819,6 @@ void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int return; } -void PackNHWCToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel) { - int c8 = UP_DIV(channel, C8NUM); - for (int b = 0; b < batch; b++) { - int src_oc_offset = b * plane * channel; - int dst_oc_offset = b * plane * c8 * C8NUM; - for (int k = 0; k < plane; k++) { - int src_kernel_offset = src_oc_offset + k * channel; - int dst_kernel_offset = dst_oc_offset + k * C8NUM; - for (int i = 0; i < channel; i++) { - int c8_block_num = i / C8NUM; - int c8_block_rem = i % C8NUM; - int src_ic_offset = src_kernel_offset + i; - int dst_ic_offset = dst_kernel_offset + c8_block_num * plane * C8NUM + c8_block_rem; - ((int8_t *)dst + dst_ic_offset)[0] = ((int8_t *)src + src_ic_offset)[0]; - } - } - } -} - void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel) { for (int n = 0; n < batch; n++) { for (int c = 0; c < channel; c++) { @@ -1231,27 +1035,6 @@ void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int ch return PackNHWCToNCHWFp32(src, dst, batch, channel, plane); } -void MatrixPackUnit(const float *src, float *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride) { - size_t copy_size = row * C4NUM * sizeof(float); - for (int c = 0; c < col; c++) { - memcpy(dst + c * dst_stride, src + c * src_stride, copy_size); - } -} - -void MatrixPack(const float *src, float *dst, int row, int ic4, int stride) { - int row4mod = row % 4; - int row4div = row / 4; - - for (int i = 0; i < row4div; i++) { - MatrixPackUnit(src + i * 4 * 4, dst + i * 4 * ic4 * 4, 4, ic4, stride, 16); - } - - if (row4mod > 0) { - MatrixPackUnit(src + row4div * 4 * 4, dst + row4div * 4 * ic4 * 4, row4mod, ic4, stride, row4mod * 4); - } - return; -} - void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) { int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_; int ic4 = UP_DIV(conv_param->input_channel_, C4NUM); diff --git a/mindspore/lite/nnacl/pack.h b/mindspore/lite/nnacl/pack.h index d4178d74f0..8674f9c770 100644 --- a/mindspore/lite/nnacl/pack.h +++ b/mindspore/lite/nnacl/pack.h @@ -46,11 +46,6 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param); -void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel, - size_t plane_size, ConvParameter *conv_param); - -void MatrixPack(const float *src, float *dst, int row, int ic4, int stride); - void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param); void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel); @@ -75,20 +70,10 @@ void PackNCHWToNHWCFp32(const void *src, void *dst, int batch, int plane, int ch void PackNHWC4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); -void PackNCHWToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); - void PackNC4HW4ToNHWC4Fp32(const void *src, void *dst, int batch, int plane, int channel); void PackNC4HW4ToNHWCFp32(const void *src, void *dst, int batch, int plane, int channel); -void PackNC4HW4ToNHWCReluFp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackNC4HW4ToNHWCRelu6Fp32(const void *src, void *dst, int batch, int plane, int channel); - -void PackNC4HW4ToNHWCPreluFp32(const void *src, void *dst, const void *slope, int batch, int plane, int channel); - -void PackNC4HW4ToNCHWFp32(const void *src, void *dst, int batch, int plane, int channel); - void PackNHWCToC8HWN8Fp32(const void *src, void *dst, int batch, int plane, int channel); void PackNHWCToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); @@ -99,18 +84,10 @@ void PackNHWCToNHWC8Int8(const void *src, void *dst, int batch, int plane, int c void PackNHWC8ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); -void PackNCHWToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); - -void PackNC4HW4ToNHWC4Int8(const void *src, void *dst, int batch, int plane, int channel); - void PackNC4HW4ToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); -void PackNC4HW4ToNCHWInt8(const void *src, void *dst, int batch, int plane, int channel); - void PackNHWCToC8HWN8Int8(const void *src, void *dst, int batch, int plane, int channel); -void PackNHWCToNC8HW8Int8(const void *src, void *dst, int batch, int plane, int channel); - void PackNCHWToNHWCInt8(const void *src, void *dst, int batch, int plane, int channel); void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param); diff --git a/mindspore/lite/nnacl/winograd_transform.c b/mindspore/lite/nnacl/winograd_transform.c index 6b185395c0..42967768f9 100644 --- a/mindspore/lite/nnacl/winograd_transform.c +++ b/mindspore/lite/nnacl/winograd_transform.c @@ -42,7 +42,7 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float * int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); - int dst_plane_offset = c * C4NUM * ic4; + int dst_plane_offset = c * in_channel; for (int ic = 0; ic < ic4; ic++) { // clear tmp buffer memset(tmp_data, 0, input_unit * input_unit * C4NUM * sizeof(float)); @@ -91,9 +91,9 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float * const int tile_num = 12; #endif int dst_ic4_offset = dst_plane_offset + ic * C4NUM; - size_t dst_step = tile_num * ic4 * C4NUM; + size_t dst_step = tile_num * in_channel; float *trans_input_ptr = trans_input + dst_ic4_offset; - func(tmp_data, trans_input_ptr, C4NUM, dst_step); + func(tmp_data, trans_input_ptr, C4NUM, dst_step, real_c); } out_tile_index++; } // cal_tile_num loop diff --git a/mindspore/lite/nnacl/winograd_utils.c b/mindspore/lite/nnacl/winograd_utils.c index b9010b7956..117c5a77da 100644 --- a/mindspore/lite/nnacl/winograd_utils.c +++ b/mindspore/lite/nnacl/winograd_utils.c @@ -171,227 +171,241 @@ void GeneralOutputTransformUnit(const float *src_data, float *dst_data, const fl InputTransFunc GetInputTransFunc(int input_unit) { return InputTransFuncList[input_unit]; } -void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step) { +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { #ifdef ENABLE_ARM - float32x4_t src[16]; - float32x4_t t[16]; - float32x4_t m[16]; - Load16Data; - for (int l = 0; l < 4; ++l) { - int offset = l * 4; - t[l] = vsubq_f32(src[offset], src[2 + offset]); - t[4 + l] = vaddq_f32(src[1 + offset], src[2 + offset]); - t[8 + l] = vsubq_f32(src[2 + offset], src[1 + offset]); - t[12 + l] = vsubq_f32(src[3 + offset], src[1 + offset]); - } - for (int l = 0; l < 4; ++l) { - int offset = l * 4; - m[l] = vsubq_f32(t[offset], t[2 + offset]); - m[4 + l] = vaddq_f32(t[1 + offset], t[2 + offset]); - m[8 + l] = vsubq_f32(t[2 + offset], t[1 + offset]); - m[12 + l] = vsubq_f32(t[3 + offset], t[1 + offset]); - } - for (int i = 0; i < 16; i++) { - vst1q_f32(dst_data + i * dst_step, m[i]); - } -#else - float src[16]; - float t[16]; - float m[16]; - for (int i = 0; i < C4NUM; ++i) { - for (int j = 0; j < 16; ++j) { - src[j] = src_data[i + j * src_step]; - } + if (real_c == 4) { + float32x4_t src[16]; + float32x4_t t[16]; + float32x4_t m[16]; + Load16Data; for (int l = 0; l < 4; ++l) { int offset = l * 4; - t[l] = src[offset] - src[2 + offset]; - t[4 + l] = src[1 + offset] + src[2 + offset]; - t[8 + l] = src[2 + offset] - src[1 + offset]; - t[12 + l] = src[3 + offset] - src[1 + offset]; + t[l] = vsubq_f32(src[offset], src[2 + offset]); + t[4 + l] = vaddq_f32(src[1 + offset], src[2 + offset]); + t[8 + l] = vsubq_f32(src[2 + offset], src[1 + offset]); + t[12 + l] = vsubq_f32(src[3 + offset], src[1 + offset]); } for (int l = 0; l < 4; ++l) { int offset = l * 4; - m[l] = t[offset] - t[2 + offset]; - m[4 + l] = t[1 + offset] + t[2 + offset]; - m[8 + l] = t[2 + offset] - t[1 + offset]; - m[12 + l] = t[3 + offset] - t[1 + offset]; + m[l] = vsubq_f32(t[offset], t[2 + offset]); + m[4 + l] = vaddq_f32(t[1 + offset], t[2 + offset]); + m[8 + l] = vsubq_f32(t[2 + offset], t[1 + offset]); + m[12 + l] = vsubq_f32(t[3 + offset], t[1 + offset]); } - for (int k = 0; k < 16; ++k) { - dst_data[i + k * dst_step] = m[k]; + for (int i = 0; i < 16; i++) { + vst1q_f32(dst_data + i * dst_step, m[i]); } + } else { +#endif + float src[16]; + float t[16]; + float m[16]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 16; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = src[offset] - src[2 + offset]; + t[4 + l] = src[1 + offset] + src[2 + offset]; + t[8 + l] = src[2 + offset] - src[1 + offset]; + t[12 + l] = src[3 + offset] - src[1 + offset]; + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = t[offset] - t[2 + offset]; + m[4 + l] = t[1 + offset] + t[2 + offset]; + m[8 + l] = t[2 + offset] - t[1 + offset]; + m[12 + l] = t[3 + offset] - t[1 + offset]; + } + for (int k = 0; k < 16; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } +#ifdef ENABLE_ARM } #endif } -void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step) { +void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { #ifdef ENABLE_ARM - float32x4_t src[36]; - float32x4_t t[36]; - float32x4_t m[36]; - Load36Data; - for (int l = 0; l < 6; ++l) { - int offset = l * 6; - float32x4_t tmp1 = vsubq_f32(src[3 + offset], src[1 + offset]); - float32x4_t tmp2 = vsubq_f32(src[4 + offset], src[2 + offset]); - t[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 4), vmulq_n_f32(src[2 + offset], 5)), src[4 + offset]); - t[6 + l] = vaddq_f32(vmulq_n_f32(vaddq_f32(src[1 + offset], src[2 + offset]), -4), - vaddq_f32(src[3 + offset], src[4 + offset])); - t[12 + l] = vaddq_f32(vmulq_n_f32(vsubq_f32(src[1 + offset], src[2 + offset]), 4), - vsubq_f32(src[4 + offset], src[3 + offset])); - t[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2); - t[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2); - t[30 + l] = vaddq_f32(vsubq_f32(vmulq_n_f32(src[1 + offset], 4), vmulq_n_f32(src[3 + offset], 5)), src[5 + offset]); - } - for (int l = 0; l < 6; ++l) { - int offset = l * 6; - float32x4_t tmp1 = vsubq_f32(t[3 + offset], t[1 + offset]); - float32x4_t tmp2 = vsubq_f32(t[4 + offset], t[2 + offset]); - m[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 4), vmulq_n_f32(t[2 + offset], 5)), t[4 + offset]); - m[6 + l] = - vaddq_f32(vmulq_n_f32(vaddq_f32(t[1 + offset], t[2 + offset]), -4), vaddq_f32(t[3 + offset], t[4 + offset])); - m[12 + l] = - vaddq_f32(vmulq_n_f32(vsubq_f32(t[1 + offset], t[2 + offset]), 4), vsubq_f32(t[4 + offset], t[3 + offset])); - m[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2); - m[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2); - m[30 + l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[1 + offset], 4), vmulq_n_f32(t[3 + offset], 5)), t[5 + offset]); - } - for (int i = 0; i < 36; i++) { - vst1q_f32(dst_data + i * dst_step, m[i]); - } -#else - float src[36]; - float t[36]; - float m[36]; - for (int i = 0; i < C4NUM; ++i) { - for (int j = 0; j < 36; ++j) { - src[j] = src_data[i + j * src_step]; - } + if (real_c == 4) { + float32x4_t src[36]; + float32x4_t t[36]; + float32x4_t m[36]; + Load36Data; for (int l = 0; l < 6; ++l) { int offset = l * 6; - float tmp1 = src[3 + offset] - src[1 + offset]; - float tmp2 = src[4 + offset] - src[2 + offset]; - t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset]; - t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]); - t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]); - t[18 + l] = 2 * tmp1 + tmp2; - t[24 + l] = -2 * tmp1 + tmp2; - t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset]; + float32x4_t tmp1 = vsubq_f32(src[3 + offset], src[1 + offset]); + float32x4_t tmp2 = vsubq_f32(src[4 + offset], src[2 + offset]); + t[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 4), vmulq_n_f32(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = vaddq_f32(vmulq_n_f32(vaddq_f32(src[1 + offset], src[2 + offset]), -4), + vaddq_f32(src[3 + offset], src[4 + offset])); + t[12 + l] = vaddq_f32(vmulq_n_f32(vsubq_f32(src[1 + offset], src[2 + offset]), 4), + vsubq_f32(src[4 + offset], src[3 + offset])); + t[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2); + t[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2); + t[30 + l] = + vaddq_f32(vsubq_f32(vmulq_n_f32(src[1 + offset], 4), vmulq_n_f32(src[3 + offset], 5)), src[5 + offset]); } for (int l = 0; l < 6; ++l) { int offset = l * 6; - float tmp1 = t[3 + offset] - t[1 + offset]; - float tmp2 = t[4 + offset] - t[2 + offset]; - m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset]; - m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]); - m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]); - m[18 + l] = 2 * tmp1 + tmp2; - m[24 + l] = -2 * tmp1 + tmp2; - m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset]; - } - for (int k = 0; k < 36; ++k) { - dst_data[i + k * dst_step] = m[k]; + float32x4_t tmp1 = vsubq_f32(t[3 + offset], t[1 + offset]); + float32x4_t tmp2 = vsubq_f32(t[4 + offset], t[2 + offset]); + m[l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 4), vmulq_n_f32(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = + vaddq_f32(vmulq_n_f32(vaddq_f32(t[1 + offset], t[2 + offset]), -4), vaddq_f32(t[3 + offset], t[4 + offset])); + m[12 + l] = + vaddq_f32(vmulq_n_f32(vsubq_f32(t[1 + offset], t[2 + offset]), 4), vsubq_f32(t[4 + offset], t[3 + offset])); + m[18 + l] = vaddq_f32(vmulq_n_f32(tmp1, 2), tmp2); + m[24 + l] = vaddq_f32(vmulq_n_f32(tmp1, -2), tmp2); + m[30 + l] = vaddq_f32(vsubq_f32(vmulq_n_f32(t[1 + offset], 4), vmulq_n_f32(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + vst1q_f32(dst_data + i * dst_step, m[i]); } + } else { +#endif + float src[36]; + float t[36]; + float m[36]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 36; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = src[3 + offset] - src[1 + offset]; + float tmp2 = src[4 + offset] - src[2 + offset]; + t[l] = 4 * src[offset] - 5 * src[2 + offset] + src[4 + offset]; + t[6 + l] = -4 * (src[1 + offset] + src[2 + offset]) + (src[3 + offset] + src[4 + offset]); + t[12 + l] = 4 * (src[1 + offset] - src[2 + offset]) + (src[4 + offset] - src[3 + offset]); + t[18 + l] = 2 * tmp1 + tmp2; + t[24 + l] = -2 * tmp1 + tmp2; + t[30 + l] = 4 * src[1 + offset] - 5 * src[3 + offset] + src[5 + offset]; + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float tmp1 = t[3 + offset] - t[1 + offset]; + float tmp2 = t[4 + offset] - t[2 + offset]; + m[l] = 4 * t[offset] - 5 * t[2 + offset] + t[4 + offset]; + m[6 + l] = -4 * (t[1 + offset] + t[2 + offset]) + (t[3 + offset] + t[4 + offset]); + m[12 + l] = 4 * (t[1 + offset] - t[2 + offset]) + (t[4 + offset] - t[3 + offset]); + m[18 + l] = 2 * tmp1 + tmp2; + m[24 + l] = -2 * tmp1 + tmp2; + m[30 + l] = 4 * t[1 + offset] - 5 * t[3 + offset] + t[5 + offset]; + } + for (int k = 0; k < 36; ++k) { + dst_data[i + k * dst_step] = m[k]; + } + } +#ifdef ENABLE_ARM } #endif } -void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step) { +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c) { #ifdef ENABLE_ARM - float32x4_t src[64]; - float32x4_t t[64]; - float32x4_t m[64]; - Load64Data; - for (int l = 0; l < 8; ++l) { - int offset = l * 8; - t[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 0.5625), vmulq_n_f32(src[2 + offset], 3.0625)), - vmulq_n_f32(src[4 + offset], 3.5)), - src[6 + offset]); - float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 1.125), vmulq_n_f32(src[5 + offset], 0.5)); - float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 2.25), vmulq_n_f32(src[4 + offset], 3.25)); - t[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]); - t[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]); - tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.5625), src[5 + offset]); - tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.5625), vmulq_n_f32(src[4 + offset], 2.5)); - t[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]); - t[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]); - tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.375), vmulq_n_f32(src[5 + offset], 1.5)); - tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.25), vmulq_n_f32(src[4 + offset], 1.25)); - t[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]); - t[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]); - t[56 + l] = - vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src[1 + offset], -0.5625), vmulq_n_f32(src[3 + offset], 3.0625)), - vmulq_n_f32(src[5 + offset], 3.5)), - src[7 + offset]); - } - for (int l = 0; l < 8; ++l) { - int offset = l * 8; - m[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 0.5625), vmulq_n_f32(t[2 + offset], 3.0625)), - vmulq_n_f32(t[4 + offset], 3.5)), - t[6 + offset]); - float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 1.125), vmulq_n_f32(t[5 + offset], 0.5)); - float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 2.25), vmulq_n_f32(t[4 + offset], 3.25)); - m[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]); - m[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]); - tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.5625), t[5 + offset]); - tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.5625), vmulq_n_f32(t[4 + offset], 2.5)); - m[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]); - m[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]); - tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.375), vmulq_n_f32(t[5 + offset], 1.5)); - tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.25), vmulq_n_f32(t[4 + offset], 1.25)); - m[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]); - m[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]); - m[56 + l] = vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t[1 + offset], -0.5625), vmulq_n_f32(t[3 + offset], 3.0625)), - vmulq_n_f32(t[5 + offset], 3.5)), - t[7 + offset]); - } - for (int i = 0; i < 64; i++) { - vst1q_f32(dst_data + i * dst_step, m[i]); - } -#else - float src[64]; - float t[64]; - float m[64]; - for (int i = 0; i < C4NUM; ++i) { - for (int j = 0; j < 64; ++j) { - src[j] = src_data[i + j * src_step]; - } + if (real_c == 4) { + float32x4_t src[64]; + float32x4_t t[64]; + float32x4_t m[64]; + Load64Data; for (int l = 0; l < 8; ++l) { int offset = l * 8; - t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset]; - float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset]; - float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset]; - t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset]; - t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset]; - tmp1 = 0.5625f * src[1 + offset] + src[5 + offset]; - tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset]; - t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset]; - t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset]; - tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset]; - tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset]; - t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset]; - t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset]; - t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset]; + t[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(src[offset], 0.5625), vmulq_n_f32(src[2 + offset], 3.0625)), + vmulq_n_f32(src[4 + offset], 3.5)), + src[6 + offset]); + float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 1.125), vmulq_n_f32(src[5 + offset], 0.5)); + float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 2.25), vmulq_n_f32(src[4 + offset], 3.25)); + t[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.5625), vmulq_n_f32(src[4 + offset], 2.5)); + t[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = vaddq_f32(vmulq_n_f32(src[1 + offset], 0.375), vmulq_n_f32(src[5 + offset], 1.5)); + tmp2 = vsubq_f32(vmulq_n_f32(src[2 + offset], 0.25), vmulq_n_f32(src[4 + offset], 1.25)); + t[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(src[1 + offset], -0.5625), vmulq_n_f32(src[3 + offset], 3.0625)), + vmulq_n_f32(src[5 + offset], 3.5)), + src[7 + offset]); } for (int l = 0; l < 8; ++l) { int offset = l * 8; - m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset]; - float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset]; - float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset]; - m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset]; - m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset]; - tmp1 = 0.5625f * t[1 + offset] + t[5 + offset]; - tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset]; - m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset]; - m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset]; - tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset]; - tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset]; - m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset]; - m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset]; - m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset]; - } - for (int k = 0; k < 64; ++k) { - dst_data[i + k * dst_step] = m[k]; + m[l] = vsubq_f32(vaddq_f32(vsubq_f32(vmulq_n_f32(t[offset], 0.5625), vmulq_n_f32(t[2 + offset], 3.0625)), + vmulq_n_f32(t[4 + offset], 3.5)), + t[6 + offset]); + float32x4_t tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 1.125), vmulq_n_f32(t[5 + offset], 0.5)); + float32x4_t tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 2.25), vmulq_n_f32(t[4 + offset], 3.25)); + m[8 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.5625), vmulq_n_f32(t[4 + offset], 2.5)); + m[24 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = vaddq_f32(vmulq_n_f32(t[1 + offset], 0.375), vmulq_n_f32(t[5 + offset], 1.5)); + tmp2 = vsubq_f32(vmulq_n_f32(t[2 + offset], 0.25), vmulq_n_f32(t[4 + offset], 1.25)); + m[40 + l] = vaddq_f32(vsubq_f32(vaddq_f32(tmp1, tmp2), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = vaddq_f32(vaddq_f32(vsubq_f32(tmp2, tmp1), vmulq_n_f32(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = + vaddq_f32(vsubq_f32(vaddq_f32(vmulq_n_f32(t[1 + offset], -0.5625), vmulq_n_f32(t[3 + offset], 3.0625)), + vmulq_n_f32(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + vst1q_f32(dst_data + i * dst_step, m[i]); + } + } else { +#endif + float src[64]; + float t[64]; + float m[64]; + for (int i = 0; i < real_c; ++i) { + for (int j = 0; j < 64; ++j) { + src[j] = src_data[i + j * src_step]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = 0.5625f * src[offset] - 3.0625f * src[2 + offset] + 3.5f * src[4 + offset] - src[6 + offset]; + float tmp1 = 1.125f * src[1 + offset] + 0.5f * src[5 + offset]; + float tmp2 = 2.25f * src[2 + offset] - 3.25f * src[4 + offset]; + t[8 + l] = tmp1 + tmp2 - 1.625f * src[3 + offset] + src[6 + offset]; + t[16 + l] = tmp2 - tmp1 + 1.625f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.5625f * src[1 + offset] + src[5 + offset]; + tmp2 = 0.5625f * src[2 + offset] - 2.5f * src[4 + offset]; + t[24 + l] = tmp1 + tmp2 - 2.5f * src[3 + offset] + src[6 + offset]; + t[32 + l] = tmp2 - tmp1 + 2.5f * src[3 + offset] + src[6 + offset]; + tmp1 = 0.375f * src[1 + offset] + 1.5f * src[5 + offset]; + tmp2 = 0.25f * src[2 + offset] - 1.25f * src[4 + offset]; + t[40 + l] = tmp1 + tmp2 - 1.875f * src[3 + offset] + src[6 + offset]; + t[48 + l] = tmp2 - tmp1 + 1.875f * src[3 + offset] + src[6 + offset]; + t[56 + l] = -0.5625f * src[1 + offset] + 3.0625f * src[3 + offset] - 3.5f * src[5 + offset] + src[7 + offset]; + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = 0.5625f * t[offset] - 3.0625f * t[2 + offset] + 3.5f * t[4 + offset] - t[6 + offset]; + float tmp1 = 1.125f * t[1 + offset] + 0.5f * t[5 + offset]; + float tmp2 = 2.25f * t[2 + offset] - 3.25f * t[4 + offset]; + m[8 + l] = tmp1 + tmp2 - 1.625f * t[3 + offset] + t[6 + offset]; + m[16 + l] = tmp2 - tmp1 + 1.625f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.5625f * t[1 + offset] + t[5 + offset]; + tmp2 = 0.5625f * t[2 + offset] - 2.5f * t[4 + offset]; + m[24 + l] = tmp1 + tmp2 - 2.5f * t[3 + offset] + t[6 + offset]; + m[32 + l] = tmp2 - tmp1 + 2.5f * t[3 + offset] + t[6 + offset]; + tmp1 = 0.375f * t[1 + offset] + 1.5f * t[5 + offset]; + tmp2 = 0.25f * t[2 + offset] - 1.25f * t[4 + offset]; + m[40 + l] = tmp1 + tmp2 - 1.875f * t[3 + offset] + t[6 + offset]; + m[48 + l] = tmp2 - tmp1 + 1.875f * t[3 + offset] + t[6 + offset]; + m[56 + l] = -0.5625f * t[1 + offset] + 3.0625f * t[3 + offset] - 3.5f * t[5 + offset] + t[7 + offset]; + } + for (int k = 0; k < 64; ++k) { + dst_data[i + k * dst_step] = m[k]; + } } +#ifdef ENABLE_ARM } #endif } diff --git a/mindspore/lite/nnacl/winograd_utils.h b/mindspore/lite/nnacl/winograd_utils.h index 8e7e8745d7..f9bdece471 100644 --- a/mindspore/lite/nnacl/winograd_utils.h +++ b/mindspore/lite/nnacl/winograd_utils.h @@ -28,7 +28,7 @@ #ifdef __cplusplus extern "C" { #endif -typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step); +typedef void (*InputTransFunc)(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); typedef void (*OutputTransFunc)(const float *src_data, float *dst_data, const float *bias_data, int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); @@ -163,11 +163,11 @@ void GeneralOutputTransformUnit(const float *src_data, float *dst_data, const fl InputTransFunc GetInputTransFunc(int input_unit); -void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step); +void InputTransform4x4Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); -void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step); +void InputTransform6x6Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); -void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step); +void InputTransform8x8Unit(const float *src_data, float *dst_data, int src_step, int dst_step, int real_c); OutputTransFunc GetOutputTransFunc(int input_unit, int output_unit, ActType act_type); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc index b1fdd2e48d..930d14dab9 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc @@ -39,21 +39,18 @@ int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_da // original weight format : ohwi auto channel_in = conv_param_->input_channel_; auto channel_out = conv_param_->output_channel_; - int ic4 = UP_DIV(channel_in, C4NUM); int oc_block_num = UP_DIV(channel_out, oc_block); - int c4_channel = ic4 * C4NUM; - int block_stride = c4_channel * oc_block; + int block_stride = channel_in * oc_block; int block_num_stride = block_stride * oc_block_num; // trans_filter = G*g*GT (g represents weight_data) // separate into two steps ===> tmp = (g * GT)T ===> trans = (tmp * GT)T use same function:MatrixMultiplyWinograd - auto tmp_data = reinterpret_cast(malloc(c4_channel * input_unit_ * kernel_unit_ * sizeof(float))); + auto tmp_data = reinterpret_cast(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float))); if (tmp_data == nullptr) { MS_LOG(ERROR) << "malloc tmp_data failed."; return RET_MEMORY_FAILED; } - memset(tmp_data, 0, c4_channel * input_unit_ * kernel_unit_ * sizeof(float)); - auto trans_out_data = reinterpret_cast(malloc(c4_channel * input_unit_ * input_unit_ * sizeof(float))); + auto trans_out_data = reinterpret_cast(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float))); if (trans_out_data == nullptr) { free(tmp_data); MS_LOG(ERROR) << "malloc trans_out_data failed."; @@ -61,14 +58,14 @@ int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_da } #ifndef ENABLE_ARM64 - auto tmp_data1 = reinterpret_cast(malloc(c4_channel * input_unit_ * kernel_unit_ * sizeof(float))); + auto tmp_data1 = reinterpret_cast(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float))); if (tmp_data1 == nullptr) { free(tmp_data); free(trans_out_data); MS_LOG(ERROR) << "malloc tmp_data1 failed."; return RET_MEMORY_FAILED; } - auto trans_out_data1 = reinterpret_cast(malloc(c4_channel * input_unit_ * input_unit_ * sizeof(float))); + auto trans_out_data1 = reinterpret_cast(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float))); if (trans_out_data1 == nullptr) { free(tmp_data); free(tmp_data1); @@ -87,30 +84,30 @@ int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_da #ifndef ENABLE_ARM64 // tmp_data = g * GT MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_, - input_unit_, channel_in, c4_channel * 4); + input_unit_, channel_in, channel_in * 4); // tmp_data1 = (tmp_data)T - PackHWCToWHC(tmp_data, tmp_data1, kernel_unit_, input_unit_, c4_channel); + PackHWCToWHC(tmp_data, tmp_data1, kernel_unit_, input_unit_, channel_in); // trans_out_data1 = tmp * GT - MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit_, kernel_unit_, input_unit_, c4_channel, - c4_channel * 4); + MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit_, kernel_unit_, input_unit_, channel_in, + channel_in * 4); // trans_out_data = (trans_out_data1)T - PackHWCToWHC(trans_out_data1, trans_out_data, input_unit_, input_unit_, c4_channel); + PackHWCToWHC(trans_out_data1, trans_out_data, input_unit_, input_unit_, channel_in); #else // tmp = (g * GT)T MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_, - input_unit_, channel_in, c4_channel * 4); + input_unit_, channel_in, channel_in * 4); // trans = (tmp * GT)T - MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_, c4_channel, - c4_channel * 4); + MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_, channel_in, + channel_in * 4); #endif int in_offset = 0; for (int j = 0; j < input_unit_; ++j) { for (int k = 0; k < input_unit_; ++k) { - for (int c = 0; c < c4_channel; ++c) { + for (int c = 0; c < channel_in; ++c) { *(trans_weight_ + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c]; } - in_offset += c4_channel; + in_offset += channel_in; output_oz_offset += block_num_stride; } } @@ -128,7 +125,6 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { auto filter_tensor = in_tensors_.at(kWeightIndex); int in_channel = filter_tensor->Channel(); int out_channel = filter_tensor->Batch(); - int ic4 = UP_DIV(in_channel, C4NUM); conv_param_->input_channel_ = in_channel; conv_param_->output_channel_ = out_channel; @@ -137,7 +133,7 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { int oc_block_num = UP_DIV(out_channel, C8NUM); // set data - auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float); + auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float); trans_weight_ = reinterpret_cast(malloc(trans_matrix_data_size)); if (trans_weight_ == nullptr) { MS_LOG(ERROR) << "malloc matrix_buffer failed."; @@ -188,7 +184,6 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { int ConvolutionWinogradCPUKernel::InitTmpBuffer() { int channel_out = conv_param_->output_channel_; int oc8 = UP_DIV(channel_out, C8NUM); - int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM); #ifdef ENABLE_ARM32 int tile_num = 4; #else @@ -196,7 +191,8 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { #endif MS_ASSERT(ctx_->allocator != nullptr); - size_t tile_buffer_size = thread_count_ * tile_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float); + size_t tile_buffer_size = + thread_count_ * tile_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float); trans_input_ = reinterpret_cast(ctx_->allocator->Malloc(tile_buffer_size)); if (trans_input_ == nullptr) { MS_LOG(ERROR) << "malloc trans_input_ failed."; @@ -217,8 +213,8 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { return RET_MEMORY_FAILED; } - col_buffer_ = - reinterpret_cast(ctx_->allocator->Malloc(thread_count_ * tile_num * ic4 * C4NUM * sizeof(float))); + col_buffer_ = reinterpret_cast( + ctx_->allocator->Malloc(thread_count_ * tile_num * conv_param_->input_channel_ * sizeof(float))); if (col_buffer_ == nullptr) { MS_LOG(ERROR) << "malloc col_buffer_ failed."; return RET_ERROR;