|
|
|
@ -220,32 +220,6 @@ int ConvolutionCPUKernel::Run() {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CheckIfUseWinograd(bool *use_winograd, int *output_unit, ConvParameter *conv_param,
|
|
|
|
|
InputTransformUnitFunc input_trans_func, OutputTransformUnitFunc output_trans_func) {
|
|
|
|
|
if (conv_param->kernel_w_ == conv_param->kernel_h_ && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 &&
|
|
|
|
|
conv_param->stride_h_ == 1 && conv_param->stride_w_ == 1) {
|
|
|
|
|
*output_unit = SelectOutputUnit(conv_param);
|
|
|
|
|
if (*output_unit > 1) {
|
|
|
|
|
*use_winograd = true;
|
|
|
|
|
int input_unit = conv_param->kernel_h_ + *output_unit - 1;
|
|
|
|
|
input_trans_func = GetInputTransFunc(input_unit);
|
|
|
|
|
if (input_trans_func == nullptr) {
|
|
|
|
|
MS_LOG(INFO) << "No matching input trans func. Turn back to common conv.";
|
|
|
|
|
*use_winograd = false;
|
|
|
|
|
}
|
|
|
|
|
output_trans_func = GetOutputTransFunc(input_unit, *output_unit);
|
|
|
|
|
if (output_trans_func == nullptr) {
|
|
|
|
|
MS_LOG(INFO) << "No matching output trans func. Turn back to common conv.";
|
|
|
|
|
*use_winograd = false;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
*use_winograd = false;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
*use_winograd = false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
|
|
|
|
const std::vector<lite::tensor::Tensor *> &outputs,
|
|
|
|
|
OpParameter *opParameter, const Context *ctx,
|
|
|
|
@ -270,7 +244,8 @@ kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::tensor::Ten
|
|
|
|
|
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param, input_trans_func, output_trans_func);
|
|
|
|
|
kernel::LiteKernel *kernel;
|
|
|
|
|
if (kernel_h == 1 && kernel_w == 1) {
|
|
|
|
|
kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx);
|
|
|
|
|
// kernel = new (std::nothrow) kernel::Convolution1x1CPUKernel(opParameter, inputs, outputs, ctx);
|
|
|
|
|
kernel = new (std::nothrow) kernel::ConvolutionCPUKernel(opParameter, inputs, outputs, ctx);
|
|
|
|
|
} else if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
|
|
|
|
|
kernel = new (std::nothrow) kernel::Convolution3x3CPUKernel(opParameter, inputs, outputs, ctx);
|
|
|
|
|
} else if (use_winograd) {
|
|
|
|
|