!7245 [MSLITE]deconv optimize common

Merge pull request !7245 from ling/sr
pull/7245/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 92288517df

@ -71,4 +71,57 @@ typedef struct SlidingWindowParam {
int kernel_step_;
} SlidingWindowParam;
#define DECONV_WINOGRAD_DEFAULT_UNIT 3
#define DECONV_WINOGRAD_DEFAULT_TILE 8
#define DECONV_WINOGRAD_BUFFER_COUNT 8
typedef struct DeConvWg {
void *b_buffer_;
void *AT_;
void *BT_;
int kh_;
int kw_;
int k_;
int i_;
int o_;
} DeConvWg;
typedef struct DeConvWgABuffer {
bool buf_init_;
bool trans_formed_;
void *middle_buffer_;
void *dest_buffer_;
} DeConvWgABuffer;
typedef struct DeConvComputeUnit {
void *weight_;
void *tmp_buffer_;
int w_start_;
int h_start_;
int w_size_;
int h_size_;
bool use_winograd_;
DeConvWg winograd_;
} DeConvComputeUnit;
typedef struct DeConvParam {
DeConvComputeUnit *compute_units_;
int compute_size_;
DeConvWgABuffer a_buffer_[DECONV_WINOGRAD_BUFFER_COUNT];
int input_plane_;
int output_plane_;
int kernel_plane_;
int ic_div4_;
int oc_div4_;
int ic_up4_;
int oc_up4_;
int thread_num_;
int in_tile_count_;
int in_tile_h_count_;
int in_tile_w_count_;
int out_tile_h_;
int out_tile_w_;
} DeConvParam;
#endif // MINDSPORE_LITE_NNACL_CONV_PARAMETER_H_

@ -22,7 +22,7 @@ typedef enum ErrorCodeCommonEnum {
NNACL_ERR = 1,
NNACL_NULL_PTR,
NNACL_PARAM_INVALID,
OPLIB_COMMON_END = 9999
NNACL_COMMON_END = 9999
} ErrorCodeCommonEnum;
typedef enum ErrorCodeFp32OpEnum {
@ -34,6 +34,7 @@ typedef enum ErrorCodeFp32OpEnum {
NNACL_ERRCODE_LOG_NEGATIVE_OR_ZERO,
NNACL_ERRCODE_DIVISOR_ZERO,
NNACL_ERRCODE_INDEX_OUT_OF_RANGE,
NNACL_ERRCODE_WINOGRAD_GENERATOR_ERROR,
NNACL_ERRCODE_OP_FP32_END = 19999
} ErrorCodeFp32OpEnum;

@ -15,8 +15,9 @@
*/
#include "nnacl/fp32/common_func.h"
void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6, int size) {
size_t plane_size, size_t plane_stride, size_t oc_stride, bool is_relu, bool is_relu6, int size) {
int oc_div = 0, oc_mod = 0;
for (int oc = 0; oc < output_channel; oc++) {
if (size != 0) {
@ -26,8 +27,8 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p
return;
}
for (int hw = 0; hw < plane_size; hw++) {
int src_index = oc_div * size * plane_size + hw * size + oc_mod;
int dst_index = hw * stride + oc;
int src_index = oc_div * size * plane_stride + hw * size + oc_mod;
int dst_index = hw * oc_stride + oc;
float value = src_ptr_[src_index];
if (bias_ptr != NULL) {
value = value + bias_ptr[oc];
@ -43,7 +44,8 @@ void PostConvFuncComm(const float *src_ptr_, float *out_ptr, const float *bias_p
void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6) {
#ifndef ENABLE_ARM
PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, stride, is_relu, is_relu6, C8NUM);
PostConvFuncComm(c8_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_size, stride, is_relu, is_relu6,
C8NUM);
#else
size_t oc8mod = output_channel % C8NUM;
size_t oc8div = output_channel - oc8mod;
@ -55,6 +57,59 @@ void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bi
return;
}
void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t plane_stride, bool is_relu, bool is_relu6) {
PostConvFuncComm(c4_out_ptr, out_ptr, bias_ptr, output_channel, plane_size, plane_stride, output_channel, is_relu,
is_relu6, C4NUM);
return;
}
void WinogradMatrixProductLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) {
int unitStep = 4 * length;
for (int y = 0; y < h; ++y) {
float *dstY = M + y * w * unitStep;
for (int x = 0; x < w; ++x) {
float *dstX = dstY + x * unitStep;
const float *srcX = S + x * unitStep;
memset(dstX, 0, unitStep * sizeof(float));
for (int i = 0; i < k; ++i) {
float b = B[i * h + y];
const float *srcY = srcX + i * w * unitStep;
if (0.0f == b) {
continue;
}
for (int j = 0; j < unitStep; ++j) {
dstX[j] += srcY[j] * b;
}
}
}
}
}
// M = S * B , M = w*h * l, S = k*h * l, B = w*k
void WinogradMatrixProductRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length) {
int unitStep = 4 * length;
for (int y = 0; y < h; ++y) {
float *dstY = M + y * w * unitStep;
const float *srcY = S + y * k * unitStep;
for (int x = 0; x < w; ++x) {
float *dstX = dstY + x * unitStep;
memset(dstX, 0, unitStep * sizeof(float));
for (int i = 0; i < k; ++i) {
const float *srcX = srcY + i * unitStep;
float b = B[i * h + x];
if (0.0f == b) {
continue;
}
for (int j = 0; j < unitStep; ++j) {
dstX[j] += srcX[j] * b;
}
}
}
}
}
union float32_bits {
unsigned int u;
float f;

@ -29,6 +29,12 @@ extern "C" {
void PostConvFuncFp32C8(const float *c8_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t stride, bool is_relu, bool is_relu6);
void PostConvFuncFp32C4(const float *c4_out_ptr, float *out_ptr, const float *bias_ptr, size_t output_channel,
size_t plane_size, size_t plane_stride, bool is_relu, bool is_relu6);
void WinogradMatrixProductLeft(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length);
void WinogradMatrixProductRight(const float *S, const float *B, float *M, size_t w, size_t h, size_t k, size_t length);
float ShortToFloat32(uint16_t src_value);
uint16_t Float32ToShort(float src_value);

@ -33,9 +33,10 @@ void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, in
return;
}
int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param) {
/* row12x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
void DeConvPostFp32C8(const float *src, float *tmp, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param) {
/* arm64 row12x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
/* arm32 row4x8-major(ih*iw x oc*kh*kw) -> row8-major(oh*ow x oc) */
size_t input_plane = conv_param->input_w_ * conv_param->input_h_;
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
@ -45,11 +46,11 @@ int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *
#else
const int tile_num = 12;
#endif
int in_plane12 = UP_ROUND(input_plane, tile_num);
int in_plane_round = UP_ROUND(input_plane, tile_num);
int src_iw_stride = C8NUM;
int src_ih_stride = conv_param->input_w_ * C8NUM;
int src_kw_stride = in_plane12 * C8NUM;
int src_kh_stride = in_plane12 * conv_param->kernel_w_ * C8NUM;
int src_kw_stride = in_plane_round * C8NUM;
int src_kh_stride = in_plane_round * conv_param->kernel_w_ * C8NUM;
int dst_oh_stride = conv_param->output_w_ * C8NUM;
int dst_ow_stride = C8NUM;
int dst_kh_stride = conv_param->dilation_h_ * conv_param->output_w_ * C8NUM;
@ -57,7 +58,7 @@ int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *
for (int c = 0; c < oc8; c += 8) {
float *dst_ptr = tmp + c * output_plane;
const float *src_ptr = src + c * in_plane12 * kernel_plane;
const float *src_ptr = src + c * in_plane_round * kernel_plane;
memset(dst_ptr, 0, output_plane * C8NUM * sizeof(float));
for (int ih = 0; ih < conv_param->input_h_; ih++) {
@ -104,5 +105,5 @@ int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *
PostConvFuncFp32C8(tmp, dst, bias, output_channel, output_plane, conv_param->output_channel_,
conv_param->act_type_ == ActType_Relu, conv_param->act_type_ == ActType_Relu6);
return NNACL_OK;
return;
}

@ -22,13 +22,15 @@
#include "nnacl/conv_parameter.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp32/common_func.h"
#include "nnacl/fp32/conv.h"
#include "nnacl/minimal_filtering_generator.h"
#ifdef __cplusplus
extern "C" {
#endif
void PackDeConvWeightFp32(const float *weight, float *dst, int input_channel, int output_channel, int plane);
int DeConvPostFp32C12x8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param);
void DeConvPostFp32C8(const float *src, float *tmp_out, const float *bias, float *dst, int output_channel,
ConvParameter *conv_param);
#ifdef __cplusplus
}
#endif

File diff suppressed because it is too large Load Diff

@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_NNACL_FP32_DECONV_WINOGRAD_H_
#define MINDSPORE_LITE_NNACL_FP32_DECONV_WINOGRAD_H_
#include <string.h>
#include "nnacl/pack.h"
#include "nnacl/op_base.h"
#include "nnacl/conv_parameter.h"
#include "nnacl/errorcode.h"
#include "nnacl/fp32/common_func.h"
#include "nnacl/minimal_filtering_generator.h"
#ifdef __cplusplus
extern "C" {
#endif
int PackDeConvWgDataFp32(float *nhwc_weight, DeConvComputeUnit *unit, ConvParameter *conv_param,
DeConvParam *deconv_param);
void DeconvWg(float *nhwc_input_, float *tile_in, float *tile_out, int start_index, int calculate_count,
ConvParameter *conv_param, DeConvParam *deconv_param, int task_id);
void DeconvWgPost(float *tile_out, float *nc4hw4_output, ConvParameter *conv_param, DeConvParam *deconv_param,
int calculate_count, int tile_index);
#ifdef __cplusplus
}
#endif
#endif // MINDSPORE_LITE_NNACL_FP32_DECONV_WINOGRAD_H_

@ -254,3 +254,88 @@ void MatrixMultiplyVec(const float32x4_t *matrix_a, const float32x4_t *matrix_b,
}
}
#endif
int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, float *matrix_gt,
int oc_block, int input_unit, int kernel_unit, int channel, int batch, bool pack) {
// original weight format : ohwi
int oc_block_num = UP_DIV(batch, oc_block);
int block_stride = channel * oc_block;
int block_num_stride = block_stride * oc_block_num;
// trans_filter = G*g*GT (g represents weight_data)
// separate into two steps ===> tmp = (g * GT)T ===> trans = (tmp * GT)T use same function:MatrixMultiplyWinograd
float *tmp_data = (float *)(malloc(channel * input_unit * kernel_unit * sizeof(float)));
if (tmp_data == NULL) {
return NNACL_ERR;
}
float *trans_out_data = (float *)(malloc(channel * input_unit * input_unit * sizeof(float)));
if (trans_out_data == NULL) {
free(tmp_data);
return NNACL_ERR;
}
#ifndef ENABLE_ARM
float *tmp_data1 = (float *)(malloc(channel * input_unit * kernel_unit * sizeof(float)));
if (tmp_data1 == NULL) {
free(tmp_data);
free(trans_out_data);
return NNACL_ERR;
}
float *trans_out_data1 = (float *)(malloc(channel * input_unit * input_unit * sizeof(float)));
if (trans_out_data1 == NULL) {
free(tmp_data);
free(tmp_data1);
free(trans_out_data);
return NNACL_ERR;
}
#endif
int input_oz_offset = kernel_unit * kernel_unit * channel;
for (int i = 0; i < batch; i++) {
int out_c_block = i / oc_block;
int out_c_res = i % oc_block;
int output_oz_offset = out_c_block * block_stride + out_c_res;
#ifndef ENABLE_ARM
// tmp_data = g * GT
MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit, kernel_unit, input_unit,
channel, channel * 4);
// tmp_data1 = (tmp_data)T
PackHWCToWHC(tmp_data, tmp_data1, kernel_unit, input_unit, channel);
// trans_out_data1 = tmp * GT
MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit, kernel_unit, input_unit, channel,
channel * 4);
// trans_out_data = (trans_out_data1)T
PackHWCToWHC(trans_out_data1, trans_out_data, input_unit, input_unit, channel);
#else
// tmp = (g * GT)T
MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit, kernel_unit, input_unit,
channel, channel * 4);
// trans = (tmp * GT)T
MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit, kernel_unit, input_unit, channel,
channel * 4);
#endif
if (pack) {
int in_offset = 0;
for (int j = 0; j < input_unit; ++j) {
for (int k = 0; k < input_unit; ++k) {
for (int c = 0; c < channel; ++c) {
*(winograd_data + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c];
}
in_offset += channel;
output_oz_offset += block_num_stride;
}
}
} else {
memcpy(winograd_data + i * channel * input_unit * input_unit, trans_out_data,
channel * input_unit * input_unit * sizeof(float));
}
}
#ifndef ENABLE_ARM
free(tmp_data1);
free(trans_out_data1);
#endif
free(tmp_data);
free(trans_out_data);
return NNACL_OK;
}

@ -20,6 +20,8 @@
#ifdef ENABLE_ARM
#include <arm_neon.h>
#endif
#include <stdbool.h>
#include "nnacl/pack.h"
#ifdef __cplusplus
extern "C" {
@ -47,6 +49,9 @@ int CookToomFilter(float *matrix_a, float *matrix_at, float *matrix_b, float *ma
void MatrixMultiplyWinograd(const float *matix_a, const float *matrix_b, float *matrix_c, int m, int k, int n,
int in_channel, int c4_channel);
int WinogradWeightTransform(const float *weight_data, float *winograd_data, float *matrix_g, float *matrix_gt,
int oc_block, int input_unit_, int kernel_unit_, int channel, int batch, bool pack);
#ifdef ENABLE_ARM
void MatrixMultiplyVec(const float32x4_t *matrix_a, const float32x4_t *matrix_b, float32x4_t *matrix_c,
const float *bias, int m, int k, int n);

@ -36,89 +36,9 @@ int ConvolutionWinogradCPUKernel::WinogradFilterTransform(const float *weight_da
MS_LOG(ERROR) << "Divide by zero";
return RET_ERROR;
}
// original weight format : ohwi
auto channel_in = conv_param_->input_channel_;
auto channel_out = conv_param_->output_channel_;
int oc_block_num = UP_DIV(channel_out, oc_block);
int block_stride = channel_in * oc_block;
int block_num_stride = block_stride * oc_block_num;
// trans_filter = G*g*GT (g represents weight_data)
// separate into two steps ===> tmp = (g * GT)T ===> trans = (tmp * GT)T use same function:MatrixMultiplyWinograd
auto tmp_data = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float)));
if (tmp_data == nullptr) {
MS_LOG(ERROR) << "malloc tmp_data failed.";
return RET_MEMORY_FAILED;
}
auto trans_out_data = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float)));
if (trans_out_data == nullptr) {
free(tmp_data);
MS_LOG(ERROR) << "malloc trans_out_data failed.";
return RET_MEMORY_FAILED;
}
#ifndef ENABLE_ARM
auto tmp_data1 = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * kernel_unit_ * sizeof(float)));
if (tmp_data1 == nullptr) {
free(tmp_data);
free(trans_out_data);
MS_LOG(ERROR) << "malloc tmp_data1 failed.";
return RET_MEMORY_FAILED;
}
auto trans_out_data1 = reinterpret_cast<float *>(malloc(channel_in * input_unit_ * input_unit_ * sizeof(float)));
if (trans_out_data1 == nullptr) {
free(tmp_data);
free(tmp_data1);
free(trans_out_data);
MS_LOG(ERROR) << "malloc trans_out_data1 failed.";
return RET_MEMORY_FAILED;
}
#endif
int input_oz_offset = kernel_unit_ * kernel_unit_ * channel_in;
for (int i = 0; i < channel_out; i++) {
int out_c_block = i / oc_block;
int out_c_res = i % oc_block;
int output_oz_offset = out_c_block * block_stride + out_c_res;
#ifndef ENABLE_ARM
// tmp_data = g * GT
MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_,
input_unit_, channel_in, channel_in * 4);
// tmp_data1 = (tmp_data)T
PackHWCToWHC(tmp_data, tmp_data1, kernel_unit_, input_unit_, channel_in);
// trans_out_data1 = tmp * GT
MatrixMultiplyWinograd(tmp_data1, matrix_gt, trans_out_data1, input_unit_, kernel_unit_, input_unit_, channel_in,
channel_in * 4);
// trans_out_data = (trans_out_data1)T
PackHWCToWHC(trans_out_data1, trans_out_data, input_unit_, input_unit_, channel_in);
#else
// tmp = (g * GT)T
MatrixMultiplyWinograd(weight_data + i * input_oz_offset, matrix_gt, tmp_data, kernel_unit_, kernel_unit_,
input_unit_, channel_in, channel_in * 4);
// trans = (tmp * GT)T
MatrixMultiplyWinograd(tmp_data, matrix_gt, trans_out_data, input_unit_, kernel_unit_, input_unit_, channel_in,
channel_in * 4);
#endif
int in_offset = 0;
for (int j = 0; j < input_unit_; ++j) {
for (int k = 0; k < input_unit_; ++k) {
for (int c = 0; c < channel_in; ++c) {
*(trans_weight_ + output_oz_offset + c * oc_block) = trans_out_data[in_offset + c];
}
in_offset += channel_in;
output_oz_offset += block_num_stride;
}
}
}
#ifndef ENABLE_ARM
free(tmp_data1);
free(trans_out_data1);
#endif
free(tmp_data);
free(trans_out_data);
return RET_OK;
return WinogradWeightTransform(weight_data, trans_weight_, matrix_g, matrix_gt, oc_block, input_unit_, kernel_unit_,
conv_param_->input_channel_, conv_param_->output_channel_, true);
}
int ConvolutionWinogradCPUKernel::InitWeightBias() {

@ -15,6 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32/deconvolution.h"
#include "src/runtime/kernel/arm/fp32/deconvolution_winograd.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
@ -125,9 +126,9 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
matmul_param_->col_, OutType_C8);
#endif
DeConvPostFp32C12x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_,
reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM,
output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_);
DeConvPostFp32C8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_,
reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM,
output_ptr_ + task_id * thread_stride_ * C8NUM, oc_res, conv_param_);
return RET_OK;
}
@ -246,7 +247,17 @@ kernel::LiteKernel *CpuDeConvFp32KernelCreator(const std::vector<lite::Tensor *>
}
weight_tensor->SetData(dequant_weight);
}
auto kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
kernel::LiteKernel *kernel;
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
if ((conv_param->stride_h_ != 1 || conv_param->stride_w_ != 1) &&
(conv_param->dilation_w_ == 1 && conv_param->dilation_h_ == 1)) {
/* DeConvolutionWinogradCPUKernel */
kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) kernel::DeConvolutionCPUKernel(opParameter, inputs, outputs, ctx, primitive);
}
if (kernel == nullptr) {
MS_LOG(ERROR) << "kernel is nullptr.";
if (weight_tensor->data_type() == kNumberTypeInt8) {

@ -0,0 +1,70 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_WINOGRAD_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_WINOGRAD_H_
#include <float.h>
#include <vector>
#include "src/lite_kernel.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "schema/model_generated.h"
#include "nnacl/fp32/matmul.h"
#include "nnacl/fp32/deconv_winograd.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
namespace mindspore::kernel {
class DeConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel {
public:
DeConvolutionWinogradCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {
deconv_param_ = new DeConvParam();
for (auto &wg : deconv_param_->a_buffer_) {
wg.buf_init_ = false;
}
}
~DeConvolutionWinogradCPUKernel() override;
int Init() override;
int Run() override;
int ReSize() override;
public:
int DoDeconv(int task_id);
int DeDeconvPost(int task_id);
private:
int InitComputeParam();
int InitDataParam();
int InitParameter();
void FreeDeconvParam();
void FreeResizeBuf();
private:
DeConvParam *deconv_param_;
float *nhwc_input_ = nullptr;
float *nhwc_output_ = nullptr;
float *nc4hw4_output_ = nullptr;
float *tile_input_ = nullptr;
float *tile_output_ = nullptr;
std::mutex lock_;
int thread_num_hw_;
int thread_stride_hw_;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_DECONVOLUTION_WINOGRAD_H_
Loading…
Cancel
Save