From 6c0135ec97aa606e728bff025515930fa879758e Mon Sep 17 00:00:00 2001 From: zhangxuetong Date: Sat, 8 Aug 2020 10:33:29 +0800 Subject: [PATCH] modify fp16 conv creator --- .../src/runtime/kernel/arm/CMakeLists.txt | 2 ++ .../kernel/arm/fp16/convolution_fp16.cc | 18 ++++++++---- .../runtime/kernel/arm/fp32/convolution.cc | 29 ++----------------- .../kernel/arm/nnacl/winograd_utils.cc | 25 ++++++++++++++++ .../runtime/kernel/arm/nnacl/winograd_utils.h | 2 ++ 5 files changed, 44 insertions(+), 32 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt index 2a66748ecf..0fec5a929c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt +++ b/mindspore/lite/src/runtime/kernel/arm/CMakeLists.txt @@ -1,3 +1,5 @@ +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/) + file(GLOB KERNEL_SRC ${CMAKE_CURRENT_SOURCE_DIR}/base/*.cc nnacl/*.cc diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index 92277efdf6..fc48c4d188 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -23,6 +23,7 @@ #include "src/kernel_registry.h" #include "include/errorcode.h" #include "src/runtime/runtime_api.h" +#include "nnacl/winograd_utils.h" using mindspore::kernel::KERNEL_ARCH::kCPU; using mindspore::lite::KernelRegistrar; @@ -242,7 +243,7 @@ int ConvolutionFP16CPUKernel::Run() { auto out_tensor = outputs_.at(kOutputIndex); auto output_addr = reinterpret_cast(out_tensor->Data()); for (int j = 0; j < out_tensor->ElementsNum(); ++j) { - output_addr[j] = static_cast(fp16_out_[j]); + output_addr[j] = static_cast(fp16_out_[j]); } return RET_OK; } @@ -264,20 +265,27 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vectorinput_w_ = inputs.front()->Width(); conv_param->output_h_ = outputs.front()->Height(); conv_param->output_w_ = outputs.front()->Width(); - kernel::LiteKernel *kernel; + kernel::LiteKernel *kernel = nullptr; if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) { kernel = new (std::nothrow) kernel::Convolution3x3FP16CPUKernel(opParameter, inputs, outputs, ctx); } else { - kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); + bool use_winograd = false; + int out_unit; + InputTransformUnitFunc input_trans_func = nullptr; + OutputTransformUnitFunc output_trans_func = nullptr; + CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func); + if (kernel_h != 1 && kernel_w != 1 && !use_winograd) { + kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx); + } } if (kernel == nullptr) { - MS_LOG(ERROR) << "Create conv fp16 kernel failed."; + MS_LOG(DEBUG) << "Create conv fp16 kernel failed."; return nullptr; } auto ret = kernel->Init(); if (ret != RET_OK) { delete kernel; - MS_LOG(ERROR) << "Init kernel failed, name: " << opParameter->name_ << ", type: " + MS_LOG(INFO) << "Init fp16 kernel failed, name: " << opParameter->name_ << ", type: " << schema::EnumNamePrimitiveType(static_cast(opParameter->type_)); return nullptr; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc index bee5e46e75..f30d2af194 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution.cc @@ -220,32 +220,6 @@ int ConvolutionCPUKernel::Run() { return RET_OK; } -void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param, - InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) { - if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && - conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { - *output_unit = SelectOutputUnit(conv_param); - if (*output_unit > 1) { - *use_winograd = true; - int input_unit = conv_param->kernel_h_ + *output_unit - 1; - input_trans_func = GetInputTransFunc(input_unit); - if (input_trans_func == nullptr) { - MS_LOG(INFO) << "No matching input trans func. Turn back to common conv."; - *use_winograd = false; - } - output_trans_func = GetOutputTransFunc(input_unit, *output_unit); - if (output_trans_func == nullptr) { - MS_LOG(INFO) << "No matching output trans func. Turn back to common conv."; - *use_winograd = false; - } - } else { - *use_winograd = false; - } - } else { - *use_winograd = false; - } -} - kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const Context *ctx, @@ -270,7 +244,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vectorkernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && + conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) { + *output_unit = SelectOutputUnit(conv_param); + if (*output_unit > 1) { + *use_winograd = true; + int input_unit = conv_param->kernel_h_ + *output_unit - 1; + input_trans_func = GetInputTransFunc(input_unit); + if (input_trans_func == nullptr) { + *use_winograd = false; + } + output_trans_func = GetOutputTransFunc(input_unit, *output_unit); + if (output_trans_func == nullptr) { + *use_winograd = false; + } + } else { + *use_winograd = false; + } + } else { + *use_winograd = false; + } +} + diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_utils.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_utils.h index d7a7b7a69c..67bc39becd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_utils.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/winograd_utils.h @@ -54,5 +54,7 @@ InputTransformUnitFunc GetInputTransFunc(int input_unit); OutputTransformUnitFunc GetOutputTransFunc(int input_unit, int output_unit); +void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param, + InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func); #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_WINOGRAD_UTILS_H_