diff --git a/mindspore/lite/nnacl/fp16/conv_fp16.c b/mindspore/lite/nnacl/fp16/conv_fp16.c index 4c23bfdc47..fad112f77d 100644 --- a/mindspore/lite/nnacl/fp16/conv_fp16.c +++ b/mindspore/lite/nnacl/fp16/conv_fp16.c @@ -120,208 +120,6 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we } #endif -void SWBorderPixel(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height, - int width, int in_kh_step, int in_kw_step, int kernel_h, int kernel_w, int ic, bool is_relu, - bool is_relu6) { - int ic8 = ic / C8NUM; - int ic8_res = ic8 % C8NUM; - int ic4 = ic8_res / C4NUM; - for (int c = 0; c < C4NUM; c++) { - dst[c] = 0; - } - const float16_t *weight_oc = weight; - for (int oc = 0; oc < C4NUM; ++oc) { - const float16_t *weight_kh = weight_oc; - const float16_t *src_kh = src; - for (int kh = 0; kh < height; kh++) { - const float16_t *src_kw = src_kh; - const float16_t *weight_kw = weight_kh; - for (int kw = 0; kw < width; kw++) { - const float16_t *src_ic8 = src_kw; - const float16_t *weight_ic8 = weight_kw; - - for (int rc = 0; rc < ic8; ++rc) { - for (int c = 0; c < C8NUM; c++) { - dst[oc] += src_ic8[c] * weight_ic8[c]; - } - src_ic8 += C8NUM; - weight_ic8 += C8NUM; - } // ic8 loop - - const float16_t *src_ic4 = src_ic8; - const float16_t *weight_ic4 = weight_ic8; - for (int rc = 0; rc < ic4; ++rc) { - for (int c = 0; c < C4NUM; c++) { - dst[oc] += src_ic4[c] * weight_ic4[c]; - } - src_ic4 += C4NUM; - weight_ic4 += C4NUM; - } // ic4 loop - - src_kw += in_kw_step; - weight_kw += ic4 * C4NUM; - } // kernel_w loop - src_kh += in_kh_step; - weight_kh += kernel_w * ic4 * C4NUM; - } // kernel_h loop - dst[oc] += bias[oc]; - dst[oc] = (is_relu) ? (MSMAX(0, dst[oc])) : (dst[oc]); - dst[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst[oc]))) : (dst[oc]); - weight_oc += kernel_h * kernel_w * ic4 * C4NUM; - } // oc loop -} - -void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top, - int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) { - bool relu = conv_param->act_type_ == ActType_Relu; - bool relu6 = conv_param->act_type_ == ActType_Relu6; - float16_t *dst_h = dst + top * sliding->out_h_step_; - for (int oh = top; oh < bottom; oh++) { - int ih = oh * conv_param->stride_h_ - conv_param->pad_u_; - int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_)); - int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_)); - const float16_t *src_h = src + ih * sliding->in_h_step_; - - float16_t *dst_kernel = dst_h + left * sliding->block_channel_; - for (int ow = left; ow < right; ow++) { - int iw = ow * conv_param->stride_w_ - conv_param->pad_l_; - int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_)); - int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_)); - const float16_t *src_w = src_h + iw * sliding->ic4_channel_; - - const float16_t *src_kernel = src_w + start_kh * sliding->in_kh_step_ + start_kw * sliding->in_kw_step_; - const float16_t *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * sliding->ic4_channel_; - - SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, - sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, - sliding->ic4_channel_, relu, relu6); - - dst_kernel += sliding->block_channel_; - } // width loop - dst_h += sliding->out_h_step_; - } // height loop -} - -void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height, - int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int ic, int in_sh_step, - int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { - int ic8 = ic / C8NUM; - int ic8_res = ic % C8NUM; - int ic4 = ic8_res / C4NUM; - float16_t *dst_h = dst; - const float16_t *src_h = src; - for (int oh = 0; oh < height; oh++) { - float16_t *dst_w = dst_h; - const float16_t *src_w = src_h; - for (int ow = 0; ow < width; ow++) { - const float16_t *weight_oc = weight; - for (int c = 0; c < C4NUM; c++) { - dst_w[c] = 0; - } - - for (int oc = 0; oc < C4NUM; oc++) { - const float16_t *weight_kh = weight_oc; - const float16_t *src_kh = src_w; - for (int kh = 0; kh < kernel_h; kh++) { - const float16_t *src_kw = src_kh; - const float16_t *weight_kw = weight_kh; - for (int kw = 0; kw < kernel_w; kw++) { - const float16_t *src_ic8 = src_kw; - const float16_t *weight_ic8 = weight_kw; - - for (int rc = 0; rc < ic8; ++rc) { - for (int c = 0; c < C8NUM; c++) { - dst_w[oc] += src_ic8[c] * weight_ic8[c]; - } - - src_ic8 += C8NUM; - weight_ic8 += C8NUM; - } // ic8 loop - - const float16_t *src_ic4 = src_ic8; - const float16_t *weight_ic4 = weight_ic8; - for (int rc = 0; rc < ic4; ++rc) { - for (int c = 0; c < C4NUM; c++) { - dst_w[oc] += src_ic4[c] * weight_ic4[c]; - } - - src_ic4 += C4NUM; - weight_ic4 += C4NUM; - } // ic4 loop - - src_kw += in_kw_step; - weight_kw += ic4 * C4NUM; - } // kernel_w loop - src_kh += in_kh_step; - weight_kh += kernel_w * ic4 * C4NUM; - } // kernel_h loop - // add biad relu - - dst_w[oc] += bias[oc]; - dst_w[oc] = (is_relu) ? (MSMAX(0, dst_w[oc])) : (dst_w[oc]); - dst_w[oc] = (is_relu6) ? (MSMIN(6, MSMAX(0, dst_w[oc]))) : (dst_w[oc]); - weight_oc += kernel_h * kernel_w * ic4 * C4NUM; - } // oc block - - dst_w += block_channel; - src_w += in_sw_step; - } // dst_width loop - dst_h += out_h_step; - src_h += in_sh_step; - } // dst_height loop -} - -// fp16 conv sliding window -void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data, - float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param, - SlidingWindowParam *slidingWindow_param) { - bool relu = conv_param->act_type_ == ActType_Relu; - bool relu6 = conv_param->act_type_ == ActType_Relu6; - int oc4_res = conv_param->output_channel_ % C4NUM; - const float16_t *src = input_data; - float16_t *dst = NULL; - if (oc4_res == 0) { - dst = output_data; - } else { - dst = tmp_out_block; - } - - for (int b = 0; b < conv_param->output_batch_; b++) { - for (int oc = task_id; oc < slidingWindow_param->c_block_; oc += conv_param->thread_num_) { - const float16_t *src_data = src; - float16_t *dst_data = dst + oc * C4NUM; - const float16_t *weight = packed_weight + oc * slidingWindow_param->kernel_step_; - const float16_t *bias = bias_data + oc * C4NUM; - SWBorderFp16(dst_data, src_data, weight, bias, 0, slidingWindow_param->top_, 0, conv_param->output_w_, conv_param, - slidingWindow_param); - SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->bottom_, conv_param->output_h_, 0, - conv_param->output_w_, conv_param, slidingWindow_param); - SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_, 0, - slidingWindow_param->left_, conv_param, slidingWindow_param); - SWBorderFp16(dst_data, src_data, weight, bias, slidingWindow_param->top_, slidingWindow_param->bottom_, - slidingWindow_param->right_, conv_param->output_w_, conv_param, slidingWindow_param); - - if (slidingWindow_param->right_ > slidingWindow_param->left_ && - slidingWindow_param->bottom_ > slidingWindow_param->top_) { - int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_u_; - int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_l_; - const float16_t *in_t = - src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_; - float16_t *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ + - slidingWindow_param->left_ * slidingWindow_param->block_channel_; - SWCenterFp16(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_, - slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, - conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, - slidingWindow_param->ic4_channel_, slidingWindow_param->in_sh_step_, - slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_, - slidingWindow_param->in_kw_step_, relu, relu6); - } - } // output C4 loop - src += slidingWindow_param->in_step_; - dst += slidingWindow_param->out_step_; - } // batch loop -} - // fp16 convolution common (im2col+gemm) void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param) { @@ -537,8 +335,9 @@ void UnPack3x3Relu6OutputFp16(const float16_t *src, float16_t *dst, int batch, i // fp16 convolution winograd void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, - TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, - MatricesFp16 *matrices) { + float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, + InputTransFp16Func in_func, OutputTransFp16Func out_func) { + const int tile_num = 16; int thread_num = conv_param->thread_num_; int input_unit = conv_param->input_unit_; int in_batch = conv_param->input_batch_; @@ -547,7 +346,6 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa int out_unit = conv_param->output_unit_; int out_w_block = UP_DIV(conv_param->output_w_, out_unit); int out_h_block = UP_DIV(conv_param->output_h_, out_unit); - const int tile_num = 16; int output_count = out_w_block * out_h_block; int output_tile_count = UP_DIV(output_count, tile_num); int out_channel = conv_param->output_channel_; @@ -557,8 +355,7 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa float16_t *trans_input = buffer_list[0]; float16_t *gemm_out = buffer_list[1]; - float16_t *tmp_out_data = buffer_list[2]; - float16_t *tmp_data = buffer_list[3]; + float16_t *tmp_data = buffer_list[2]; int trans_input_offset = tile_num * input_unit_square * ic8 * C8NUM; int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; int tmp_data_offset = input_unit_square * C8NUM; @@ -566,156 +363,21 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa // step 2 : input transform (online) for (int b = 0; b < in_batch; b++) { int in_batch_offset = b * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_; - int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc8 * C8NUM; + int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_; for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) { int out_tile_index = thread_id * tile_num; int cal_num = output_count - thread_id * tile_num; cal_num = cal_num > tile_num ? tile_num : cal_num; WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset, tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, - matrices[2], matrices[3]); + in_func); // step 3 : gemm IndirectGemmFp16_16x8(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset, trans_weight, NULL, input_unit_square, ic8 * 2, oc8 * C8NUM, output_offset, 1, 1, 0, 0); // step 4 : output transform - WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data, - cal_num, out_tile_index, out_w_block, conv_param, matrices[0], matrices[1]); - } - } -} - -void UnPackWinogradOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel, - int output_unit) { - int out_h_block_num = UP_DIV(height, output_unit); - int out_w_block_num = UP_DIV(width, output_unit); - int c8 = UP_DIV(channel, C8NUM); - int c8_block = C8NUM * out_h_block_num * output_unit * out_w_block_num * output_unit; - for (int b = 0; b < batch; b++) { - int src_batch_offset = b * c8 * c8_block; - int dst_batch_offset = b * height * width * channel; - for (int h = 0; h < height; h++) { - int src_h_offset = src_batch_offset + C8NUM * (h * out_w_block_num * output_unit); - const int dst_h_offset = dst_batch_offset + h * width * channel; - for (int w = 0; w < width; w++) { - int src_w_offset = src_h_offset + w * C8NUM; - int dst_w_offset = dst_h_offset + w * channel; - for (int c = 0; c < c8 - 1; c++) { - int src_c8_offset = src_w_offset + c * c8_block; - int dst_c8_offset = dst_w_offset + c * C8NUM; -#ifdef ENABLE_NEON - vst1q_f16(dst + dst_c8_offset, vld1q_f16(src + src_c8_offset)); -#else - for (int i = 0; i < C8NUM; ++i) { - dst[dst_c8_offset + i] = src[src_c8_offset + i]; - } -#endif - } - int c_res = channel - (c8 - 1) * C8NUM; - int src_c_res_offset = (c8 - 1) * c8_block; - int dst_c_res_offset = (c8 - 1) * C8NUM; - for (int c = 0; c < c_res; c++) { - int src_c8_res_offset = src_w_offset + src_c_res_offset + c; - int dst_c8_res_offset = dst_w_offset + dst_c_res_offset + c; - dst[dst_c8_res_offset] = src[src_c8_res_offset]; - } - } - } - } -} - -void UnPackWinogradReluOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel, - int output_unit) { - int out_h_block_num = UP_DIV(height, output_unit); - int out_w_block_num = UP_DIV(width, output_unit); - int c8 = UP_DIV(channel, C8NUM); - int c8_block = C8NUM * out_h_block_num * output_unit * out_w_block_num * output_unit; - for (int b = 0; b < batch; b++) { - int src_batch_offset = b * c8 * c8_block; - int dst_batch_offset = b * height * width * channel; - for (int h = 0; h < height; h++) { - int src_h_offset = src_batch_offset + C8NUM * (h * out_w_block_num * output_unit); - const int dst_h_offset = dst_batch_offset + h * width * channel; - for (int w = 0; w < width; w++) { - int src_w_offset = src_h_offset + w * C8NUM; - int dst_w_offset = dst_h_offset + w * channel; - for (int c = 0; c < c8 - 1; c++) { - int src_c8_offset = src_w_offset + c * c8_block; - int dst_c8_offset = dst_w_offset + c * C8NUM; -#ifdef ENABLE_NEON - float16x8_t input_ptr = vld1q_f16(src + src_c8_offset); - float16x8_t zero = vdupq_n_f16(0); - input_ptr = vmaxq_f16(zero, input_ptr); - vst1q_f16(dst + dst_c8_offset, input_ptr); -#else - for (int i = 0; i < C8NUM; ++i) { - float16_t input_data = src[src_c8_offset + i]; - input_data = input_data < 0 ? 0 : input_data; - dst[dst_c8_offset + i] = input_data; - } -#endif - } - int c_res = channel - (c8 - 1) * C8NUM; - int src_c_res_offset = (c8 - 1) * c8_block; - int dst_c_res_offset = (c8 - 1) * C8NUM; - for (int c = 0; c < c_res; c++) { - int src_c8_res_offset = src_w_offset + src_c_res_offset + c; - int dst_c8_res_offset = dst_w_offset + dst_c_res_offset + c; - float16_t input_data = src[src_c8_res_offset]; - input_data = input_data < 0 ? 0 : input_data; - dst[dst_c8_res_offset] = input_data; - } - } - } - } -} - -void UnPackWinogradRelu6OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel, - int output_unit) { - int out_h_block_num = UP_DIV(height, output_unit); - int out_w_block_num = UP_DIV(width, output_unit); - int c8 = UP_DIV(channel, C8NUM); - int c8_block = C8NUM * out_h_block_num * output_unit * out_w_block_num * output_unit; - for (int b = 0; b < batch; b++) { - int src_batch_offset = b * c8 * c8_block; - int dst_batch_offset = b * height * width * channel; - for (int h = 0; h < height; h++) { - int src_h_offset = src_batch_offset + C8NUM * (h * out_w_block_num * output_unit); - const int dst_h_offset = dst_batch_offset + h * width * channel; - for (int w = 0; w < width; w++) { - int src_w_offset = src_h_offset + w * C8NUM; - int dst_w_offset = dst_h_offset + w * channel; - for (int c = 0; c < c8 - 1; c++) { - int src_c8_offset = src_w_offset + c * c8_block; - int dst_c8_offset = dst_w_offset + c * C8NUM; -#ifdef ENABLE_NEON - float16x8_t input_ptr = vld1q_f16(src + src_c8_offset); - float16x8_t zero = vdupq_n_f16(0); - float16x8_t six = vdupq_n_f16(6); - input_ptr = vmaxq_f16(zero, input_ptr); - input_ptr = vminq_f16(six, input_ptr); - vst1q_f16(dst + dst_c8_offset, input_ptr); -#else - for (int i = 0; i < C8NUM; ++i) { - float16_t input_data = src[src_c8_offset + i]; - input_data = input_data < 0 ? 0 : input_data; - input_data = input_data > 6 ? 6 : input_data; - dst[dst_c8_offset + i] = input_data; - } -#endif - } - int c_res = channel - (c8 - 1) * C8NUM; - int src_c_res_offset = (c8 - 1) * c8_block; - int dst_c_res_offset = (c8 - 1) * C8NUM; - for (int c = 0; c < c_res; c++) { - int src_c8_res_offset = src_w_offset + src_c_res_offset + c; - int dst_c8_res_offset = dst_w_offset + dst_c_res_offset + c; - float16_t input_data = src[src_c8_res_offset]; - input_data = input_data < 0 ? 0 : input_data; - input_data = input_data > 6 ? 6 : input_data; - dst[dst_c8_res_offset] = input_data; - } - } + WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, bias_data, + cal_num, out_tile_index, out_w_block, conv_param, out_func); } } } diff --git a/mindspore/lite/nnacl/fp16/conv_fp16.h b/mindspore/lite/nnacl/fp16/conv_fp16.h index 3d9ab6bb2a..0064a553d5 100644 --- a/mindspore/lite/nnacl/fp16/conv_fp16.h +++ b/mindspore/lite/nnacl/fp16/conv_fp16.h @@ -40,17 +40,6 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we #ifdef __cplusplus extern "C" { #endif -void SWBorderFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int top, - int bottom, int left, int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding); - -void SWCenterFp16(float16_t *dst, const float16_t *src, const float16_t *weight, const float16_t *bias, int height, - int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int ic, int in_sh_step, - int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6); - -// fp16 sliding window -void ConvSWFp16(const float16_t *input_data, const float16_t *packed_weight, const float16_t *bias_data, - float16_t *tmp_out_block, float16_t *output_data, int task_id, ConvParameter *conv_param, - SlidingWindowParam *slidingWindow_param); // fp16 convolution common (im2col+gemm) void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data, @@ -69,17 +58,9 @@ void UnPack3x3Relu6OutputFp16(const float16_t *src, float16_t *dst, int batch, i // fp16 convolution winograd void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const float16_t *bias_data, - TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, - MatricesFp16 *matrices); - -void UnPackWinogradOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel, - int output_unit); - -void UnPackWinogradReluOutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel, - int output_unit); + float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param, + InputTransFp16Func in_func, OutputTransFp16Func out_func); -void UnPackWinogradRelu6OutputFp16(const float16_t *src, float16_t *dst, int batch, int height, int width, int channel, - int output_unit); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp16/winograd_transform_fp16.c b/mindspore/lite/nnacl/fp16/winograd_transform_fp16.c index 5ced95a316..a284fb58f8 100644 --- a/mindspore/lite/nnacl/fp16/winograd_transform_fp16.c +++ b/mindspore/lite/nnacl/fp16/winograd_transform_fp16.c @@ -569,8 +569,8 @@ void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, // fp16 common winograd void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, - int out_tile_index, int out_w_block_num, ConvParameter *conv_param, float16_t *matrix_b, - float16_t *matrix_bt) { + int out_tile_index, int out_w_block_num, ConvParameter *conv_param, + InputTransFp16Func func) { const int tile_num = 16; int input_unit = conv_param->input_unit_; int output_unit = conv_param->output_unit_; @@ -593,36 +593,56 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in int interval_x_e = src_x_e < input_w ? input_unit : (input_w - src_x_s); int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s); - int src_plane_offset = ic8 * C8NUM * (src_y_s * input_w + src_x_s); + int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s); int dst_plane_offset = c * C4NUM; for (int ic = 0; ic < ic8; ic++) { // clear tmp buffer memset(tmp_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t)); - // get real input block with padding + int real_c = in_channel - ic * C8NUM; + real_c = real_c > C8NUM ? C8NUM : real_c; int src_ic8_offset = src_plane_offset + ic * C8NUM; - for (int interval = interval_y_s; interval < interval_y_e; interval++) { - int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * ic8 * C8NUM; - int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; - for (int j = 0; j < (interval_x_e - interval_x_s); j++) { - int src_x_offset = src_y_offset + j * ic8 * C8NUM; - int dst_x_offset = dst_y_offset + j * C8NUM; - const float16_t *src_addr = input_data + src_x_offset; - float16_t *dst_addr = tmp_data + dst_x_offset; + + // get real input block with padding + if (real_c == C8NUM) { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = input_data + src_x_offset; + float16_t *dst_addr = tmp_data + dst_x_offset; #ifdef ENABLE_NEON - vst1q_f16(dst_addr, vld1q_f16(src_addr)); + vst1q_f16(dst_addr, vld1q_f16(src_addr)); #else - for (int k = 0; k < C8NUM; k++) { - dst_addr[k] = src_addr[k]; - } + for (int k = 0; k < C8NUM; k++) { + dst_addr[k] = src_addr[k]; + } #endif + } + } + } else { + for (int interval = interval_y_s; interval < interval_y_e; interval++) { + int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel; + int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM; + for (int j = 0; j < (interval_x_e - interval_x_s); j++) { + int src_x_offset = src_y_offset + j * in_channel; + int dst_x_offset = dst_y_offset + j * C8NUM; + const float16_t *src_addr = input_data + src_x_offset; + float16_t *dst_addr = tmp_data + dst_x_offset; + for (int k = 0; k < real_c; k++) { + dst_addr[k] = src_addr[k]; + } + } } } + // input transform int dst_ic8_offset = dst_plane_offset + ic * tile_num * C8NUM; size_t dst_step = ic8 * C8NUM * tile_num; float16_t *trans_input_ptr = trans_input + dst_ic8_offset; - GeneralInputTransformUnitFp16(tmp_data, trans_input_ptr, matrix_b, matrix_bt, C8NUM, dst_step, input_unit); + func(tmp_data, trans_input_ptr, C8NUM, dst_step); } out_tile_index++; } // cal_tile_num loop @@ -630,12 +650,10 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param, - float16_t *matrix_a, float16_t *matrix_at) { + OutputTransFp16Func func) { int output_unit = conv_param->output_unit_; int output_w = conv_param->output_w_; int output_h = conv_param->output_h_; - int output_w_unit_block = UP_DIV(output_w, output_unit); - int output_h_unit_block = UP_DIV(output_h, output_unit); int output_channel = conv_param->output_channel_; int oc8 = UP_DIV(output_channel, C8NUM); int input_unit = conv_param->input_unit_; @@ -645,18 +663,27 @@ void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_d for (int i = 0; i < cal_num; i++) { int dst_x_s = out_tile_index % output_unit_num; int dst_y_s = out_tile_index / output_unit_num; + int r_w = output_w - dst_x_s * output_unit; + r_w = r_w > output_unit ? output_unit : r_w; + int r_h = output_h - dst_y_s * output_unit; + r_h = r_h > output_unit ? output_unit : r_h; + int tmp_ix = dst_x_s * output_unit; + dst_x_s = tmp_ix > output_w ? output_w : tmp_ix; + int tmp_iy = dst_y_s * output_unit; + dst_y_s = tmp_iy > output_h ? output_h : tmp_iy; + int src_tile_offset = i * oc8 * C8NUM * input_unit * input_unit; - int dst_tile_offset = C8NUM * output_unit * (dst_x_s + dst_y_s * output_w_unit_block * output_unit); + int dst_tile_offset = output_channel * (dst_x_s + dst_y_s * output_w); for (int j = 0; j < oc8; j++) { + int r_c = output_channel - j * C8NUM; + r_c = r_c > C8NUM ? C8NUM : r_c; int src_oc8_offset = src_tile_offset + j * input_unit * input_unit * C8NUM; - int dst_oc8_offset = - dst_tile_offset + j * C8NUM * output_h_unit_block * output_w_unit_block * output_unit * output_unit; + int dst_oc8_offset = dst_tile_offset + j * C8NUM; const float16_t *src_ptr = gemm_out + src_oc8_offset; const float16_t *bias_ptr = bias_data + j * C8NUM; float16_t *dst_ptr = tmp_out_data + dst_oc8_offset; - GeneralOutputTransformUnitFp16(src_ptr, dst_ptr, bias_ptr, matrix_a, matrix_at, C8NUM, - output_w_unit_block * output_unit, input_unit, output_unit); + func(src_ptr, dst_ptr, bias_ptr, C8NUM, output_w, output_channel, r_w, r_h, r_c); } out_tile_index++; } diff --git a/mindspore/lite/nnacl/fp16/winograd_transform_fp16.h b/mindspore/lite/nnacl/fp16/winograd_transform_fp16.h index 0ae23b2264..eaedbc498d 100644 --- a/mindspore/lite/nnacl/fp16/winograd_transform_fp16.h +++ b/mindspore/lite/nnacl/fp16/winograd_transform_fp16.h @@ -43,12 +43,12 @@ void Conv3x3Fp16OutputTransform(const float16_t *gemm_out, float16_t *out_data, // fp16 common winograd void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_input, float16_t *tmp_data, int cal_num, - int out_tile_index, int out_w_block_num, ConvParameter *conv_param, float16_t *matrix_b, - float16_t *matrix_bt); + int out_tile_index, int out_w_block_num, ConvParameter *conv_param, + InputTransFp16Func func); void WinogradOutputTransformFp16(const float16_t *gemm_out, float16_t *tmp_out_data, const float16_t *bias_data, int cal_num, int out_tile_index, int output_unit_num, ConvParameter *conv_param, - float16_t *matrix_a, float16_t *matrix_at); + OutputTransFp16Func func); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp16/winograd_utils_fp16.c b/mindspore/lite/nnacl/fp16/winograd_utils_fp16.c index d70789cfac..f3d6ab8086 100644 --- a/mindspore/lite/nnacl/fp16/winograd_utils_fp16.c +++ b/mindspore/lite/nnacl/fp16/winograd_utils_fp16.c @@ -76,3 +76,2088 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da } } } + +static InputTransFp16Func InputTransFp16FuncList[] = { + NULL, NULL, NULL, NULL, InputTransform4x4UnitFp16, NULL, InputTransform6x6UnitFp16, NULL, InputTransform8x8UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList4[] = {NULL, NULL, OutputTransform4x2UnitFp16, + OutputTransform4x3UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList4[] = {NULL, NULL, OutputTransform4x2ReluUnitFp16, + OutputTransform4x3ReluUnitFp16}; +static OutputTransFp16Func OutputTransFp16FuncRelu6List4[] = {NULL, NULL, OutputTransform4x2Relu6UnitFp16, + OutputTransform4x3Relu6UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList6[] = {NULL, + NULL, + OutputTransform6x2UnitFp16, + OutputTransform6x3UnitFp16, + OutputTransform6x4UnitFp16, + OutputTransform6x5UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList6[] = {NULL, + NULL, + OutputTransform6x2ReluUnitFp16, + OutputTransform6x3ReluUnitFp16, + OutputTransform6x4ReluUnitFp16, + OutputTransform6x5ReluUnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncRelu6List6[] = {NULL, + NULL, + OutputTransform6x2Relu6UnitFp16, + OutputTransform6x3Relu6UnitFp16, + OutputTransform6x4Relu6UnitFp16, + OutputTransform6x5Relu6UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncList8[] = {NULL, + NULL, + OutputTransform8x2UnitFp16, + OutputTransform8x3UnitFp16, + OutputTransform8x4UnitFp16, + OutputTransform8x5UnitFp16, + OutputTransform8x6UnitFp16, + OutputTransform8x7UnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncReluList8[] = {NULL, + NULL, + OutputTransform8x2ReluUnitFp16, + OutputTransform8x3ReluUnitFp16, + OutputTransform8x4ReluUnitFp16, + OutputTransform8x5ReluUnitFp16, + OutputTransform8x6ReluUnitFp16, + OutputTransform8x7ReluUnitFp16}; + +static OutputTransFp16Func OutputTransFp16FuncRelu6List8[] = {NULL, + NULL, + OutputTransform8x2Relu6UnitFp16, + OutputTransform8x3Relu6UnitFp16, + OutputTransform8x4Relu6UnitFp16, + OutputTransform8x5Relu6UnitFp16, + OutputTransform8x6Relu6UnitFp16, + OutputTransform8x7Relu6UnitFp16}; + +InputTransFp16Func GetInputTransFp16Func(int input_unit) { return InputTransFp16FuncList[input_unit]; } + +void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step) { + float16x8_t src[16]; + float16x8_t t[16]; + float16x8_t m[16]; + Load16DataFp16; + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vsubq_f16(src[offset], src[2 + offset]); + t[4 + l] = vaddq_f16(src[1 + offset], src[2 + offset]); + t[8 + l] = vsubq_f16(src[2 + offset], src[1 + offset]); + t[12 + l] = vsubq_f16(src[3 + offset], src[1 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + m[l] = vsubq_f16(t[offset], t[2 + offset]); + m[4 + l] = vaddq_f16(t[1 + offset], t[2 + offset]); + m[8 + l] = vsubq_f16(t[2 + offset], t[1 + offset]); + m[12 + l] = vsubq_f16(t[3 + offset], t[1 + offset]); + } + for (int i = 0; i < 16; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, vget_low_f16(m[i])); + vst1_f16(dst_data + dst_offset + 64, vget_high_f16(m[i])); + } +} + +void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step) { + float16x8_t src[36]; + float16x8_t t[36]; + float16x8_t m[36]; + Load36DataFp16; + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vsubq_f16(src[3 + offset], src[1 + offset]); + float16x8_t tmp2 = vsubq_f16(src[4 + offset], src[2 + offset]); + t[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 4), vmulq_n_f16(src[2 + offset], 5)), src[4 + offset]); + t[6 + l] = vaddq_f16(vmulq_n_f16(vaddq_f16(src[1 + offset], src[2 + offset]), -4), + vaddq_f16(src[3 + offset], src[4 + offset])); + t[12 + l] = vaddq_f16(vmulq_n_f16(vsubq_f16(src[1 + offset], src[2 + offset]), 4), + vsubq_f16(src[4 + offset], src[3 + offset])); + t[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + t[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + t[30 + l] = vaddq_f16(vsubq_f16(vmulq_n_f16(src[1 + offset], 4), vmulq_n_f16(src[3 + offset], 5)), src[5 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vsubq_f16(t[3 + offset], t[1 + offset]); + float16x8_t tmp2 = vsubq_f16(t[4 + offset], t[2 + offset]); + m[l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 4), vmulq_n_f16(t[2 + offset], 5)), t[4 + offset]); + m[6 + l] = + vaddq_f16(vmulq_n_f16(vaddq_f16(t[1 + offset], t[2 + offset]), -4), vaddq_f16(t[3 + offset], t[4 + offset])); + m[12 + l] = + vaddq_f16(vmulq_n_f16(vsubq_f16(t[1 + offset], t[2 + offset]), 4), vsubq_f16(t[4 + offset], t[3 + offset])); + m[18 + l] = vaddq_f16(vmulq_n_f16(tmp1, 2), tmp2); + m[24 + l] = vaddq_f16(vmulq_n_f16(tmp1, -2), tmp2); + m[30 + l] = vaddq_f16(vsubq_f16(vmulq_n_f16(t[1 + offset], 4), vmulq_n_f16(t[3 + offset], 5)), t[5 + offset]); + } + for (int i = 0; i < 36; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, vget_low_f16(m[i])); + vst1_f16(dst_data + dst_offset + 64, vget_high_f16(m[i])); + } +} + +void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step) { + float16x8_t src[64]; + float16x8_t t[64]; + float16x8_t m[64]; + Load64DataFp16; + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + t[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(src[offset], 0.5625), vmulq_n_f16(src[2 + offset], 3.0625)), + vmulq_n_f16(src[4 + offset], 3.5)), + src[6 + offset]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 1.125), vmulq_n_f16(src[5 + offset], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 2.25), vmulq_n_f16(src[4 + offset], 3.25)); + t[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + t[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.625)), src[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.5625), src[5 + offset]); + tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.5625), vmulq_n_f16(src[4 + offset], 2.5)); + t[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + t[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 2.5)), src[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(src[1 + offset], 0.375), vmulq_n_f16(src[5 + offset], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(src[2 + offset], 0.25), vmulq_n_f16(src[4 + offset], 1.25)); + t[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(src[3 + offset], 1.875)), src[6 + offset]); + t[56 + l] = + vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(src[1 + offset], -0.5625), vmulq_n_f16(src[3 + offset], 3.0625)), + vmulq_n_f16(src[5 + offset], 3.5)), + src[7 + offset]); + } + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + m[l] = vsubq_f16(vaddq_f16(vsubq_f16(vmulq_n_f16(t[offset], 0.5625), vmulq_n_f16(t[2 + offset], 3.0625)), + vmulq_n_f16(t[4 + offset], 3.5)), + t[6 + offset]); + float16x8_t tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 1.125), vmulq_n_f16(t[5 + offset], 0.5)); + float16x8_t tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 2.25), vmulq_n_f16(t[4 + offset], 3.25)); + m[8 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + m[16 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.625)), t[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.5625), t[5 + offset]); + tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.5625), vmulq_n_f16(t[4 + offset], 2.5)); + m[24 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + m[32 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 2.5)), t[6 + offset]); + tmp1 = vaddq_f16(vmulq_n_f16(t[1 + offset], 0.375), vmulq_n_f16(t[5 + offset], 1.5)); + tmp2 = vsubq_f16(vmulq_n_f16(t[2 + offset], 0.25), vmulq_n_f16(t[4 + offset], 1.25)); + m[40 + l] = vaddq_f16(vsubq_f16(vaddq_f16(tmp1, tmp2), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[48 + l] = vaddq_f16(vaddq_f16(vsubq_f16(tmp2, tmp1), vmulq_n_f16(t[3 + offset], 1.875)), t[6 + offset]); + m[56 + l] = vaddq_f16(vsubq_f16(vaddq_f16(vmulq_n_f16(t[1 + offset], -0.5625), vmulq_n_f16(t[3 + offset], 3.0625)), + vmulq_n_f16(t[5 + offset], 3.5)), + t[7 + offset]); + } + for (int i = 0; i < 64; i++) { + int dst_offset = i * dst_step; + vst1_f16(dst_data + dst_offset, vget_low_f16(m[i])); + vst1_f16(dst_data + dst_offset + 64, vget_high_f16(m[i])); + } +} + +OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type) { + if (input_unit == 4 && output_unit < 4) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList4[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List4[output_unit]; + } else { + return OutputTransFp16FuncList4[output_unit]; + } + } else if (input_unit == 6 && output_unit < 6) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList6[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List6[output_unit]; + } else { + return OutputTransFp16FuncList6[output_unit]; + } + } else if (input_unit == 8 && output_unit < 8) { + if (act_type == ActType_Relu) { + return OutputTransFp16FuncReluList8[output_unit]; + } else if (act_type == ActType_Relu6) { + return OutputTransFp16FuncRelu6List8[output_unit]; + } else { + return OutputTransFp16FuncList8[output_unit]; + } + } else { + return NULL; + } +} + +void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform4x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[8]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + t[l] = vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]); + t[l + 4] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), src[3 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 4; + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform4x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform4x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[16]; + float16x8_t t[12]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load16DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 4; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(src[1 + offset], src[2 + offset]); + t[l] = vaddq_f16(src[offset], tmp); + t[l + 4] = vsubq_f16(src[1 + offset], src[2 + offset]); + t[l + 8] = vaddq_f16(tmp, src[3 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 4; + float16x8_t tmp = vaddq_f16(t[1 + offset], t[2 + offset]); + m[l] = vaddq_f16(vaddq_f16(t[offset], tmp), bias_ptr); + m[l + 3] = vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(tmp, t[3 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[12]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(src[offset], src[1 + offset]), src[2 + offset]), src[3 + offset]), + src[4 + offset]); + t[l + 6] = vaddq_f16(vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)), + src[5 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 6; + m[l] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], t[1 + offset]), t[2 + offset]), t[3 + offset]), t[4 + offset]), + bias_ptr); + m[l + 2] = vaddq_f16(vaddq_f16(vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), + vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + t[5 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[18]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(vsubq_f16(src[1 + offset], src[2 + offset]), + vmulq_n_f16(vsubq_f16(src[3 + offset], src[4 + offset]), 2)); + t[l + 12] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), src[5 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 3] = vaddq_f16( + vaddq_f16(vsubq_f16(t[1 + offset], t[2 + offset]), vmulq_n_f16(vsubq_f16(t[3 + offset], t[4 + offset]), 2)), + bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[24]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), src[5 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 4] = vminq_f16(six, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 8] = vminq_f16(six, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform6x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[36]; + float16x8_t t[30]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load36DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 6; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp4 = vsubq_f16(src[3 + offset], src[4 + offset]); + t[l] = vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2); + t[l + 6] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)); + t[l + 12] = vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)); + t[l + 18] = vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)); + t[l + 24] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), src[5 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 6; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp4 = vsubq_f16(t[3 + offset], t[4 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 2)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 4)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(tmp3, vmulq_n_f16(tmp4, 8)), bias_ptr); + m[l + 20] = vaddq_f16(vaddq_f16(vaddq_f16(tmp1, vmulq_n_f16(tmp2, 16)), t[5 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 5] = vminq_f16(six, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 10] = vminq_f16(six, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 15] = vminq_f16(six, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + m[l + 20] = vminq_f16(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[16]; + float16x8_t m[4]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), src[7 + offset]); + } + for (int l = 0; l < 2; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 2] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 2] = vmaxq_f16(zero, m[l + 2]); + m[l + 2] = vminq_f16(six, m[l + 2]); + } + if (r_c == C8NUM && r_h == 2 && r_w == 2) { + Store4DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 2; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[24]; + float16x8_t m[9]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), src[7 + offset]); + } + for (int l = 0; l < 3; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 3] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 6] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), t[7 + offset]), bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 3] = vmaxq_f16(zero, m[l + 3]); + m[l + 3] = vminq_f16(six, m[l + 3]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + } + if (r_c == C8NUM && r_h == 3 && r_w == 3) { + Store9DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 3; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[32]; + float16x8_t m[16]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), src[7 + offset]); + } + for (int l = 0; l < 4; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 4] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 8] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 12] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 4] = vmaxq_f16(zero, m[l + 4]); + m[l + 4] = vminq_f16(six, m[l + 4]); + m[l + 8] = vmaxq_f16(zero, m[l + 8]); + m[l + 8] = vminq_f16(six, m[l + 8]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + } + if (r_c == C8NUM && r_h == 4 && r_w == 4) { + Store16DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 4; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[40]; + float16x8_t m[25]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), src[7 + offset]); + } + for (int l = 0; l < 5; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 5] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 10] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 15] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 20] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 5] = vmaxq_f16(zero, m[l + 5]); + m[l + 5] = vminq_f16(six, m[l + 5]); + m[l + 10] = vmaxq_f16(zero, m[l + 10]); + m[l + 10] = vminq_f16(six, m[l + 10]); + m[l + 15] = vmaxq_f16(zero, m[l + 15]); + m[l + 15] = vminq_f16(six, m[l + 15]); + m[l + 20] = vmaxq_f16(zero, m[l + 20]); + m[l + 20] = vminq_f16(six, m[l + 20]); + } + if (r_c == C8NUM && r_h == 5 && r_w == 5) { + Store25DataFp16; + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 5; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x6ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 18] = vmaxq_f16(zero, m[l + 18]); + m[l + 24] = vmaxq_f16(zero, m[l + 24]); + m[l + 30] = vmaxq_f16(zero, m[l + 30]); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x6Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[48]; + float16x8_t m[36]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), src[7 + offset]); + } + for (int l = 0; l < 6; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 6] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 12] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 18] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 24] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 30] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 6] = vmaxq_f16(zero, m[l + 6]); + m[l + 6] = vminq_f16(six, m[l + 6]); + m[l + 12] = vmaxq_f16(zero, m[l + 12]); + m[l + 12] = vminq_f16(six, m[l + 12]); + m[l + 18] = vmaxq_f16(zero, m[l + 18]); + m[l + 18] = vminq_f16(six, m[l + 18]); + m[l + 24] = vmaxq_f16(zero, m[l + 24]); + m[l + 24] = vminq_f16(six, m[l + 24]); + m[l + 30] = vmaxq_f16(zero, m[l + 30]); + m[l + 30] = vminq_f16(six, m[l + 30]); + } + if (r_c == C8NUM && r_h == 6 && r_w == 6) { + for (int i = 0; i < 6; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 6; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 6; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x7ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + float16x8_t zero = vdupq_n_f16(0); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l + 7] = vmaxq_f16(zero, m[l + 7]); + m[l + 14] = vmaxq_f16(zero, m[l + 14]); + m[l + 21] = vmaxq_f16(zero, m[l + 21]); + m[l + 28] = vmaxq_f16(zero, m[l + 28]); + m[l + 35] = vmaxq_f16(zero, m[l + 35]); + m[l + 42] = vmaxq_f16(zero, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} + +void OutputTransform8x7Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c) { + float16x8_t src[64]; + float16x8_t t[56]; + float16x8_t m[49]; + float16x8_t zero = vdupq_n_f16(0); + float16x8_t six = vdupq_n_f16(6); + Load64DataFp16; + float16x8_t bias_ptr = vld1q_f16(bias_data); + for (int l = 0; l < 8; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp2 = vaddq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp3 = vaddq_f16(src[5 + offset], src[6 + offset]); + float16x8_t tmp4 = vsubq_f16(src[1 + offset], src[2 + offset]); + float16x8_t tmp5 = vsubq_f16(src[3 + offset], src[4 + offset]); + float16x8_t tmp6 = vsubq_f16(src[5 + offset], src[6 + offset]); + t[l] = vaddq_f16(vaddq_f16(vaddq_f16(src[offset], tmp1), tmp2), tmp3); + t[l + 8] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)); + t[l + 16] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)); + t[l + 24] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)); + t[l + 32] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)); + t[l + 40] = vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)); + t[l + 48] = + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), src[7 + offset]); + } + for (int l = 0; l < 7; ++l) { + int offset = l * 8; + float16x8_t tmp1 = vaddq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp2 = vaddq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp3 = vaddq_f16(t[5 + offset], t[6 + offset]); + float16x8_t tmp4 = vsubq_f16(t[1 + offset], t[2 + offset]); + float16x8_t tmp5 = vsubq_f16(t[3 + offset], t[4 + offset]); + float16x8_t tmp6 = vsubq_f16(t[5 + offset], t[6 + offset]); + m[l] = vaddq_f16(vaddq_f16(vaddq_f16(vaddq_f16(t[offset], tmp1), tmp2), tmp3), bias_ptr); + m[l + 7] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.5), tmp5), vmulq_n_f16(tmp6, 1.5)), bias_ptr); + m[l + 14] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.25), tmp2), vmulq_n_f16(tmp3, 2.25)), bias_ptr); + m[l + 21] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.125), tmp5), vmulq_n_f16(tmp6, 3.375)), bias_ptr); + m[l + 28] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.0625), tmp2), vmulq_n_f16(tmp3, 5.0625)), bias_ptr); + m[l + 35] = vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp4, 0.03125), tmp5), vmulq_n_f16(tmp6, 7.59375)), bias_ptr); + m[l + 42] = vaddq_f16( + vaddq_f16(vaddq_f16(vaddq_f16(vmulq_n_f16(tmp1, 0.015625), tmp2), vmulq_n_f16(tmp3, 11.390625)), t[7 + offset]), + bias_ptr); + m[l] = vmaxq_f16(zero, m[l]); + m[l] = vminq_f16(six, m[l]); + m[l + 7] = vmaxq_f16(zero, m[l + 7]); + m[l + 7] = vminq_f16(six, m[l + 7]); + m[l + 14] = vmaxq_f16(zero, m[l + 14]); + m[l + 14] = vminq_f16(six, m[l + 14]); + m[l + 21] = vmaxq_f16(zero, m[l + 21]); + m[l + 21] = vminq_f16(six, m[l + 21]); + m[l + 28] = vmaxq_f16(zero, m[l + 28]); + m[l + 28] = vminq_f16(six, m[l + 28]); + m[l + 35] = vmaxq_f16(zero, m[l + 35]); + m[l + 35] = vminq_f16(six, m[l + 35]); + m[l + 42] = vmaxq_f16(zero, m[l + 42]); + m[l + 42] = vminq_f16(six, m[l + 42]); + } + if (r_c == C8NUM && r_h == 7 && r_w == 7) { + for (int i = 0; i < 7; i++) { + int dst_k_offset = i * dst_step * out_c; + int m_k_offset = i * 7; + vst1q_f16(dst_data + dst_k_offset + 0 * out_c, m[m_k_offset]); + vst1q_f16(dst_data + dst_k_offset + 1 * out_c, m[m_k_offset + 1]); + vst1q_f16(dst_data + dst_k_offset + 2 * out_c, m[m_k_offset + 2]); + vst1q_f16(dst_data + dst_k_offset + 3 * out_c, m[m_k_offset + 3]); + vst1q_f16(dst_data + dst_k_offset + 4 * out_c, m[m_k_offset + 4]); + vst1q_f16(dst_data + dst_k_offset + 5 * out_c, m[m_k_offset + 5]); + vst1q_f16(dst_data + dst_k_offset + 6 * out_c, m[m_k_offset + 6]); + } + } else { + for (int i = 0; i < r_c; i++) { + for (int j = 0; j < r_h; j++) { + int dst_k_offset = j * dst_step * out_c; + int m_k_offset = j * 7; + for (int k = 0; k < r_w; k++) { + dst_data[i + dst_k_offset + k * out_c] = m[k + m_k_offset][i]; + } + } + } + } +} diff --git a/mindspore/lite/nnacl/fp16/winograd_utils_fp16.h b/mindspore/lite/nnacl/fp16/winograd_utils_fp16.h index b961f6a2da..ac057e4b79 100644 --- a/mindspore/lite/nnacl/fp16/winograd_utils_fp16.h +++ b/mindspore/lite/nnacl/fp16/winograd_utils_fp16.h @@ -26,12 +26,286 @@ #ifdef __cplusplus extern "C" { #endif +typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); + +typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + void GeneralInputTransformUnitFp16(const float16_t *src_data, float16_t *dst_data, float16_t *matrix_b, float16_t *matrix_bt, int src_step, int dst_step, int in_unit); void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, float16_t *matrix_a, float16_t *matrix_at, int src_step, int dst_step, int in_unit, int out_unit); + +#define Load16DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); + +#define Load36DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); \ + src[16] = vld1q_f16(src_data + 16 * src_step); \ + src[17] = vld1q_f16(src_data + 17 * src_step); \ + src[18] = vld1q_f16(src_data + 18 * src_step); \ + src[19] = vld1q_f16(src_data + 19 * src_step); \ + src[20] = vld1q_f16(src_data + 20 * src_step); \ + src[21] = vld1q_f16(src_data + 21 * src_step); \ + src[22] = vld1q_f16(src_data + 22 * src_step); \ + src[23] = vld1q_f16(src_data + 23 * src_step); \ + src[24] = vld1q_f16(src_data + 24 * src_step); \ + src[25] = vld1q_f16(src_data + 25 * src_step); \ + src[26] = vld1q_f16(src_data + 26 * src_step); \ + src[27] = vld1q_f16(src_data + 27 * src_step); \ + src[28] = vld1q_f16(src_data + 28 * src_step); \ + src[29] = vld1q_f16(src_data + 29 * src_step); \ + src[30] = vld1q_f16(src_data + 30 * src_step); \ + src[31] = vld1q_f16(src_data + 31 * src_step); \ + src[32] = vld1q_f16(src_data + 32 * src_step); \ + src[33] = vld1q_f16(src_data + 33 * src_step); \ + src[34] = vld1q_f16(src_data + 34 * src_step); \ + src[35] = vld1q_f16(src_data + 35 * src_step); + +#define Load64DataFp16 \ + src[0] = vld1q_f16(src_data + 0 * src_step); \ + src[1] = vld1q_f16(src_data + 1 * src_step); \ + src[2] = vld1q_f16(src_data + 2 * src_step); \ + src[3] = vld1q_f16(src_data + 3 * src_step); \ + src[4] = vld1q_f16(src_data + 4 * src_step); \ + src[5] = vld1q_f16(src_data + 5 * src_step); \ + src[6] = vld1q_f16(src_data + 6 * src_step); \ + src[7] = vld1q_f16(src_data + 7 * src_step); \ + src[8] = vld1q_f16(src_data + 8 * src_step); \ + src[9] = vld1q_f16(src_data + 9 * src_step); \ + src[10] = vld1q_f16(src_data + 10 * src_step); \ + src[11] = vld1q_f16(src_data + 11 * src_step); \ + src[12] = vld1q_f16(src_data + 12 * src_step); \ + src[13] = vld1q_f16(src_data + 13 * src_step); \ + src[14] = vld1q_f16(src_data + 14 * src_step); \ + src[15] = vld1q_f16(src_data + 15 * src_step); \ + src[16] = vld1q_f16(src_data + 16 * src_step); \ + src[17] = vld1q_f16(src_data + 17 * src_step); \ + src[18] = vld1q_f16(src_data + 18 * src_step); \ + src[19] = vld1q_f16(src_data + 19 * src_step); \ + src[20] = vld1q_f16(src_data + 20 * src_step); \ + src[21] = vld1q_f16(src_data + 21 * src_step); \ + src[22] = vld1q_f16(src_data + 22 * src_step); \ + src[23] = vld1q_f16(src_data + 23 * src_step); \ + src[24] = vld1q_f16(src_data + 24 * src_step); \ + src[25] = vld1q_f16(src_data + 25 * src_step); \ + src[26] = vld1q_f16(src_data + 26 * src_step); \ + src[27] = vld1q_f16(src_data + 27 * src_step); \ + src[28] = vld1q_f16(src_data + 28 * src_step); \ + src[29] = vld1q_f16(src_data + 29 * src_step); \ + src[30] = vld1q_f16(src_data + 30 * src_step); \ + src[31] = vld1q_f16(src_data + 31 * src_step); \ + src[32] = vld1q_f16(src_data + 32 * src_step); \ + src[33] = vld1q_f16(src_data + 33 * src_step); \ + src[34] = vld1q_f16(src_data + 34 * src_step); \ + src[35] = vld1q_f16(src_data + 35 * src_step); \ + src[36] = vld1q_f16(src_data + 36 * src_step); \ + src[37] = vld1q_f16(src_data + 37 * src_step); \ + src[38] = vld1q_f16(src_data + 38 * src_step); \ + src[39] = vld1q_f16(src_data + 39 * src_step); \ + src[40] = vld1q_f16(src_data + 40 * src_step); \ + src[41] = vld1q_f16(src_data + 41 * src_step); \ + src[42] = vld1q_f16(src_data + 42 * src_step); \ + src[43] = vld1q_f16(src_data + 43 * src_step); \ + src[44] = vld1q_f16(src_data + 44 * src_step); \ + src[45] = vld1q_f16(src_data + 45 * src_step); \ + src[46] = vld1q_f16(src_data + 46 * src_step); \ + src[47] = vld1q_f16(src_data + 47 * src_step); \ + src[48] = vld1q_f16(src_data + 48 * src_step); \ + src[49] = vld1q_f16(src_data + 49 * src_step); \ + src[50] = vld1q_f16(src_data + 50 * src_step); \ + src[51] = vld1q_f16(src_data + 51 * src_step); \ + src[52] = vld1q_f16(src_data + 52 * src_step); \ + src[53] = vld1q_f16(src_data + 53 * src_step); \ + src[54] = vld1q_f16(src_data + 54 * src_step); \ + src[55] = vld1q_f16(src_data + 55 * src_step); \ + src[56] = vld1q_f16(src_data + 56 * src_step); \ + src[57] = vld1q_f16(src_data + 57 * src_step); \ + src[58] = vld1q_f16(src_data + 58 * src_step); \ + src[59] = vld1q_f16(src_data + 59 * src_step); \ + src[60] = vld1q_f16(src_data + 60 * src_step); \ + src[61] = vld1q_f16(src_data + 61 * src_step); \ + src[62] = vld1q_f16(src_data + 62 * src_step); \ + src[63] = vld1q_f16(src_data + 63 * src_step); + +InputTransFp16Func GetInputTransFp16Func(int input_unit); + +void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); + +void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); + +void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step); + +OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type); + +#define Store4DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + dst_step * out_c, m[2]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[3]); + +#define Store9DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + dst_step * out_c, m[3]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[5]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[6]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]); + +#define Store16DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + 3 * out_c, m[3]); \ + vst1q_f16(dst_data + dst_step * out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[5]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[6]); \ + vst1q_f16(dst_data + dst_step * out_c + 3 * out_c, m[7]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[8]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[9]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c, m[12]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + out_c, m[13]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]); + +#define Store25DataFp16 \ + vst1q_f16(dst_data, m[0]); \ + vst1q_f16(dst_data + out_c, m[1]); \ + vst1q_f16(dst_data + 2 * out_c, m[2]); \ + vst1q_f16(dst_data + 3 * out_c, m[3]); \ + vst1q_f16(dst_data + 4 * out_c, m[4]); \ + vst1q_f16(dst_data + dst_step * out_c, m[5]); \ + vst1q_f16(dst_data + dst_step * out_c + out_c, m[6]); \ + vst1q_f16(dst_data + dst_step * out_c + 2 * out_c, m[7]); \ + vst1q_f16(dst_data + dst_step * out_c + 3 * out_c, m[8]); \ + vst1q_f16(dst_data + dst_step * out_c + 4 * out_c, m[9]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c, m[10]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[11]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \ + vst1q_f16(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c, m[15]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + out_c, m[16]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \ + vst1q_f16(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c, m[20]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + out_c, m[21]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \ + vst1q_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]); + +void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform4x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform6x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform6x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); + +void OutputTransform8x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x2Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x3Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x4Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x5Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x6Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); +void OutputTransform8x7Relu6UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data, + int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/fp32/conv.c b/mindspore/lite/nnacl/fp32/conv.c index 2778ec8f85..cb65c3159c 100644 --- a/mindspore/lite/nnacl/fp32/conv.c +++ b/mindspore/lite/nnacl/fp32/conv.c @@ -41,8 +41,7 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons #endif int output_tile_count = UP_DIV(output_count, cal_num); int kernel_plane = kernel_h * kernel_w; - int unit_size = kernel_plane * in_channel; - int deep = in_channel * kernel_plane; + int deep = kernel_plane * in_channel; for (int b = 0; b < in_batch; b++) { int in_batch_offset = b * in_channel * in_h * in_w; @@ -50,9 +49,9 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) { int start_index = thread_id * cal_num; int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num; - float *gemm_input = packed_input + task_id * unit_size * cal_num; - float *col_major_gemm_input = col_major_input + task_id * unit_size * cal_num; - size_t packed_input_size = unit_size * cal_num * sizeof(float); + float *gemm_input = packed_input + task_id * deep * cal_num; + float *col_major_gemm_input = col_major_input + task_id * deep * cal_num; + size_t packed_input_size = deep * cal_num * sizeof(float); memset(gemm_input, 0, packed_input_size); memset(col_major_gemm_input, 0, packed_input_size); Im2ColPackUnitFp32(input_data + in_batch_offset, conv_param, gemm_input, real_cal_num, start_index); @@ -95,8 +94,8 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_ float *trans_input = buffer_list[0]; float *gemm_out = buffer_list[1]; - float *tmp_data = buffer_list[3]; - float *col_buffer = buffer_list[4]; + float *tmp_data = buffer_list[2]; + float *col_buffer = buffer_list[3]; int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM; int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM; int tmp_data_offset = input_unit_square * C4NUM; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index a7ae4dd4c7..ca61d26254 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -16,7 +16,6 @@ #include "src/runtime/kernel/arm/fp16/convolution_fp16.h" #include -#include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h" #include "src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h" #include "src/runtime/kernel/arm/fp16/convolution_3x3_fp16.h" #include "src/runtime/kernel/arm/fp16/convolution_1x1_fp16.h" @@ -203,19 +202,13 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector & auto conv_param = reinterpret_cast(opParameter); int kernel_h = conv_param->kernel_h_; int kernel_w = conv_param->kernel_w_; - int stride_h = conv_param->stride_h_; - int stride_w = conv_param->stride_w_; - int dilation_h = conv_param->dilation_h_; - int dilation_w = conv_param->dilation_w_; conv_param->input_h_ = inputs.front()->Height(); conv_param->input_w_ = inputs.front()->Width(); conv_param->output_h_ = outputs.front()->Height(); conv_param->output_w_ = outputs.front()->Width(); kernel::LiteKernel *kernel = nullptr; - if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { - kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); - } else if (kernel_h == 1 && kernel_w == 1) { + if (kernel_h == 1 && kernel_w == 1) { kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive); } else { bool use_winograd = false; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.cc deleted file mode 100644 index bdbcc3a182..0000000000 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.cc +++ /dev/null @@ -1,236 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "src/runtime/kernel/arm/fp16/convolution_sw_fp16.h" -#include -#include "nnacl/fp16/conv_fp16.h" -#include "nnacl/fp16/cast_fp16.h" -#include "nnacl/fp16/pack_fp16.h" -#include "nnacl/fp32/conv_depthwise.h" -#include "src/runtime/kernel/arm/fp16/layout_transform_fp16.h" -#include "schema/model_generated.h" -#include "src/kernel_registry.h" -#include "include/errorcode.h" -#include "src/runtime/runtime_api.h" - -using mindspore::kernel::KERNEL_ARCH::kCPU; -using mindspore::lite::KernelRegistrar; -using mindspore::lite::RET_ERROR; -using mindspore::lite::RET_OK; -using mindspore::schema::PrimitiveType_Conv2D; - -namespace mindspore::kernel { -int ConvolutionSWFP16CPUKernel::ProcessFilter() { - int kernel_h = conv_param_->kernel_h_; - int kernel_w = conv_param_->kernel_w_; - int in_channel = conv_param_->input_channel_; - int out_channel = conv_param_->output_channel_; - int ic4 = UP_DIV(in_channel, C4NUM); - - auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Get Execute filter failed."; - return ret; - } - - for (int oc = 0; oc < out_channel; ++oc) { - int src_oc_offset = oc * kernel_h * kernel_w * in_channel; - int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM; - for (int i = 0; i < kernel_h * kernel_w; ++i) { - const float16_t *src = execute_weight_ + src_oc_offset + i * in_channel; - float16_t *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM; - memcpy(dst, src, in_channel * sizeof(float16_t)); - } - } - - return RET_OK; -} - -int ConvolutionSWFP16CPUKernel::InitWeightBias() { - auto filter_tensor = in_tensors_.at(kWeightIndex); - int kernel_h = filter_tensor->Height(); - int kernel_w = filter_tensor->Width(); - int in_channel = filter_tensor->Channel(); - int out_channel = filter_tensor->Batch(); - conv_param_->input_channel_ = in_channel; - conv_param_->output_channel_ = out_channel; - int oc4 = UP_DIV(out_channel, C4NUM); - int ic4 = UP_DIV(in_channel, C4NUM); - int kernel_plane = kernel_h * kernel_w; - int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane; - - packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float16_t))); - if (packed_weight_ == nullptr) { - MS_LOG(ERROR) << "malloc packed_weight_ failed."; - return RET_ERROR; - } - memset(packed_weight_, 0, pack_weight_size * sizeof(float16_t)); - auto ret = ProcessFilter(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Process filter failed."; - return ret; - } - - bias_data_ = malloc(oc4 * C4NUM * sizeof(float16_t)); - if (bias_data_ == nullptr) { - MS_LOG(ERROR) << "malloc bias_data_ failed."; - return RET_ERROR; - } - memset(bias_data_, 0, oc4 * C4NUM * sizeof(float16_t)); - auto fp16_bias_data = reinterpret_cast(bias_data_); - if (in_tensors_.size() == kInputSize2) { - auto ori_bias = reinterpret_cast(in_tensors_.at(kBiasIndex)->MutableData()); - for (int i = 0; i < out_channel; ++i) { - fp16_bias_data[i] = (float16_t)ori_bias[i]; - } - } else { - MS_ASSERT(in_tensor_.size() == kInputSize1); - } - return RET_OK; -} - -int ConvolutionSWFP16CPUKernel::InitTmpBuffer() { - int out_channel = conv_param_->output_channel_; - int oc4 = UP_DIV(out_channel, C4NUM); - - int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM); - size_t nhwc4_input_size = - ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); - nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size); - if (nhwc4_input_ == nullptr) { - MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; - return RET_ERROR; - } - - tmp_output_block_ = reinterpret_cast(ctx_->allocator->Malloc( - conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * oc4 * C4NUM * sizeof(float16_t))); - if (tmp_output_block_ == nullptr) { - MS_LOG(ERROR) << "malloc tmp_output_block_ failed."; - return RET_ERROR; - } - return RET_OK; -} - -void ConvolutionSWFP16CPUKernel::ConfigInputOutput() { - auto input_tensor = in_tensors_.at(kInputIndex); - auto input_format = input_tensor->GetFormat(); - schema::Format execute_format = schema::Format::Format_NHWC4; - convert_func_ = LayoutTransformFp16(input_format, execute_format); - if (convert_func_ == nullptr) { - MS_LOG(ERROR) << "layout convert func is nullptr."; - return; - } -} - -int ConvolutionSWFP16CPUKernel::Init() { - auto ret = InitWeightBias(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Init weight bias failed."; - return RET_ERROR; - } - if (!InferShapeDone()) { - return RET_OK; - } - ConfigInputOutput(); - return ReSize(); -} - -int ConvolutionSWFP16CPUKernel::ReSize() { - auto ret = ConvolutionBaseCPUKernel::CheckResizeValid(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Resize is invalid."; - return ret; - } - - if (slidingWindow_param_ != nullptr) { - delete slidingWindow_param_; - slidingWindow_param_ = nullptr; - } - - ret = ConvolutionBaseCPUKernel::Init(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "ConvolutionBase init fail!ret: " << ret; - return ret; - } - - // init sliding window param - slidingWindow_param_ = new (std::nothrow) SlidingWindowParam; - if (slidingWindow_param_ == nullptr) { - MS_LOG(ERROR) << "new SlidingWindowParam fail!"; - return RET_ERROR; - } - InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM); - return RET_OK; -} - -int ConvolutionSWFP16CPUKernel::RunImpl(int task_id) { - ConvSWFp16(reinterpret_cast(nhwc4_input_), packed_weight_, reinterpret_cast(bias_data_), - tmp_output_block_, execute_output_, task_id, conv_param_, slidingWindow_param_); - return RET_OK; -} - -static int ConvolutionSWFp16Impl(void *cdata, int task_id) { - auto conv = reinterpret_cast(cdata); - auto error_code = conv->RunImpl(task_id); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "ConvolutionFp16 Run error task_id[" << task_id << "] error_code[" << error_code << "]"; - return RET_ERROR; - } - return RET_OK; -} - -int ConvolutionSWFP16CPUKernel::Run() { - auto ret = Prepare(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Prepare failed."; - return RET_ERROR; - } - ret = ConvolutionBaseFP16CPUKernel::GetExecuteTensor(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Get Execute tensor failed."; - return ret; - } - ret = InitTmpBuffer(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Init tmp buffer failed."; - return RET_ERROR; - } - - int in_batch = conv_param_->input_batch_; - int in_h = conv_param_->input_h_; - int in_w = conv_param_->input_w_; - int in_channel = conv_param_->input_channel_; - convert_func_(reinterpret_cast(execute_input_), nhwc4_input_, in_batch, in_h * in_w, in_channel); - - int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionSWFp16Impl, this, thread_count_); - if (error_code != RET_OK) { - MS_LOG(ERROR) << "conv fp16 error error_code[" << error_code << "]"; - FreeTmpBuffer(); - return RET_ERROR; - } - - // output nhwc4 - int oc4_res = conv_param_->output_channel_ % C4NUM; - if (oc4_res != 0) { - PackNHWC4ToNHWCFp16(reinterpret_cast(tmp_output_block_), reinterpret_cast(execute_output_), - conv_param_->output_batch_, conv_param_->output_h_ * conv_param_->output_w_, - conv_param_->output_channel_); - } - ConvolutionBaseFP16CPUKernel::IfCastOutput(); - ConvolutionBaseFP16CPUKernel::FreeTmpBuffer(); - FreeTmpBuffer(); - return RET_OK; -} -} // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.h deleted file mode 100644 index a0b3680f20..0000000000 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_sw_fp16.h +++ /dev/null @@ -1,72 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_ -#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_ - -#include -#include -#include "src/lite_kernel.h" -#include "src/runtime/kernel/arm/fp16/convolution_base_fp16.h" - -namespace mindspore::kernel { -class ConvolutionSWFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { - public: - ConvolutionSWFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const InnerContext *ctx, - const mindspore::lite::PrimitiveC *primitive) - : ConvolutionBaseFP16CPUKernel(parameter, inputs, outputs, ctx, primitive) {} - ~ConvolutionSWFP16CPUKernel() override { - if (fp16_weight_ != nullptr) { - free(fp16_weight_); - fp16_weight_ = nullptr; - } - if (packed_weight_ != nullptr) { - free(packed_weight_); - packed_weight_ = nullptr; - } - if (slidingWindow_param_ != nullptr) { - delete slidingWindow_param_; - slidingWindow_param_ = nullptr; - } - } - - int Init() override; - int ReSize() override; - int Run() override; - int RunImpl(int task_id); - int InitWeightBias(); - int InitTmpBuffer(); - void ConfigInputOutput(); - int ProcessFilter(); - - private: - void FreeTmpBuffer() { - if (nhwc4_input_ != nullptr) { - ctx_->allocator->Free(nhwc4_input_); - nhwc4_input_ = nullptr; - } - if (tmp_output_block_ != nullptr) { - ctx_->allocator->Free(tmp_output_block_); - tmp_output_block_ = nullptr; - } - } - float16_t *packed_weight_ = nullptr; - float16_t *tmp_output_block_ = nullptr; - SlidingWindowParam *slidingWindow_param_ = nullptr; -}; -} // namespace mindspore::kernel - -#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP16_CONVOLUTION_SW_FP16_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc index fd3b3fd77f..c3a0f9f307 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc @@ -138,50 +138,6 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_ return RET_OK; } -int ConvolutionWinogradFP16CPUKernel::MallocTransformMatrices() { - matrix_a_ = reinterpret_cast(malloc(input_unit_ * output_unit_ * sizeof(float16_t))); - if (matrix_a_ == nullptr) { - MS_LOG(ERROR) << "malloc matrix_a_ failed."; - return RET_ERROR; - } - matrix_at_ = reinterpret_cast(malloc(input_unit_ * output_unit_ * sizeof(float16_t))); - if (matrix_at_ == nullptr) { - MS_LOG(ERROR) << "malloc matrix_at_ failed."; - return RET_ERROR; - } - matrix_b_ = reinterpret_cast(malloc(input_unit_ * input_unit_ * sizeof(float16_t))); - if (matrix_b_ == nullptr) { - MS_LOG(ERROR) << "malloc matrix_b_ failed."; - return RET_ERROR; - } - matrix_bt_ = reinterpret_cast(malloc(input_unit_ * input_unit_ * sizeof(float16_t))); - if (matrix_bt_ == nullptr) { - MS_LOG(ERROR) << "malloc matrix_bt_ failed."; - return RET_ERROR; - } - return RET_OK; -} - -void ConvolutionWinogradFP16CPUKernel::FreeTransformMatrices() { - if (matrix_a_ != nullptr) { - free(matrix_a_); - matrix_a_ = nullptr; - } - if (matrix_at_ != nullptr) { - free(matrix_at_); - matrix_at_ = nullptr; - } - if (matrix_b_ != nullptr) { - free(matrix_b_); - matrix_b_ = nullptr; - } - if (matrix_bt_ != nullptr) { - free(matrix_bt_); - matrix_bt_ = nullptr; - } - return; -} - int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { auto filter_tensor = in_tensors_.at(kWeightIndex); int in_channel = filter_tensor->Channel(); @@ -190,9 +146,8 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { conv_param_->input_channel_ = in_channel; conv_param_->output_channel_ = out_channel; - int oc_block, oc_block_num; - oc_block = C8NUM; - oc_block_num = UP_DIV(out_channel, C8NUM); + const int oc_block = C8NUM; + int oc_block_num = UP_DIV(out_channel, C8NUM); // init weight auto ret = ConvolutionBaseFP16CPUKernel::GetExecuteFilter(); @@ -209,49 +164,24 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { return RET_ERROR; } memset(trans_weight_, 0, trans_matrix_data_size); - auto *matrix_g = reinterpret_cast(malloc(input_unit_ * kernel_unit_ * sizeof(float))); - if (matrix_g == nullptr) { - MS_LOG(ERROR) << "malloc matrix_g failed."; - return RET_ERROR; - } - auto matrix_gt = reinterpret_cast(malloc(input_unit_ * kernel_unit_ * sizeof(float))); - if (matrix_gt == nullptr) { - free(matrix_g); - MS_LOG(ERROR) << "malloc matrix_gt failed."; - return RET_ERROR; - } - ret = MallocTransformMatrices(); - if (ret != RET_OK) { - free(matrix_g); - free(matrix_gt); - MS_LOG(ERROR) << "Malloc transform matrices failed."; - return ret; - } - float matrix_a[MAX_LEN]; - float matrix_at[MAX_LEN]; - float matrix_b[MAX_LEN]; - float matrix_bt[MAX_LEN]; - ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, 0.5f, output_unit_, kernel_unit_); + float matrix_g[64]; + float matrix_gt[64]; + float matrix_a[64]; + float matrix_at[64]; + float matrix_b[64]; + float matrix_bt[64]; + float coef = 1.0f; + if (input_unit_ == 8) { + coef = 0.5f; + } + ret = CookToomFilter(matrix_a, matrix_at, matrix_b, matrix_bt, matrix_g, matrix_gt, coef, output_unit_, kernel_unit_); if (ret != RET_OK) { - free(matrix_g); - free(matrix_gt); MS_LOG(ERROR) << "get matrix g from CookToomFilter failed."; return ret; } - Float32ToFloat16(matrix_a, matrix_a_, input_unit_ * output_unit_); - Float32ToFloat16(matrix_at, matrix_at_, input_unit_ * output_unit_); - Float32ToFloat16(matrix_b, matrix_b_, input_unit_ * input_unit_); - Float32ToFloat16(matrix_bt, matrix_bt_, input_unit_ * input_unit_); - matrices_[0] = matrix_a_; - matrices_[1] = matrix_at_; - matrices_[2] = matrix_b_; - matrices_[3] = matrix_bt_; - ret = WinogradFilterTransformFp16(execute_weight_, matrix_g, matrix_gt, oc_block); if (ret != RET_OK) { - free(matrix_g); - free(matrix_gt); MS_LOG(ERROR) << "winograd filter transfrom failed."; return ret; } @@ -259,8 +189,6 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { // init bias bias_data_ = malloc(oc_block_num * oc_block * sizeof(float16_t)); if (bias_data_ == nullptr) { - free(matrix_g); - free(matrix_gt); MS_LOG(ERROR) << "malloc bias_data_ failed."; return RET_ERROR; } @@ -274,27 +202,15 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { } else { MS_ASSERT(inputs_.size() == kInputSize1); } - free(matrix_g); - free(matrix_gt); return RET_OK; } int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { - int cal_num = 16; + const int cal_num = 16; int channel_out = conv_param_->output_channel_; - int output_h = conv_param_->output_h_; - int output_w = conv_param_->output_w_; int oc8 = UP_DIV(channel_out, C8NUM); int ic8 = UP_DIV(conv_param_->input_channel_, C8NUM); - size_t nhwc8_input_size = - ic8 * C8NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float16_t); - nhwc4_input_ = ctx_->allocator->Malloc(nhwc8_input_size); - if (nhwc4_input_ == nullptr) { - MS_LOG(ERROR) << "malloc nhwc4_input_ failed."; - return RET_ERROR; - } - size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic8 * C8NUM * sizeof(float16_t); trans_input_ = reinterpret_cast(ctx_->allocator->Malloc(tile_buffer_size)); if (trans_input_ == nullptr) { @@ -309,16 +225,6 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { return RET_ERROR; } - int out_w_block = UP_DIV(output_w, output_unit_); - int out_h_block = UP_DIV(output_h, output_unit_); - tmp_out_data_ = reinterpret_cast( - ctx_->allocator->Malloc(conv_param_->output_batch_ * out_w_block * out_h_block * output_unit_ * output_unit_ * oc8 * - C8NUM * sizeof(float16_t))); - if (tmp_out_data_ == nullptr) { - MS_LOG(ERROR) << "malloc tmp_out_data_ failed."; - return RET_ERROR; - } - tmp_data_ = reinterpret_cast( ctx_->allocator->Malloc(thread_count_ * C8NUM * input_unit_ * input_unit_ * sizeof(float16_t))); if (tmp_data_ == nullptr) { @@ -328,14 +234,21 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { tmp_buffer_address_list_[0] = trans_input_; tmp_buffer_address_list_[1] = gemm_out_; - tmp_buffer_address_list_[2] = tmp_out_data_; - tmp_buffer_address_list_[3] = tmp_data_; + tmp_buffer_address_list_[2] = tmp_data_; return RET_OK; } int ConvolutionWinogradFP16CPUKernel::ConfigInputOutput() { - auto output_tensor = out_tensors_.at(kOutputIndex); - output_tensor->SetFormat(schema::Format_NHWC); + in_func_ = GetInputTransFp16Func(input_unit_); + if (in_func_ == nullptr) { + MS_LOG(ERROR) << "in_func_ is null."; + return RET_ERROR; + } + out_func_ = GetOutputTransFp16Func(input_unit_, output_unit_, conv_param_->act_type_); + if (out_func_ == nullptr) { + MS_LOG(ERROR) << "out_func_ is null."; + return RET_ERROR; + } return RET_OK; } @@ -381,9 +294,8 @@ int ConvolutionWinogradFP16CPUKernel::ReSize() { } int ConvolutionWinogradFP16CPUKernel::RunImpl(int task_id) { - ConvWinogardFp16(reinterpret_cast(nhwc4_input_), trans_weight_, - reinterpret_cast(bias_data_), tmp_buffer_address_list_, task_id, conv_param_, - matrices_); + ConvWinogardFp16(execute_input_, trans_weight_, reinterpret_cast(bias_data_), execute_output_, + tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_); return RET_OK; } @@ -397,28 +309,6 @@ static int ConvolutionWinogradFp16Impl(void *cdata, int task_id) { return RET_OK; } -int ConvolutionWinogradFP16CPUKernel::PostProcess() { - auto act_type = conv_param_->act_type_; - switch (act_type) { - case ActType_No: - UnPackWinogradOutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_, - conv_param_->output_w_, conv_param_->output_channel_, output_unit_); - break; - case ActType_Relu: - UnPackWinogradReluOutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_, - conv_param_->output_w_, conv_param_->output_channel_, output_unit_); - break; - case ActType_Relu6: - UnPackWinogradRelu6OutputFp16(tmp_out_data_, execute_output_, conv_param_->output_batch_, conv_param_->output_h_, - conv_param_->output_w_, conv_param_->output_channel_, output_unit_); - break; - default: - MS_LOG(ERROR) << "Unsupport activation type."; - return RET_ERROR; - } - return RET_OK; -} - int ConvolutionWinogradFP16CPUKernel::Run() { auto prepare_ret = Prepare(); if (prepare_ret != RET_OK) { @@ -438,12 +328,6 @@ int ConvolutionWinogradFP16CPUKernel::Run() { return RET_ERROR; } - int in_batch = conv_param_->input_batch_; - int in_h = conv_param_->input_h_; - int in_w = conv_param_->input_w_; - int in_channel = conv_param_->input_channel_; - PackNHWCToNHWC8Fp16(execute_input_, nhwc4_input_, in_batch, in_h * in_w, in_channel); - int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionWinogradFp16Impl, this, thread_count_); if (error_code != RET_OK) { MS_LOG(ERROR) << "conv winograd error error_code[" << error_code << "]"; @@ -451,11 +335,6 @@ int ConvolutionWinogradFP16CPUKernel::Run() { return RET_ERROR; } - ret = PostProcess(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Post process failed."; - return ret; - } ConvolutionBaseFP16CPUKernel::IfCastOutput(); ConvolutionBaseFP16CPUKernel::FreeTmpBuffer(); FreeTmpBuffer(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h index 5d74715dd5..f4c0e4ddb2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h @@ -42,7 +42,6 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { free(trans_weight_); trans_weight_ = nullptr; } - FreeTransformMatrices(); } int Init() override; @@ -50,19 +49,12 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { int Run() override; int RunImpl(int task_id); int InitWeightBias(); - int MallocTransformMatrices(); - void FreeTransformMatrices(); int InitTmpBuffer(); int ConfigInputOutput(); - int PostProcess(); int WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g, float *matrix_gt, int oc_block); private: void FreeTmpBuffer() { - if (nhwc4_input_ != nullptr) { - ctx_->allocator->Free(nhwc4_input_); - nhwc4_input_ = nullptr; - } if (trans_input_ != nullptr) { ctx_->allocator->Free(trans_input_); trans_input_ = nullptr; @@ -75,10 +67,6 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { ctx_->allocator->Free(gemm_out_); gemm_out_ = nullptr; } - if (tmp_out_data_ != nullptr) { - ctx_->allocator->Free(tmp_out_data_); - tmp_out_data_ = nullptr; - } } int kernel_unit_; int input_unit_; @@ -86,14 +74,10 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { float16_t *tmp_data_ = nullptr; float16_t *trans_input_ = nullptr; float16_t *gemm_out_ = nullptr; - float16_t *tmp_out_data_ = nullptr; - float16_t *matrix_a_ = nullptr; - float16_t *matrix_at_ = nullptr; - float16_t *matrix_b_ = nullptr; - float16_t *matrix_bt_ = nullptr; float16_t *trans_weight_ = nullptr; - TmpBufferAddressFp16 tmp_buffer_address_list_[4]; - MatricesFp16 matrices_[4]; + TmpBufferAddressFp16 tmp_buffer_address_list_[3]; + InputTransFp16Func in_func_; + OutputTransFp16Func out_func_; }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc index e7fb418f47..e0f483d77e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.cc @@ -117,9 +117,8 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { conv_param_->output_channel_ = out_channel; int oc4 = UP_DIV(out_channel, C4NUM); - int oc_block, oc_block_num; - oc_block = C8NUM; - oc_block_num = UP_DIV(out_channel, C8NUM); + const int oc_block = C8NUM; + int oc_block_num = UP_DIV(out_channel, C8NUM); // set data auto trans_matrix_data_size = input_unit_ * input_unit_ * ic4 * C4NUM * oc_block_num * oc_block * sizeof(float); @@ -172,9 +171,6 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { int ConvolutionWinogradCPUKernel::InitTmpBuffer() { int channel_out = conv_param_->output_channel_; - int output_h = conv_param_->output_h_; - int output_w = conv_param_->output_w_; - int oc4 = UP_DIV(channel_out, C4NUM); int oc8 = UP_DIV(channel_out, C8NUM); int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM); #ifdef ENABLE_ARM32 @@ -198,16 +194,6 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { return RET_ERROR; } - int out_w_block = UP_DIV(output_w, output_unit_); - int out_h_block = UP_DIV(output_h, output_unit_); - tmp_out_data_ = - reinterpret_cast(ctx_->allocator->Malloc(conv_param_->output_batch_ * out_w_block * out_h_block * - output_unit_ * output_unit_ * oc4 * C4NUM * sizeof(float))); - if (tmp_out_data_ == nullptr) { - MS_LOG(ERROR) << "malloc tmp_out_data_ failed."; - return RET_MEMORY_FAILED; - } - tmp_data_ = reinterpret_cast( ctx_->allocator->Malloc(thread_count_ * C4NUM * input_unit_ * input_unit_ * sizeof(float))); if (tmp_data_ == nullptr) { @@ -224,16 +210,12 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { tmp_buffer_address_list_[0] = trans_input_; tmp_buffer_address_list_[1] = gemm_out_; - tmp_buffer_address_list_[2] = tmp_out_data_; - tmp_buffer_address_list_[3] = tmp_data_; - tmp_buffer_address_list_[4] = col_buffer_; + tmp_buffer_address_list_[2] = tmp_data_; + tmp_buffer_address_list_[3] = col_buffer_; return RET_OK; } int ConvolutionWinogradCPUKernel::ConfigInputOutput() { - auto output_tensor = out_tensors_.at(kOutputIndex); - output_tensor->SetFormat(schema::Format::Format_NHWC); - in_func_ = GetInputTransFunc(input_unit_); if (in_func_ == nullptr) { MS_LOG(ERROR) << "in_func_ is null."; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h index 319aaf1c4e..ee22d8bff0 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd.h @@ -61,10 +61,6 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { ctx_->allocator->Free(gemm_out_); gemm_out_ = nullptr; } - if (tmp_out_data_ != nullptr) { - ctx_->allocator->Free(tmp_out_data_); - tmp_out_data_ = nullptr; - } if (col_buffer_ != nullptr) { ctx_->allocator->Free(col_buffer_); col_buffer_ = nullptr; @@ -76,10 +72,9 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { float *tmp_data_ = nullptr; float *trans_input_ = nullptr; float *gemm_out_ = nullptr; - float *tmp_out_data_ = nullptr; float *col_buffer_ = nullptr; float *trans_weight_ = nullptr; - TmpBufferAddress tmp_buffer_address_list_[5]; + TmpBufferAddress tmp_buffer_address_list_[4]; InputTransFunc in_func_; OutputTransFunc out_func_; };