diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/ConvDwFp32Center.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/ConvDwFp32Center.S index 0252d28e3e..1e27860d72 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/ConvDwFp32Center.S +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/ConvDwFp32Center.S @@ -32,16 +32,6 @@ ConvDwFp32Center: ldr x14, [sp, #48] ldr x15, [sp, #56] - mov x16, #4 - mul x8, x8, x16 - mul x9, x9, x16 - mul x10, x10, x16 - mul x11, x11, x16 - mul x12, x12, x16 - mul x13, x13, x16 - mov x16, #16 - mul x19, x7, x16 - ld1 {v5.4s}, [x3] LoopH: @@ -52,20 +42,17 @@ ConvDwFp32Center: mov x16, x23 mov x17, x2 mov x20, x6 - ld1 {v0.4s}, [x3] - fadd v0.4s, v0.4s, v5.4s + mov v0.16b, v5.16b LoopKh: mov x18, x7 - mov x21, x17 mov x22, x16 LoopKw: ld1 {v1.4s}, [x22], x13 - ld1 {v2.4s}, [x21], #16 + ld1 {v2.4s}, [x17], #16 fmla v0.4s, v1.4s, v2.4s subs x18, x18, #1 bne LoopKw add x16, x16, x12 - add x17, x17, x19 subs x20, x20, #1 bne LoopKh cbnz x15, Relu6 diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/DeconvDwFp32Center.S b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/DeconvDwFp32Center.S index f23a100299..d88c61047c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/DeconvDwFp32Center.S +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/assembly/arm64/DeconvDwFp32Center.S @@ -27,16 +27,6 @@ DeconvDwFp32Center: ldr x11, [sp, #24] ldr x12, [sp, #32] - mov x13, #4 - mul x7, x7, x13 - mul x8, x8, x13 - mul x9, x9, x13 - mul x10, x10, x13 - mul x11, x11, x13 - mul x12, x12, x13 - mov x13, #16 - mul x14, x6, x13 - LoopH: mov x15, x0 mov x16, x1 @@ -45,20 +35,18 @@ DeconvDwFp32Center: mov x18, x15 mov x19, x2 mov x20, x5 + dup v0.4s, wzr LoopKh: mov x21, x18 - mov x22, x19 mov x13, x6 LoopKw: - ld1 {v0.4s}, [x21] ld1 {v1.4s}, [x16] - ld1 {v2.4s}, [x22], #16 + ld1 {v2.4s}, [x19], #16 fmla v0.4s, v1.4s, v2.4s st1 {v0.4s}, [x21], x12 subs x13, x13, #1 bne LoopKw add x18, x18, x11 - add x19, x19, x14 subs x20, x20, #1 bne LoopKh add x15, x15, x10 diff --git a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.cc b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.cc index 97cba0972b..90d7240537 100644 --- a/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.cc +++ b/mindspore/lite/src/runtime/kernel/arm/opclib/fp32/conv_depthwise.cc @@ -120,13 +120,10 @@ void DepthwiseBorder(float *dst, const float *src, const float *weight, const fl } // height loop } +#ifndef ENABLE_ARM64 void DepthwiseCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step, int in_kw_step, bool is_relu, bool is_relu6) { -#ifdef ENABLE_ARM64 - ConvDwFp32Center(dst, src, weight, bias, height, width, kernel_h, kernel_w, out_h_step, block_channel, - in_sh_step, in_sw_step, in_kh_step, in_kw_step, is_relu, is_relu6); -#else float *dst_h = dst; const float *src_h = src; for (int oh = 0; oh < height; oh++) { @@ -139,17 +136,9 @@ void DepthwiseCenter(float *dst, const float *src, const float *weight, const fl const float *src_kw = src_kh; const float *weight_kw = weight_kh; for (int kw = 0; kw < kernel_w; kw++) { -#ifdef ENABLE_ARM64 - float32x4_t src_4 = vld1q_f32(src_kw); - float32x4_t weight_4 = vld1q_f32(weight_kw); - float32x4_t dst_4 = vld1q_f32(dst_w); - dst_4 = vfmaq_f32(dst_4, src_4, weight_4); - vst1q_f32(dst_w, dst_4); -#else for (int c = 0; c < C4NUM; c++) { dst_w[c] += src_kw[c] * weight_kw[c]; } -#endif src_kw += in_kw_step; weight_kw += C4NUM; } // kernel_w loop @@ -168,8 +157,8 @@ void DepthwiseCenter(float *dst, const float *src, const float *weight, const fl dst_h += out_h_step; src_h += in_sh_step; } // dst_height loop -#endif } +#endif // conv depthwise fp32: sliding window void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weight_data, const float *bias_data, @@ -196,11 +185,18 @@ void ConvDwC4Fp32(float *output_data, const float *input_data, const float *weig int in_w_start = sliding->left_ * conv_param->stride_w_ - conv_param->pad_w_; const float *in_t = src_data + in_h_start * sliding->in_h_step_ + in_w_start * sliding->block_channel_; float *out_t = dst_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; - +#ifdef ENABLE_ARM64 + ConvDwFp32Center(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), + sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), + sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float), + sliding->in_kw_step_ * sizeof(float), conv_param->is_relu_, conv_param->is_relu6_); +#else DepthwiseCenter(out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_); +#endif } } // output C4 loop src += sliding->in_step_; @@ -265,13 +261,10 @@ void DeconvDepthwiseBorder(float *dst, const float *src, const float *weight, in } // height loop } +#ifndef ENABLE_ARM64 void DeconvDepthwiseCenter(float *dst, const float *src, const float *weight, int height, int width, int kernel_h, int kernel_w, int out_h_step, int block_channel, int in_sh_step, int in_sw_step, int in_kh_step, int in_kw_step) { -#ifdef ENABLE_ARM64 - DeconvDwFp32Center(dst, src, weight, height, width, kernel_h, kernel_w, out_h_step, block_channel, - in_sh_step, in_sw_step, in_kh_step, in_kw_step); -#else float *dst_h = dst; const float *src_h = src; for (int oh = 0; oh < height; oh++) { @@ -284,17 +277,9 @@ void DeconvDepthwiseCenter(float *dst, const float *src, const float *weight, in float *dst_kw = dst_kh; const float *weight_kw = weight_kh; for (int kw = 0; kw < kernel_w; kw++) { -#ifdef ENABLE_ARM64 - float32x4_t src_4 = vld1q_f32(src_w); - float32x4_t weight_4 = vld1q_f32(weight_kw); - float32x4_t dst_4 = vld1q_f32(dst_kw); - dst_4 = vfmaq_f32(dst_4, src_4, weight_4); - vst1q_f32(dst_kw, dst_4); -#else for (int c = 0; c < C4NUM; c++) { dst_kw[c] += src_w[c] * weight_kw[c]; } -#endif dst_kw += in_kw_step; weight_kw += C4NUM; } // kernel_w loop @@ -307,8 +292,8 @@ void DeconvDepthwiseCenter(float *dst, const float *src, const float *weight, in dst_h += in_sh_step; src_h += out_h_step; } // dst_height loop -#endif } +#endif void DeconvDepthwisePostFunc(float *dst, const float *bias, int block_channel, const ConvParameter *conv_param) { float *dst_k = dst; @@ -347,10 +332,18 @@ void DeconvDwC4Fp32(float *output_data, const float *input_data, const float *we float *out_t = dst_data + oh_h_start * sliding->in_h_step_ + oh_w_start * sliding->block_channel_; const float *in_t = src_data + sliding->top_ * sliding->out_h_step_ + sliding->left_ * sliding->block_channel_; +#ifdef ENABLE_ARM64 + DeconvDwFp32Center(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, + conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_ * sizeof(float), + sliding->block_channel_ * sizeof(float), sliding->in_sh_step_ * sizeof(float), + sliding->in_sw_step_ * sizeof(float), sliding->in_kh_step_ * sizeof(float), + sliding->in_kw_step_ * sizeof(float)); +#else DeconvDepthwiseCenter(out_t, in_t, weight, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_, conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_, sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_); +#endif } DeconvDepthwisePostFunc(dst_data, bias, sliding->block_channel_, conv_param); } // output C4 loop