From aec6dfd51396c3d4ca3561a6d1abc5a7c3d3761a Mon Sep 17 00:00:00 2001 From: lixian Date: Mon, 15 Mar 2021 09:06:07 +0800 Subject: [PATCH] add 1d f(2,3) support for 3x3 dw conv --- .../lite/nnacl/fp32/conv_depthwise_fp32.c | 558 +++++++++++------- .../lite/nnacl/fp32/conv_depthwise_fp32.h | 13 +- mindspore/lite/nnacl/fp32/pack_fp32.c | 20 + mindspore/lite/nnacl/fp32/pack_fp32.h | 4 + .../nnacl/intrinsics/ms_simd_instructions.h | 3 - .../arm/fp32/convolution_delegate_fp32.cc | 10 +- .../fp32/convolution_depthwise_3x3_fp32.cc | 87 ++- .../arm/fp32/convolution_depthwise_3x3_fp32.h | 6 +- 8 files changed, 412 insertions(+), 289 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c index 1fee1230f7..7c2ff0ca63 100644 --- a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c +++ b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.c @@ -18,6 +18,7 @@ #include "nnacl/common_func.h" #include "nnacl/fp32/common_func_fp32.h" #include "nnacl/fp32/winograd_transform.h" +#include "nnacl/intrinsics/ms_simd_instructions.h" #ifdef ENABLE_ARM64 #include #endif @@ -337,260 +338,373 @@ bool CheckConvDwUse3X3(const ConvParameter *conv_param) { in_w == (conv_param->input_w_ + 2 * conv_param->pad_l_); } -void ConvDw3x3BorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width, - int in_kh_step, int in_kw_step, int channel, bool relu, bool relu6) { - for (int c = 0; c < channel; c += C4NUM) { - for (int i = 0; i < C4NUM; i++) { - dst[i] = 0; - } - const float *src_kh = src; - const float *weight_kh = weight; - for (int kh = 0; kh < height; kh++) { - const float *src_kw = src_kh; - const float *weight_kw = weight_kh; - for (int kw = 0; kw < width; kw++) { - for (int i = 0; i < C4NUM; i++) { - dst[i] += src_kw[c + i] * weight_kw[c + i]; - } - src_kw += in_kw_step; - weight_kw += channel; - } // kernel_w loop - src_kh += in_kh_step; - weight_kh += 3 * channel; - } // kernel_h loop - for (int i = 0; i < C4NUM; i++) { - dst[i] += bias[c + i]; - dst[i] = (relu) ? (MSMAX(0, dst[i])) : (dst[i]); - dst[i] = (relu6) ? (MSMIN(6, MSMAX(0, dst[i]))) : (dst[i]); +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num) { + return conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->pad_u_ == 1 && conv_param->pad_d_ == 1 && conv_param->pad_l_ == 1 && conv_param->pad_r_ == 1 && + conv_param->input_channel_ == conv_param->output_channel_ && + conv_param->output_h_ / thread_num >= 4; // better had more than 4 rows for each thread +} + +void ConvDw3x3RowLeft(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + v0 = MS_MOVQ_F32(0.0f); + int ic = 0; + for (; ic < channel - 3; ic += 4) { + v1 = MS_LDQ_F32(src + ic); + v2 = MS_LDQ_F32(src + channel + ic); + v3 = MS_LDQ_F32(src + 2 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 4, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 12, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float d1 = src[i + ic]; + float d2 = src[i + ic + channel]; + float d3 = src[i + ic + 2 * channel]; + remain_line[i] = 0.0f - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = d3 - d1; } - dst += C4NUM; } } -#ifndef ENABLE_ARM64 -void ConvDw3x3Corner(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, - int in_kw_step, int channel, bool relu, bool relu6) { - ConvDw3x3BorderPixel(dst, src, weight, bias, 2, 2, in_kh_step, in_kw_step, channel, relu, relu6); -} - -void ConvDw3x3Vertical(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, - int in_kw_step, int channel, bool relu, bool relu6) { - ConvDw3x3BorderPixel(dst, src, weight, bias, 2, 3, in_kh_step, in_kw_step, channel, relu, relu6); -} - -void ConvDw3x3Horizontal(float *dst, const float *src, const float *weight, const float *bias, int in_kh_step, - int in_kw_step, int channel, bool relu, bool relu6) { - ConvDw3x3BorderPixel(dst, src, weight, bias, 3, 2, in_kh_step, in_kw_step, channel, relu, relu6); -} -#endif - -void ConvDw3x3Pad(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, - const ConvParameter *conv_param, const SlidingWindowParam *sliding) { - int input_row_size = conv_param->input_w_ * conv_param->input_channel_; - int weight_row_size = conv_param->kernel_w_ * conv_param->input_channel_; - int output_row_size = conv_param->output_w_ * conv_param->output_channel_; - int in_kh_step = sliding->in_kh_step_; - int in_kw_step = sliding->in_kw_step_; - bool relu = conv_param->act_type_ == ActType_Relu; - bool relu6 = conv_param->act_type_ == ActType_Relu6; - - for (int b = 0; b < conv_param->output_batch_; b++) { - const float *input_batch = - input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; - float *output_batch = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; - // top - const float *input = input_batch; - const float *weight = weight_data + weight_row_size + conv_param->input_channel_; - float *output = output_batch; - ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); - input += (conv_param->stride_w_ - 1) * conv_param->input_channel_; - weight = weight_data + weight_row_size; - output += conv_param->output_channel_; - for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { - ConvDw3x3Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, - relu6); - input += conv_param->stride_w_ * conv_param->input_channel_; - output += conv_param->output_channel_; - } - ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); - - // left - input = input_batch + (conv_param->stride_h_ - 1) * input_row_size; - weight = weight_data + conv_param->input_channel_; - output = output_batch + output_row_size; - for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { - ConvDw3x3Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, - relu6); - input += conv_param->stride_h_ * input_row_size; - output += output_row_size; +void ConvDw3x3RowMiddle(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + int ic = 0; + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + v2 = MS_LDQ_F32(src + 2 * channel + ic); + v3 = MS_LDQ_F32(src + 3 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 4, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 12, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + float d2 = src[i + ic + 2 * channel]; + float d3 = src[i + ic + 3 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = d3 - d1; } + } +} - // right - input = input_batch + (conv_param->input_w_ - 2) * conv_param->input_channel_ + - (conv_param->stride_h_ - 1) * input_row_size; - weight = weight_data; - output = output_batch + output_row_size + (conv_param->output_w_ - 1) * conv_param->output_channel_; - for (int out_h = sliding->top_; out_h < sliding->bottom_; out_h++) { - ConvDw3x3Horizontal(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, - relu6); - input += conv_param->stride_h_ * input_row_size; - output += output_row_size; +void ConvDw3x3RowRight(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2, v3; + int ic = 0; + v3 = MS_MOVQ_F32(0.0f); + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + v2 = MS_LDQ_F32(src + 2 * channel + ic); + MS_FLOAT32X4 b0 = MS_SUBQ_F32(v0, v2); + MS_FLOAT32X4 b1 = MS_ADDQ_F32(v1, v2); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_FLOAT32X4 b3 = MS_SUBQ_F32(v3, v1); + MS_STQ_F32(line + lw * ic, b0); + MS_STQ_F32(line + lw * ic + 4, b1); + MS_STQ_F32(line + lw * ic + 8, b2); + MS_STQ_F32(line + lw * ic + 12, b3); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 4, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 12, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + float d2 = src[i + ic + 2 * channel]; + remain_line[i] = d0 - d2; + remain_line[i + 4] = d1 + d2; + remain_line[i + 8] = d2 - d1; + remain_line[i + 12] = 0.0f - d1; } + } +} - // bottom - input = input_batch + (conv_param->input_h_ - 2) * input_row_size; - weight = weight_data + conv_param->input_channel_; - output = output_batch + (conv_param->output_h_ - 1) * output_row_size; - ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); - input += conv_param->stride_w_ == 1 ? 0 : conv_param->input_channel_; - weight = weight_data; - output += conv_param->output_channel_; - for (int out_w = sliding->left_; out_w < sliding->right_; out_w++) { - ConvDw3x3Vertical(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, - relu6); - input += conv_param->stride_w_ * conv_param->input_channel_; - output += conv_param->output_channel_; +void ConvDw3x3RowSingle(const float *src, float *line, int lw, int channel) { + MS_FLOAT32X4 v0, v1, v2; + int ic = 0; + v2 = MS_MOVQ_F32(0.0f); + for (; ic < channel - 3; ic += 4) { + v0 = MS_LDQ_F32(src + ic); + v1 = MS_LDQ_F32(src + channel + ic); + MS_FLOAT32X4 b2 = MS_SUBQ_F32(v2, v1); + MS_STQ_F32(line + lw * ic, v0); + MS_STQ_F32(line + lw * ic + 4, v1); + MS_STQ_F32(line + lw * ic + 8, b2); + memset(line + lw * ic + 12, 0, 16); + } + if (ic < channel) { + float *remain_line = line + ic * lw; + memset(remain_line, 0, 16); + memset(remain_line + 4, 0, 16); + memset(remain_line + 8, 0, 16); + memset(remain_line + 12, 0, 16); + for (int i = 0; i < channel - ic; i++) { + float d0 = src[i + ic]; + float d1 = src[i + ic + channel]; + remain_line[i] = d0; + remain_line[i + 4] = d1; + remain_line[i + 8] = 0.0f - d1; } - ConvDw3x3Corner(output, input, weight, bias_data, in_kh_step, in_kw_step, conv_param->input_channel_, relu, relu6); } } -void ConvDw3x3InitBuffer(float *buffer, const float *input, const ConvParameter *conv_param, int block_input_h, - int block_input_w) { - for (int h = 0; h < block_input_h; h++) { - const float *src = input; - for (int w = 0; w < block_input_w; w++) { - memcpy(buffer, src, 64 * sizeof(float)); - src += conv_param->input_channel_; - buffer += 64; - } - input += conv_param->input_w_ * conv_param->input_channel_; +void ConvDw3x3InitTop(const float *src, float **lines, int width, int channel) { + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + int c4 = UP_ROUND(channel, C4NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(line0, 0, c4 * lw * sizeof(float)); + ConvDw3x3RowLeft(src, line1, lw, channel); + ConvDw3x3RowLeft(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); } } -void ConvDw3x3Window(float *output, const float *buffer, const float *weight, const float *bias, int col_size, - int row_size, int channel, int output_h, int output_w, int stride, bool relu, bool relu6) { - for (int w = 0; w < output_w; w++) { - for (int i = 0; i < C4NUM; i++) { - output[i] = bias[i]; - } - const float *src_kh = buffer; - const float *weight_kh = weight; - for (int kh = 0; kh < 3; kh++) { - const float *src_kw = src_kh; - const float *weight_kw = weight_kh; - for (int kw = 0; kw < 3; kw++) { - for (int c = 0; c < C4NUM; c++) { - output[c] += src_kw[c] * weight_kw[c]; - } - src_kw += col_size; - weight_kw += channel; - } - src_kh += row_size; - weight_kh += 3 * channel; - } - for (int i = 0; i < C4NUM; i++) { - output[i] = (relu) ? (MSMAX(0, output[i])) : (output[i]); - output[i] = (relu6) ? (MSMIN(6, MSMAX(0, output[i]))) : (output[i]); - } - output += channel; - buffer += col_size * stride; +void ConvDw3x3InitRow(const float *src, float **lines, int width, int channel) { + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + int lw = UP_DIV(width, C2NUM) * C4NUM; + ConvDw3x3RowLeft(src - width * channel, line0, lw, channel); + ConvDw3x3RowLeft(src, line1, lw, channel); + ConvDw3x3RowLeft(src + width * channel, line2, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowMiddle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowRight(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src - width * channel + (ow - 1) * channel, line0 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + (ow - 1) * channel, line1 + 2 * ow * 4, lw, channel); + ConvDw3x3RowSingle(src + width * channel + (ow - 1) * channel, line2 + 2 * ow * 4, lw, channel); } } -void ConvDw3x3Block(float *output, const float *buffer, const float *weight, const float *bias, int start_c, int end_c, - int col_size, int row_size, int channel, int output_h, int output_w, int stride, bool relu, - bool relu6) { - for (; start_c <= end_c - C4NUM; start_c += C4NUM) { -#ifdef ENABLE_ARM64 - if (stride == 1) { - ConvDw3x3Stride1(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, relu, relu6); - } else { - ConvDw3x3Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, relu, relu6); - } -#else - ConvDw3x3Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, stride, relu, relu6); -#endif - output += C4NUM; - buffer += C4NUM; - weight += C4NUM; - bias += C4NUM; +void ConvDw3x3Row(const float *src, float **lines, int width, int channel) { + float *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c4 = UP_ROUND(channel, C4NUM); + int lw = UP_DIV(width, C2NUM) * C4NUM; + memset(tmp, 0, c4 * lw * sizeof(float)); + ConvDw3x3RowLeft(src, tmp, lw, channel); + int ow = 2; + for (; ow < width - 2; ow += 2) { + ConvDw3x3RowMiddle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); + } + int remain = width - ow; + if (remain == 2) { + ConvDw3x3RowRight(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); + } else if (remain == 1) { + ConvDw3x3RowSingle(src + (ow - 1) * channel, tmp + 2 * ow * 4, lw, channel); } } -void ConvDw3x3Row(float *output, float *buffer, const float *input, const float *weight, const float *bias, - const ConvParameter *conv_param, int start_w, int end_w, int block_output_h, int block_output_w, - int block_input_h, int block_input_w) { - bool relu = conv_param->act_type_ == ActType_Relu; - bool relu6 = conv_param->act_type_ == ActType_Relu6; - const int ih_offset = 64 * block_input_w; - int w = start_w; - if (conv_param->output_channel_ > 64 || (conv_param->output_channel_ < 64 && conv_param->input_w_ > 150)) { - for (; w <= end_w - block_output_w; w += block_output_w) { - float *output_ptr = output; - const float *input_ptr = input; - const float *weight_ptr = weight; - const float *bias_ptr = bias; - int c = 0; - for (; c <= conv_param->output_channel_ - 64; c += 64) { - ConvDw3x3InitBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w); - ConvDw3x3Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_, - block_output_h, block_output_w, conv_param->stride_h_, relu, relu6); - output_ptr += 64; - input_ptr += 64; - weight_ptr += 64; - bias_ptr += 64; +void ConvDw3x3Bottom(float **lines, int width, int channel) { + float *tmp = lines[0]; + lines[0] = lines[1]; + lines[1] = lines[2]; + lines[2] = tmp; + int c4 = UP_ROUND(channel, C4NUM); + memset(tmp, 0, UP_DIV(width, C2NUM) * c4 * C4NUM * sizeof(float)); +} + +void ConvDw3x3Line(float *dst, float **lines, const float *weight, const float *bias_data, int width, int ori_channel, + bool relu, bool relu6) { + int channel = ori_channel; + float *line0 = lines[0]; + float *line1 = lines[1]; + float *line2 = lines[2]; + for (; channel > 0; channel -= 4) { + MS_FLOAT32X4 bias = MS_LDQ_F32(bias_data); + bias_data += 4; + MS_FLOAT32X4 g00 = MS_LDQ_F32(weight); + MS_FLOAT32X4 g01 = MS_LDQ_F32(weight + 4); + MS_FLOAT32X4 g02 = MS_LDQ_F32(weight + 8); + MS_FLOAT32X4 g03 = MS_LDQ_F32(weight + 12); + MS_FLOAT32X4 g10 = MS_LDQ_F32(weight + 16); + MS_FLOAT32X4 g11 = MS_LDQ_F32(weight + 20); + MS_FLOAT32X4 g12 = MS_LDQ_F32(weight + 24); + MS_FLOAT32X4 g13 = MS_LDQ_F32(weight + 28); + MS_FLOAT32X4 g20 = MS_LDQ_F32(weight + 32); + MS_FLOAT32X4 g21 = MS_LDQ_F32(weight + 36); + MS_FLOAT32X4 g22 = MS_LDQ_F32(weight + 40); + MS_FLOAT32X4 g23 = MS_LDQ_F32(weight + 44); + weight += 48; + float *cur_dst = dst; + int ow = 0; + for (; ow < width - 1; ow += 2) { + MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00); + MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01); + MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02); + MS_FLOAT32X4 acc3 = MS_MULQ_F32(MS_LDQ_F32(line0 + 12), g03); + line0 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12); + acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line1 + 12), g13); + line1 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22); + acc3 = MS_MLAQ_F32(acc3, MS_LDQ_F32(line2 + 12), g23); + line2 += 16; + MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1)); + MS_FLOAT32X4 res1 = MS_ADDQ_F32(acc1, MS_SUBQ_F32(acc3, acc2)); + res0 = MS_ADDQ_F32(res0, bias); + res1 = MS_ADDQ_F32(res1, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f)); + res1 = MS_MAXQ_F32(res1, MS_MOVQ_F32(0.0f)); + } + if (relu6) { + res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f)); + res1 = MS_MINQ_F32(res1, MS_MOVQ_F32(6.0f)); } - // left channel - ConvDw3x3Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_, - conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_, - conv_param->input_channel_, block_output_h, block_output_w, conv_param->stride_h_, relu, relu6); - output += block_output_w * conv_param->input_channel_; - input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_; + if (channel >= 4) { + MS_STQ_F32(cur_dst, res0); + MS_STQ_F32(cur_dst + ori_channel, res1); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = res0[i]; + cur_dst[ori_channel + i] = res1[i]; + } + } + cur_dst += 2 * ori_channel; } - } - // left width - int left_width = end_w - w; - if (left_width > 0) { - ConvDw3x3Block(output, input, weight, bias, 0, conv_param->input_channel_, conv_param->input_channel_, - conv_param->input_w_ * conv_param->input_channel_, conv_param->input_channel_, block_output_h, - left_width, conv_param->stride_h_, relu, relu6); + if (ow < width) { + MS_FLOAT32X4 acc0 = MS_MULQ_F32(MS_LDQ_F32(line0), g00); + MS_FLOAT32X4 acc1 = MS_MULQ_F32(MS_LDQ_F32(line0 + 4), g01); + MS_FLOAT32X4 acc2 = MS_MULQ_F32(MS_LDQ_F32(line0 + 8), g02); + line0 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line1), g10); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line1 + 4), g11); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line1 + 8), g12); + line1 += 16; + acc0 = MS_MLAQ_F32(acc0, MS_LDQ_F32(line2), g20); + acc1 = MS_MLAQ_F32(acc1, MS_LDQ_F32(line2 + 4), g21); + acc2 = MS_MLAQ_F32(acc2, MS_LDQ_F32(line2 + 8), g22); + line2 += 16; + MS_FLOAT32X4 res0 = MS_ADDQ_F32(acc0, MS_ADDQ_F32(acc2, acc1)); + res0 = MS_ADDQ_F32(res0, bias); + if (relu || relu6) { + res0 = MS_MAXQ_F32(res0, MS_MOVQ_F32(0.0f)); + } + if (relu6) { + res0 = MS_MINQ_F32(res0, MS_MOVQ_F32(6.0f)); + } + if (channel >= 4) { + MS_STQ_F32(cur_dst, res0); + } else { + for (int i = 0; i < channel; i++) { + cur_dst[i] = res0[i]; + } + } + } + dst += 4; } } void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, - const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, - int task_id) { - int output_h = sliding->bottom_ - sliding->top_; - int step_oh = UP_DIV(output_h, conv_param->thread_num_); - int start_oh = step_oh * task_id + sliding->top_; - int end_oh = MSMIN(start_oh + step_oh, sliding->bottom_); - int start_ow = sliding->left_; - int end_ow = sliding->right_; - - const int block_output_h = 1; - int block_output_w = conv_param->stride_w_ == 1 ? 30 : 14; - const int block_input_h = 3; - int block_input_w = conv_param->stride_w_ * (block_output_w - 1) + 3; + const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh) { + int units = UP_DIV(conv_param->output_w_, C2NUM); + int c4 = UP_ROUND(conv_param->input_channel_, C4NUM); + int line = conv_param->input_channel_ * conv_param->input_w_; + + bool relu = conv_param->act_type_ == ActType_Relu; + bool relu6 = conv_param->act_type_ == ActType_Relu6; for (int b = 0; b < conv_param->output_batch_; b++) { - int start_ih = start_oh * conv_param->stride_h_ - conv_param->pad_u_; - int start_iw = start_ow * conv_param->stride_w_ - conv_param->pad_l_; - const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_ + - start_ih * conv_param->input_w_ * conv_param->input_channel_ + - start_iw * conv_param->input_channel_; - float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_ + - start_oh * conv_param->output_w_ * conv_param->output_channel_ + - start_ow * conv_param->output_channel_; - - for (int oh = start_oh; oh < end_oh; oh++) { - ConvDw3x3Row(dst, buffer, src, weight_data, bias_data, conv_param, start_ow, end_ow, block_output_h, - block_output_w, block_input_h, block_input_w); - src += conv_param->stride_h_ * conv_param->input_w_ * conv_param->input_channel_; - dst += conv_param->output_w_ * conv_param->output_channel_; + const float *src = input_data + b * conv_param->input_h_ * conv_param->input_w_ * conv_param->input_channel_; + float *dst = output_data + b * conv_param->output_h_ * conv_param->output_w_ * conv_param->output_channel_; + float *line0 = buffer; + float *line1 = buffer + units * c4 * C4NUM; + float *line2 = buffer + units * c4 * C8NUM; + float *lines[3] = {line0, line1, line2}; + int oh = start_oh; + if (oh == 0) { + // input trans + ConvDw3x3InitTop(src, lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3InitRow(src + oh * line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); + for (oh = start_oh + 1; oh < end_oh - 1; oh++) { + // input trans + ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); } + if (oh == conv_param->output_h_ - 1) { + // input trans + ConvDw3x3Bottom(lines, conv_param->output_w_, conv_param->input_channel_); + } else { + // input trans + ConvDw3x3Row(src + oh * line + line, lines, conv_param->output_w_, conv_param->input_channel_); + } + // dst calc and trans + ConvDw3x3Line(dst + oh * line, lines, weight_data, bias_data, conv_param->output_w_, conv_param->input_channel_, + relu, relu6); } } +#endif /*conv depthwise indirect buffer fp32 begin*/ bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param) { diff --git a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h index 933b0bf633..26ba01e895 100644 --- a/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h +++ b/mindspore/lite/nnacl/fp32/conv_depthwise_fp32.h @@ -47,12 +47,6 @@ void ConvDwSWFp32(float *output_data, const float *input_data, const float *weig bool CheckConvDwUse3X3(const ConvParameter *conv_param); -void ConvDw3x3Pad(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, - const ConvParameter *conv_param, const SlidingWindowParam *sliding); - -void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, - const float *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding, int task_id); - bool CheckConvDwUseIndirectBuffer(const ConvParameter *conv_param); void ConvDwInitIndirection(float **indirect_buffer, float *src, float *zero_ptr, const ConvParameter *conv_param, @@ -74,6 +68,13 @@ void ConvDwFp32Avx5x5(float *output, float **input, const float *weights, const size_t output_width, size_t input_stride, size_t relu, size_t relu6); #endif +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void ConvDw3x3(float *output_data, float *buffer, const float *input_data, const float *weight_data, + const float *bias_data, const ConvParameter *conv_param, int start_oh, int end_oh); + +bool CheckConvDw1DWinograd(const ConvParameter *conv_param, int thread_num); +#endif + void ConvDwFp32IndirectRow(float *output, float **input, const float *weights, const float *bias, int channels, int output_width, int input_stride, bool relu, bool relu6, int kernel); diff --git a/mindspore/lite/nnacl/fp32/pack_fp32.c b/mindspore/lite/nnacl/fp32/pack_fp32.c index aa4f970b91..bc2d7c9ef7 100644 --- a/mindspore/lite/nnacl/fp32/pack_fp32.c +++ b/mindspore/lite/nnacl/fp32/pack_fp32.c @@ -632,3 +632,23 @@ inline void Transpose8X8Fp32Sse(const float *src_ptr, float *dst_ptr, int src_st _mm_storeu_ps(dst_ptr + (C4NUM + 3) * dst_stride + C4NUM, v11_ma); } #endif + +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel) { + // nchw to nc4hw4 with 1D F(2,3) + for (int i = 0; i < channel; i++) { + float *src_kernel = (float *)src + i * 9; + float *dst_kernel = (float *)dst + (i / 4) * 48 + i % 4; + for (int y = 0; y < 3; y++) { + float g0 = src_kernel[3 * y]; + float g1 = src_kernel[3 * y + 1]; + float g2 = src_kernel[3 * y + 2]; + + dst_kernel[16 * y] = g0; + dst_kernel[16 * y + 4] = 0.5f * (g0 + g1 + g2); + dst_kernel[16 * y + 8] = 0.5f * (g0 - g1 + g2); + dst_kernel[16 * y + 12] = g2; + } + } +} +#endif diff --git a/mindspore/lite/nnacl/fp32/pack_fp32.h b/mindspore/lite/nnacl/fp32/pack_fp32.h index cf89227846..230f4eb957 100644 --- a/mindspore/lite/nnacl/fp32/pack_fp32.h +++ b/mindspore/lite/nnacl/fp32/pack_fp32.h @@ -44,6 +44,10 @@ void PackDepthwiseIndirectWeightC8Fp32(const void *src, void *dst, int height, i void Im2ColPackUnitFp32(const float *input_data, const ConvParameter *conv_param, float *packed_input, int real_cal_num, int block_index); +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) +void PackWeightConvDw3x3Fp32(const void *src, void *dst, int channel); +#endif + // Transpose 8X8 Fp32 block data typedef void (*Transpose8X8Fp32Func)(const float *src_ptr, float *dst_ptr, int src_stride, int dst_stride); #ifdef ENABLE_ARM64 diff --git a/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h b/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h index 8830d6146a..a101a62f5c 100644 --- a/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h +++ b/mindspore/lite/nnacl/intrinsics/ms_simd_instructions.h @@ -32,7 +32,6 @@ #define MS_ADDQ_EPI32 vaddq_s32 #define MS_MOVQ_F32 vmovq_n_f32 #define MS_MOVQ_EPI32 vmovq_n_s32 -#define MS_DUPQ_F32 vdupq_n_f32 // It is recommended to replace with MS_MOVQ_F32. #define MS_SUBQ_F32 vsubq_f32 #define MS_MLAQ_F32(src1, src2, src3) vmlaq_f32(src1, src2, src3) #define MS_STQ_F32 vst1q_f32 @@ -76,7 +75,6 @@ inline static float32x4_t vrecp(float32x4_t v) { #define MS_ADD256_EPI32 _mm256_add_epi32 #define MS_MOV256_F32 _mm256_set1_ps #define MS_MOV256_EPI32 _mm256_set1_epi32 -#define MS_DUP256_F32 _mm256_load_ps1 // It is recommended to replace with MS_MOV256_F32. #define MS_MLA256_F32(src1, src2, src3) _mm256_add_ps(src1, _mm256_mul_ps(src2, src3)) #define MS_ST256_F32 _mm256_storeu_ps #define MS_ST256_EPI32(src1, src2) _mm256_storeu_si256((__m256i *)(src1), src2) @@ -109,7 +107,6 @@ inline static float32x4_t vrecp(float32x4_t v) { #define MS_ADDQ_EPI32 _mm_add_epi32 #define MS_MOVQ_F32 _mm_set1_ps #define MS_MOVQ_EPI32 _mm_set1_epi32 -#define MS_DUPQ_F32 _mm_load_ps1 // It is recommended to replace with MS_MOVQ_F32. #define MS_MLAQ_F32(src1, src2, src3) _mm_add_ps(src1, _mm_mul_ps(src2, src3)) #define MS_STQ_F32 _mm_storeu_ps #define MS_STQ_EPI32(src1, src2) _mm_storeu_si128((__m128i *)(src1), src2) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc index 7815722d26..c103dcabc4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_delegate_fp32.cc @@ -21,6 +21,7 @@ #include "src/runtime/kernel/arm/fp32/convolution_depthwise_fp32.h" #include "src/runtime/kernel/arm/fp32/convolution_depthwise_slidewindow_fp32.h" #include "src/runtime/kernel/arm/fp32/convolution_depthwise_indirect_fp32.h" +#include "src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" @@ -354,8 +355,13 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector auto conv_param = reinterpret_cast(opParameter); kernel::LiteKernel *kernel = nullptr; if (opParameter != nullptr && opParameter->infer_flag_) { +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) + if (CheckConvDw1DWinograd(conv_param, ctx->thread_num_)) { + kernel = new (std::nothrow) kernel::ConvolutionDepthwise3x3CPUKernel(opParameter, inputs, outputs, ctx); + } +#endif #if defined(ENABLE_ARM64) || defined(ENABLE_AVX) - if (CheckConvDwUseIndirectBuffer(conv_param)) { + if (kernel == nullptr && CheckConvDwUseIndirectBuffer(conv_param)) { kernel = new (std::nothrow) kernel::ConvolutionDepthwiseIndirectCPUKernel(opParameter, inputs, outputs, ctx); } #endif @@ -367,7 +373,7 @@ kernel::LiteKernel *CpuConvDwFp32KernelCreator(const std::vector kernel = new (std::nothrow) kernel::ConvolutionDepthwiseCPUKernel(opParameter, inputs, outputs, ctx); } return kernel; -} +} // namespace mindspore::kernel kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *op_parameter, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.cc index e685fe2311..0149af5d46 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.cc @@ -18,8 +18,10 @@ #include "include/errorcode.h" #include "src/runtime/runtime_api.h" +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) using mindspore::lite::RET_ERROR; using mindspore::lite::RET_INFER_INVALID; +using mindspore::lite::RET_MEMORY_FAILED; using mindspore::lite::RET_OK; namespace mindspore::kernel { @@ -28,10 +30,6 @@ ConvolutionDepthwise3x3CPUKernel::~ConvolutionDepthwise3x3CPUKernel() { free(packed_weight_); packed_weight_ = nullptr; } - if (sliding_ != nullptr) { - delete sliding_; - sliding_ = nullptr; - } } int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() { @@ -39,22 +37,26 @@ int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() { auto weight_tensor = in_tensors_[kWeightIndex]; auto origin_weight = reinterpret_cast(weight_tensor->MutableData()); int channel = weight_tensor->Batch(); - int pack_weight_size = weight_tensor->Batch() * weight_tensor->Height() * weight_tensor->Width(); + int c4 = UP_ROUND(channel, C4NUM); + int pack_weight_size = c4 * C12NUM; - packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); if (packed_weight_ == nullptr) { - MS_LOG(ERROR) << "Malloc buffer failed."; - return RET_ERROR; + packed_weight_ = reinterpret_cast(malloc(pack_weight_size * sizeof(float))); + if (packed_weight_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } } - PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), channel); + PackWeightConvDw3x3Fp32(origin_weight, packed_weight_, channel); - bias_data_ = reinterpret_cast(malloc(channel * sizeof(float))); if (bias_data_ == nullptr) { - MS_LOG(ERROR) << "Malloc buffer failed."; - return RET_ERROR; + bias_data_ = reinterpret_cast(malloc(c4 * sizeof(float))); + if (bias_data_ == nullptr) { + MS_LOG(ERROR) << "Malloc buffer failed."; + return RET_ERROR; + } } - - memset(bias_data_, 0, channel * sizeof(float)); + memset(bias_data_, 0, c4 * sizeof(float)); if (in_tensors_.size() == kInputSize2) { auto bias_tensor = in_tensors_[kBiasIndex]; auto ori_bias = reinterpret_cast(bias_tensor->MutableData()); @@ -65,11 +67,6 @@ int ConvolutionDepthwise3x3CPUKernel::InitWeightBias() { } int ConvolutionDepthwise3x3CPUKernel::Init() { - sliding_ = new (std::nothrow) SlidingWindowParam; - if (sliding_ == nullptr) { - MS_LOG(ERROR) << "new sliding window param failed."; - return RET_ERROR; - } auto ret = InitWeightBias(); if (ret != 0) { MS_LOG(ERROR) << "Convolution depthwise 3x3 fp32 InitWeightBias failed."; @@ -83,15 +80,19 @@ int ConvolutionDepthwise3x3CPUKernel::Init() { int ConvolutionDepthwise3x3CPUKernel::ReSize() { ConvolutionBaseCPUKernel::Init(); - InitSlidingParamConvDw(sliding_, conv_param_, conv_param_->input_channel_); conv_param_->thread_num_ = MSMIN(thread_count_, conv_param_->output_h_); return RET_OK; } int ConvolutionDepthwise3x3CPUKernel::Execute(int task_id) { - auto buffer = buffer_ + 64 * 10 * 10 * task_id; + int units = UP_DIV(conv_param_->output_w_, C2NUM); // F(2, 3) contains 2 conv units + int c4 = UP_ROUND(conv_param_->input_channel_, C4NUM); + auto buffer = buffer_ + C12NUM * c4 * units * task_id; + int step_oh = UP_DIV(conv_param_->output_h_, conv_param_->thread_num_); + int start_oh = step_oh * task_id; + int end_oh = MSMIN(start_oh + step_oh, conv_param_->output_h_); ConvDw3x3(output_ptr_, buffer, input_ptr_, packed_weight_, reinterpret_cast(bias_data_), conv_param_, - sliding_, task_id); + start_oh, end_oh); return RET_OK; } @@ -105,25 +106,18 @@ int ConvDw3x3Run(void *cdata, int task_id) { return RET_OK; } -int ConvolutionDepthwise3x3CPUKernel::InitBuffer() { - int buffer_size = 64 * 10 * 10 * conv_param_->thread_num_; - buffer_ = reinterpret_cast(context_->allocator->Malloc(buffer_size * sizeof(float))); - if (buffer_ == nullptr) { - MS_LOG(ERROR) << "Malloc buffer failed."; - return RET_ERROR; - } - return RET_OK; -} - int ConvolutionDepthwise3x3CPUKernel::Run() { - auto ret = InitBuffer(); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Depthwise int8 ReSize error!"; - return ret; + int units = UP_DIV(conv_param_->output_w_, C2NUM); // F(2, 3) contains 2 conv units + int c4 = UP_ROUND(conv_param_->input_channel_, C4NUM); + int buffer_size = units * c4 * C12NUM * conv_param_->thread_num_; + buffer_ = reinterpret_cast(ctx_->allocator->Malloc(buffer_size * sizeof(float))); + if (buffer_ == nullptr) { + MS_LOG(ERROR) << "ConvDw3x3Run failed to allocate buffer"; + return RET_MEMORY_FAILED; } if (IsTrain() && is_trainable()) { - PackWeight(); + InitWeightBias(); } auto input_tensor = in_tensors_.at(kInputIndex); @@ -132,32 +126,21 @@ int ConvolutionDepthwise3x3CPUKernel::Run() { auto output_tensor = out_tensors_.at(kOutputIndex); output_ptr_ = reinterpret_cast(output_tensor->data_c()); - if (sliding_->top_ > 0 || sliding_->bottom_ < conv_param_->output_h_ || sliding_->left_ > 0 || - sliding_->right_ < conv_param_->output_w_) { - ConvDw3x3Pad(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast(bias_data_), conv_param_, sliding_); - } - ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Run, this, conv_param_->thread_num_); + auto ret = ParallelLaunch(this->context_->thread_pool_, ConvDw3x3Run, this, conv_param_->thread_num_); + ctx_->allocator->Free(buffer_); if (ret != RET_OK) { - context_->allocator->Free(buffer_); MS_LOG(ERROR) << "ConvDw3x3Run error: error_code[" << ret << "]"; return RET_ERROR; } - context_->allocator->Free(buffer_); return RET_OK; } -void ConvolutionDepthwise3x3CPUKernel::PackWeight() { - auto weight_tensor = in_tensors_.at(kWeightIndex); - auto origin_weight = reinterpret_cast(weight_tensor->MutableData()); - PackWeightKHWToHWKFp32(origin_weight, packed_weight_, weight_tensor->Height() * weight_tensor->Width(), - weight_tensor->Batch()); -} - int ConvolutionDepthwise3x3CPUKernel::Eval() { LiteKernel::Eval(); if (is_trainable()) { - PackWeight(); + InitWeightBias(); } return RET_OK; } } // namespace mindspore::kernel +#endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h index a919a58ccb..f02d04327e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_depthwise_3x3_fp32.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_ +#if defined(ENABLE_ARM) || (defined(ENABLE_SSE) && !defined(ENABLE_AVX)) #include #include "src/lite_kernel.h" #include "src/runtime/kernel/arm/base/convolution_base.h" @@ -39,14 +40,11 @@ class ConvolutionDepthwise3x3CPUKernel : public ConvolutionBaseCPUKernel { int Eval() override; private: - void PackWeight(); - int InitBuffer(); - SlidingWindowParam *sliding_ = nullptr; float *packed_weight_ = nullptr; float *input_ptr_ = nullptr; float *output_ptr_ = nullptr; float *buffer_ = nullptr; }; } // namespace mindspore::kernel - +#endif #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_DEPTHWISE_3X3_FP32_H_