replace 8x8 block with 12x8 block in common conv

pull/6785/head
fuzhiye 4 years ago
parent 8558d4f06e
commit 2b056b7a28

File diff suppressed because it is too large Load Diff

@ -28,46 +28,19 @@
#include "nnacl/fp32/conv_depthwise.h"
typedef float *TmpBufferAddress;
typedef float *Matrices;
typedef void (*GEMM_FUNC_FP32)(float *output, const float *input, const float *weight, const float *bias, size_t step,
size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4,
size_t relu, size_t relu6);
#ifdef __cplusplus
extern "C" {
#endif
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding);
void SWCenter(float *dst, const float *src, const float *weight, const float *bias, int height, int width, int kernel_h,
int kernel_w, int out_h_step, int block_channel, int ic4, int in_sh_step, int in_sw_step, int in_kh_step,
int in_kw_step, bool is_relu, bool is_relu6);
// fp32 sliding window
void ConvSWFp32(const float *input_data, const float *packed_weight, const float *bias_data, float *tmp_out_block,
float *output_data, int task_id, ConvParameter *conv_param, SlidingWindowParam *slidingWindow_param);
// fp32 convolution common (im2col+gemm)
void ConvFp32(float *input_data, float *packed_input, float *packed_weight, const float *bias_data,
float *tmp_out_block, float *output_data, int task_id, ConvParameter *conv_param,
GEMM_FUNC_FP32 gemm_func);
float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param);
// fp32 convolution winograd
void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransFunc in_func,
OutputTransFunc out_func);
void UnPackWinogradOutput(const float *src, float *dst, int batch, int height, int width, int channel, int output_unit);
void UnPackWinogradReluOutput(const float *src, float *dst, int batch, int height, int width, int channel,
int output_unit);
void UnPackWinogradRelu6Output(const float *src, float *dst, int batch, int height, int width, int channel,
int output_unit);
// fp32 conv3x3
void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_data, TmpBufferAddress *buffer_list,
int task_id, ConvParameter *conv_param);
#ifdef __cplusplus
}
#endif

@ -18,50 +18,6 @@
#include <string.h>
#include <stdlib.h>
void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight, int oc_block,
int oc_block_num) {
// original weight format : ohwi
if (oc_block_num == 0) {
return;
}
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_channel = conv_param->input_channel_;
int out_channel = conv_param->output_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc_block * oc_block_num * ic4 * C4NUM * kernel_plane;
int unit_size = oc_block * C4NUM;
const int block_size = pack_weight_size / oc_block_num;
for (int m = 0; m < kernel_plane; m++) {
int kernel_plane_stride = m * in_channel;
int packed_kernel_plane_stride = m * unit_size * ic4;
for (int i = 0; i < ic4; i++) {
int channel_block_stride = kernel_plane_stride + i * C4NUM;
int packed_channel_block_size = packed_kernel_plane_stride + i * unit_size;
int ic_remainder = in_channel - i * C4NUM;
int real_ic_num = ic_remainder < C4NUM ? ic_remainder : C4NUM;
for (int h = 0; h < real_ic_num; h++) {
int block_stride = channel_block_stride + h;
int packed_block_stride = packed_channel_block_size + h * oc_block;
for (int j = 0; j < oc_block_num; j++) {
int kernel_block_stride = block_stride + j * oc_block * kernel_plane * in_channel;
int packed_kernel_block_size = packed_block_stride + j * block_size;
int oc_remainder = out_channel - j * oc_block;
int real_oc_num = oc_remainder < oc_block ? oc_remainder : oc_block;
for (int k = 0; k < real_oc_num; k++) {
float *origin_data_ptr = weight_data + kernel_block_stride + k * kernel_plane * in_channel;
float *packed_data_ptr = packed_weight + packed_kernel_block_size + k;
*packed_data_ptr = *origin_data_ptr;
}
} // kernel block loop
} // inchannel block loop
} // channel block loop
} // kernel plane loop
}
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel) {
return PackNCHWToNHWCFp32(src, dst, 1, plane, channel);
}
@ -301,6 +257,7 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
// input format : nhwc
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int kernel_plane = kernel_h * kernel_w;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_u_;
@ -311,8 +268,6 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
int ic4_minus = in_channel / C4NUM;
int ic4 = UP_DIV(in_channel, C4NUM);
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
@ -323,31 +278,25 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
for (int n = kw_s; n < kw_e; n++) {
int input_x_stride = input_y_stride + n * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + n) * C8NUM * C4NUM * ic4 + i * C4NUM;
for (int m = 0; m < ic4_minus; m++) {
int channel_block_stride = input_x_stride + m * C4NUM;
int channel_block_offset = input_plane_offset + m * C8NUM * C4NUM;
#ifdef ENABLE_NEON
vst1q_f32(packed_input + channel_block_offset, vld1q_f32(input_data + channel_block_stride));
#else
for (int k = 0; k < C4NUM; ++k) {
(packed_input + channel_block_offset)[k] = (input_data + channel_block_stride)[k];
}
#endif
} // channel_block loop
int ic_res = conv_param->input_channel_ - ic4_minus * C4NUM;
for (int l = 0; l < ic_res; ++l) {
int channel_block_stride = input_x_stride + ic4_minus * C4NUM + l;
int channel_block_offset = input_plane_offset + ic4_minus * C8NUM * C4NUM + l;
packed_input[channel_block_offset] = input_data[channel_block_stride];
if (dilation_w == 1 && dilation_h == 1) {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * in_w * in_channel + input_stride;
int input_x_stride = input_y_stride + kw_s * in_channel;
int input_plane_offset = (j * kernel_w + kw_s) * in_channel + i * in_channel * kernel_plane;
memcpy(packed_input + input_plane_offset, input_data + input_x_stride,
(kw_e - kw_s) * in_channel * sizeof(float));
} // kernel_h loop
} else {
for (int j = kh_s; j < kh_e; j++) {
int input_y_stride = j * dilation_h * in_w * in_channel + input_stride;
for (int k = kw_s; k < kw_e; ++k) {
int input_x_stride = input_y_stride + k * dilation_w * in_channel;
int input_plane_offset = (j * kernel_w + k) * in_channel + i * in_channel * kernel_plane;
memcpy(packed_input + input_plane_offset, input_data + input_x_stride, in_channel * sizeof(float));
}
} // kernel_w loop
} // kernel_h loop
} // tile num loop
} // kernel_h loop
}
} // tile num loop
}
void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real_cal_num, int block_index,

@ -51,9 +51,6 @@ void MatrixPack(const float *src, float *dst, int row, int ic4, int stride);
void PackInputToC8Int8(const int8_t *input_data, int16_t *packed_input, ConvParameter *conv_param);
void PackWeightFp32(float *weight_data, ConvParameter *conv_param, float *packed_weight, int oc_block,
int oc_block_num);
void PackWeightKHWToHWKFp32(const void *src, void *dst, int plane, int channel);
void PackWeightInt8(int8_t *weight_data, ConvParameter *conv_param, int8_t *packed_weight, int32_t *weight_sum);

File diff suppressed because it is too large Load Diff

@ -38,21 +38,6 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
void WinogradOutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int cal_num,
int out_tile_index, int output_unit_num, ConvParameter *conv_param, OutputTransFunc func);
// for fp32 convolution 3x3 filter/input/output transform
void Conv3x3Fp32InputUnit(const float *tmp_data, float *trans_input_data, size_t step);
void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, float *tmp_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param);
void Conv3x3Fp32FilterTransform(float *weight_data, float *trans_weight, int iC4, int output_channel, int kernel_plane,
int oc_block);
void Conv3x3Fp32OutputUnit(const float *gemm_out, const float *bias_data, float *output_data, bool h_not_bound,
bool w_not_bound, int output_w);
void Conv3x3Fp32OutputTransform(const float *gemm_out, float *out_data, const float *bias_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param);
// for int8 convolution 3x3 filter/input/output transform
void Conv3x3Int8InputUnit(int16_t *tmp_data, int16_t *trans_input_data, size_t step, int input_zp);

@ -25,6 +25,7 @@ static InputTransFunc InputTransFuncList[] = {
NULL, NULL, NULL, NULL, InputTransform4x4Unit, NULL, InputTransform6x6Unit, NULL, InputTransform8x8Unit};
static OutputTransFunc OutputTransFuncList4[] = {NULL, NULL, OutputTransform4x2Unit, OutputTransform4x3Unit};
static OutputTransFunc OutputTransFuncReluList4[] = {NULL, NULL, OutputTransform4x2ReluUnit,
OutputTransform4x3ReluUnit};
static OutputTransFunc OutputTransFuncRelu6List4[] = {NULL, NULL, OutputTransform4x2Relu6Unit,
@ -32,12 +33,14 @@ static OutputTransFunc OutputTransFuncRelu6List4[] = {NULL, NULL, OutputTransfor
static OutputTransFunc OutputTransFuncList6[] = {
NULL, NULL, OutputTransform6x2Unit, OutputTransform6x3Unit, OutputTransform6x4Unit, OutputTransform6x5Unit};
static OutputTransFunc OutputTransFuncReluList6[] = {NULL,
NULL,
OutputTransform6x2ReluUnit,
OutputTransform6x3ReluUnit,
OutputTransform6x4ReluUnit,
OutputTransform6x5ReluUnit};
static OutputTransFunc OutputTransFuncRelu6List6[] = {NULL,
NULL,
OutputTransform6x2Relu6Unit,
@ -53,6 +56,7 @@ static OutputTransFunc OutputTransFuncList8[] = {NULL,
OutputTransform8x5Unit,
OutputTransform8x6Unit,
OutputTransform8x7Unit};
static OutputTransFunc OutputTransFuncReluList8[] = {NULL,
NULL,
OutputTransform8x2ReluUnit,
@ -61,6 +65,7 @@ static OutputTransFunc OutputTransFuncReluList8[] = {NULL,
OutputTransform8x5ReluUnit,
OutputTransform8x6ReluUnit,
OutputTransform8x7ReluUnit};
static OutputTransFunc OutputTransFuncRelu6List8[] = {NULL,
NULL,
OutputTransform8x2Relu6Unit,

@ -15,9 +15,7 @@
*/
#include "src/runtime/kernel/arm/fp32/convolution.h"
#include "src/runtime/kernel/arm/fp32/convolution_slidewindow.h"
#include "src/runtime/kernel/arm/fp32/convolution_1x1.h"
#include "src/runtime/kernel/arm/fp32/convolution_3x3.h"
#include "src/runtime/kernel/arm/fp32/convolution_winograd.h"
#include "nnacl/fp32/conv.h"
#include "nnacl/common_func.h"
@ -42,17 +40,10 @@ int ConvolutionCPUKernel::InitWeightBias() {
int out_channel = filter_tensor->Batch();
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int oc_block, oc_block_num;
#ifdef ENABLE_ARM32
oc_block = C4NUM;
oc_block_num = UP_DIV(out_channel, C4NUM);
#else
oc_block = C8NUM;
oc_block_num = UP_DIV(out_channel, C8NUM);
#endif
int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane;
const int oc_block = C8NUM;
int oc_block_num = UP_DIV(out_channel, C8NUM);
int pack_weight_size = oc_block_num * oc_block * in_channel * kernel_plane;
auto origin_weight = reinterpret_cast<float *>(filter_tensor->MutableData());
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
@ -61,7 +52,7 @@ int ConvolutionCPUKernel::InitWeightBias() {
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
PackWeightFp32(origin_weight, conv_param_, packed_weight_, oc_block, oc_block_num);
RowMajor2Col8Major(origin_weight, packed_weight_, out_channel, in_channel * kernel_plane);
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float)));
if (bias_data_ == nullptr) {
@ -80,38 +71,28 @@ int ConvolutionCPUKernel::InitWeightBias() {
}
int ConvolutionCPUKernel::InitTmpBuffer() {
int out_channel = conv_param_->output_channel_;
int in_channel = conv_param_->input_channel_;
MS_ASSERT(ctx_->allocator != nullptr);
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * ic4 * C4NUM * TILE_NUM * thread_count_;
#ifdef ENABLE_ARM32
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * in_channel * C4NUM * thread_count_;
#else
int unit_size = conv_param_->kernel_h_ * conv_param_->kernel_w_ * in_channel * C12NUM * thread_count_;
#endif
packed_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
if (packed_input_ == nullptr) {
MS_LOG(ERROR) << "malloc packed input failed.";
return RET_ERROR;
}
tmp_output_block_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * TILE_NUM * out_channel * sizeof(float)));
if (tmp_output_block_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp output block failed.";
col_major_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(unit_size * sizeof(float)));
if (col_major_input_ == nullptr) {
MS_LOG(ERROR) << "malloc col_major_input_ failed.";
return RET_ERROR;
}
return RET_OK;
}
void ConvolutionCPUKernel::ConfigInputOutput() {
// set output format
auto output_tensor = out_tensors_.at(kOutputIndex);
output_tensor->SetFormat(schema::Format::Format_NHWC);
#ifdef ENABLE_ARM32
gemm_func_ = IndirectGemmFp32_8x4;
#else
gemm_func_ = IndirectGemmFp32_8x8;
#endif
}
int ConvolutionCPUKernel::Init() {
auto ret = InitWeightBias();
if (ret != RET_OK) {
@ -121,7 +102,6 @@ int ConvolutionCPUKernel::Init() {
if (!InferShapeDone()) {
return RET_OK;
}
ConfigInputOutput();
return ReSize();
}
@ -141,15 +121,11 @@ int ConvolutionCPUKernel::ReSize() {
}
int ConvolutionCPUKernel::RunImpl(int task_id) {
if (gemm_func_ == nullptr) {
MS_LOG(ERROR) << "gemm_func is nullptr.";
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = reinterpret_cast<float *>(input_tensor->MutableData());
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
ConvFp32(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), tmp_output_block_,
output_addr, task_id, conv_param_, gemm_func_);
ConvFp32(ori_input_data, packed_input_, packed_weight_, reinterpret_cast<float *>(bias_data_), col_major_input_,
output_addr, task_id, conv_param_);
return RET_OK;
}
@ -186,19 +162,6 @@ int ConvolutionCPUKernel::Run() {
return RET_OK;
}
bool CheckIfUseSlideWindow(ConvParameter *conv_param) {
int in_channel = conv_param->input_channel_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int oc4 = UP_DIV(out_channel, C4NUM);
if (out_h * out_w <= 32 || ic4 < 4 || oc4 < 4) {
return true;
}
return false;
}
kernel::LiteKernel *CpuConvFp32KernelCreator(const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, OpParameter *op_parameter,
const InnerContext *ctx, const kernel::KernelKey &desc,

@ -43,23 +43,21 @@ class ConvolutionCPUKernel : public ConvolutionBaseCPUKernel {
int RunImpl(int task_id);
int InitWeightBias();
int InitTmpBuffer();
void ConfigInputOutput();
private:
void FreeTmpBuffer() {
if (tmp_output_block_ != nullptr) {
ctx_->allocator->Free(tmp_output_block_);
tmp_output_block_ = nullptr;
}
if (packed_input_ != nullptr) {
ctx_->allocator->Free(packed_input_);
packed_input_ = nullptr;
}
if (col_major_input_ != nullptr) {
ctx_->allocator->Free(col_major_input_);
col_major_input_ = nullptr;
}
}
float *packed_input_ = nullptr;
float *packed_weight_ = nullptr;
float *tmp_output_block_ = nullptr;
GEMM_FUNC_FP32 gemm_func_ = nullptr;
float *col_major_input_ = nullptr;
};
} // namespace mindspore::kernel

@ -1,245 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/convolution_3x3.h"
#include "nnacl/fp32/conv.h"
#include "src/runtime/kernel/arm/base/layout_transform.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
namespace mindspore::kernel {
void ProcessFilter(float *origin_weight, float *dst_weight, ConvParameter *conv_param, int oc_block, int oc_block_num) {
auto input_channel = conv_param->input_channel_;
auto output_channel = conv_param->output_channel_;
auto kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
int iC4 = UP_DIV(input_channel, C4NUM);
size_t tmp_size = oc_block_num * oc_block * iC4 * C4NUM * kernel_plane * sizeof(float);
auto tmp_addr = reinterpret_cast<float *>(malloc(tmp_size));
if (tmp_addr == nullptr) {
MS_LOG(ERROR) << "malloc tmp_addr failed.";
return;
}
memset(tmp_addr, 0, tmp_size);
PackNHWCToNC4HW4Fp32(origin_weight, tmp_addr, output_channel, kernel_plane, input_channel);
Conv3x3Fp32FilterTransform(tmp_addr, dst_weight, iC4, output_channel, kernel_plane, oc_block);
free(tmp_addr);
}
int Convolution3x3CPUKernel::InitWeightBias() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
auto input_channel = filter_tensor->Channel();
auto output_channel = filter_tensor->Batch();
conv_param_->input_channel_ = input_channel;
conv_param_->output_channel_ = output_channel;
int iC4 = UP_DIV(input_channel, C4NUM);
int oC4 = UP_DIV(output_channel, C4NUM);
int oc_block, oc_block_num;
oc_block = C8NUM;
oc_block_num = UP_DIV(output_channel, C8NUM);
const int k_plane = 16;
// init weight
size_t transformed_size = iC4 * C4NUM * oc_block_num * oc_block * k_plane * sizeof(float);
transformed_filter_addr_ = reinterpret_cast<float *>(malloc(transformed_size));
if (transformed_filter_addr_ == nullptr) {
MS_LOG(ERROR) << "malloc transformed filter addr failed.";
return RET_ERROR;
}
memset(transformed_filter_addr_, 0, transformed_size);
auto weight_data = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->MutableData());
ProcessFilter(weight_data, transformed_filter_addr_, conv_param_, oc_block, oc_block_num);
// init bias
size_t new_bias_size = oC4 * C4NUM * sizeof(float);
bias_data_ = reinterpret_cast<float *>(malloc(new_bias_size));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias data failed.";
return RET_ERROR;
}
memset(bias_data_, 0, new_bias_size);
if (in_tensors_.size() == kInputSize2) {
auto ori_bias_addr = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->MutableData());
memcpy(bias_data_, ori_bias_addr, output_channel * sizeof(float));
} else {
MS_ASSERT(in_tensors_.size() == kInputSize1);
}
return RET_OK;
}
int Convolution3x3CPUKernel::InitTmpBuffer() {
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
int oC4 = UP_DIV(conv_param_->output_channel_, C4NUM);
int oC8 = UP_DIV(conv_param_->output_channel_, C8NUM);
const int k_plane = 16;
MS_ASSERT(ctx_->allocator != nullptr);
#ifdef ENABLE_ARM32
const int tile_num = 4;
#else
const int tile_num = 12;
#endif
size_t tile_buffer_size = thread_count_ * tile_num * C16NUM * ic4 * C4NUM * sizeof(float);
tile_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
if (tile_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tile buffer failed.";
return RET_ERROR;
}
size_t block_unit_buffer_size = thread_count_ * k_plane * C4NUM * sizeof(float);
block_unit_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(block_unit_buffer_size));
if (block_unit_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc block_unit_buffer_ failed.";
return RET_ERROR;
}
size_t tmp_dst_buffer_size = thread_count_ * tile_num * k_plane * oC8 * C8NUM * sizeof(float);
tmp_dst_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tmp_dst_buffer_size));
if (tmp_dst_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed.";
return RET_ERROR;
}
size_t col_buffer_size = thread_count_ * tile_num * C4NUM * ic4 * sizeof(float);
col_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(col_buffer_size));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
return RET_ERROR;
}
size_t nc4hw4_out_size =
oC4 * C4NUM * conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * sizeof(float);
nc4hw4_out_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(nc4hw4_out_size));
if (nc4hw4_out_ == nullptr) {
MS_LOG(ERROR) << "malloc nc4hw4_out_ failed.";
return RET_ERROR;
}
tmp_buffer_address_list_[0] = tile_buffer_;
tmp_buffer_address_list_[1] = block_unit_buffer_;
tmp_buffer_address_list_[2] = tmp_dst_buffer_;
tmp_buffer_address_list_[3] = nc4hw4_out_;
tmp_buffer_address_list_[4] = col_buffer_;
return RET_OK;
}
int Convolution3x3CPUKernel::Init() {
auto ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.ret: " << ret;
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
return ReSize();
}
int Convolution3x3CPUKernel::ReSize() {
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Resize is invalid.";
return ret;
}
ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.ret: " << ret;
return RET_ERROR;
}
return RET_OK;
}
int Convolution3x3CPUKernel::RunImpl(int task_id) {
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = reinterpret_cast<float *>(input_tensor->MutableData());
Conv3x3Fp32(ori_input_data, transformed_filter_addr_, reinterpret_cast<float *>(bias_data_), tmp_buffer_address_list_,
task_id, conv_param_);
return RET_OK;
}
int Convolution3x3Impl(void *cdata, int task_id) {
auto conv3x3 = reinterpret_cast<Convolution3x3CPUKernel *>(cdata);
auto error_code = conv3x3->RunImpl(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Convolution3x3 Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int Convolution3x3CPUKernel::PostProcess() {
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
auto act_type = conv_param_->act_type_;
switch (act_type) {
case ActType_No:
PackNC4HW4ToNHWCFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
break;
case ActType_Relu:
PackNC4HW4ToNHWCReluFp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
break;
case ActType_Relu6:
PackNC4HW4ToNHWCRelu6Fp32(nc4hw4_out_, output_addr, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
break;
default:
MS_LOG(ERROR) << "Unsupport activation type.";
return RET_ERROR;
}
return RET_OK;
}
int Convolution3x3CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
auto ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.ret: " << ret;
return RET_ERROR;
}
int error_code = ParallelLaunch(this->context_->thread_pool_, Convolution3x3Impl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv3x3 error error_code[" << error_code << "]";
FreeTmpBuffer();
return RET_ERROR;
}
ret = PostProcess();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Post process failed.";
FreeTmpBuffer();
return ret;
}
FreeTmpBuffer();
return RET_OK;
}
} // namespace mindspore::kernel

@ -1,80 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_
#include <vector>
#include "src/lite_kernel.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "nnacl/winograd_transform.h"
namespace mindspore::kernel {
class Convolution3x3CPUKernel : public ConvolutionBaseCPUKernel {
public:
Convolution3x3CPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~Convolution3x3CPUKernel() override {
if (transformed_filter_addr_ != nullptr) {
free(transformed_filter_addr_);
}
}
int Init() override;
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
int InitWeightBias();
int InitTmpBuffer();
int PostProcess();
private:
void FreeTmpBuffer() {
if (tile_buffer_ != nullptr) {
ctx_->allocator->Free(tile_buffer_);
tile_buffer_ = nullptr;
}
if (block_unit_buffer_ != nullptr) {
ctx_->allocator->Free(block_unit_buffer_);
block_unit_buffer_ = nullptr;
}
if (tmp_dst_buffer_ != nullptr) {
ctx_->allocator->Free(tmp_dst_buffer_);
tmp_dst_buffer_ = nullptr;
}
if (nc4hw4_out_ != nullptr) {
ctx_->allocator->Free(nc4hw4_out_);
nc4hw4_out_ = nullptr;
}
if (col_buffer_ != nullptr) {
ctx_->allocator->Free(col_buffer_);
col_buffer_ = nullptr;
}
}
float *transformed_filter_addr_ = nullptr;
float *tile_buffer_ = nullptr;
float *block_unit_buffer_ = nullptr;
float *tmp_dst_buffer_ = nullptr;
float *col_buffer_ = nullptr;
float *nc4hw4_out_ = nullptr;
TmpBufferAddress tmp_buffer_address_list_[5];
};
void ProcessFilter(float *origin_weight, float *dst_weight, ConvParameter *conv_param, int oc_block, int oc_block_num);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_3X3_H_

@ -1,201 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/runtime/kernel/arm/fp32/convolution_slidewindow.h"
#include "nnacl/common_func.h"
#include "schema/model_generated.h"
#include "src/kernel_registry.h"
#include "include/errorcode.h"
#include "src/runtime/runtime_api.h"
namespace mindspore::kernel {
using mindspore::kernel::KERNEL_ARCH::kCPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_INFER_INVALID;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_Conv2D;
int ConvolutionSWCPUKernel::InitWeightBias() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
auto input_channel = filter_tensor->Channel();
auto output_channel = filter_tensor->Batch();
int kernel_h = filter_tensor->Height();
int kernel_w = filter_tensor->Width();
conv_param_->input_channel_ = input_channel;
conv_param_->output_channel_ = output_channel;
int ic4 = UP_DIV(input_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int oc_block = C4NUM;
int oc_block_num = UP_DIV(output_channel, C4NUM);
int pack_weight_size = oc_block_num * oc_block * ic4 * C4NUM * kernel_plane;
auto origin_weight = reinterpret_cast<float *>(in_tensors_.at(kWeightIndex)->MutableData());
packed_weight_ = reinterpret_cast<float *>(malloc(pack_weight_size * sizeof(float)));
if (packed_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc packed weight failed.";
return RET_ERROR;
}
memset(packed_weight_, 0, pack_weight_size * sizeof(float));
for (int oc = 0; oc < output_channel; ++oc) {
int src_oc_offset = oc * kernel_h * kernel_w * input_channel;
int dst_oc_offset = oc * kernel_h * kernel_w * ic4 * C4NUM;
for (int i = 0; i < kernel_h * kernel_w; ++i) {
const float *src = origin_weight + src_oc_offset + i * input_channel;
float *dst = packed_weight_ + dst_oc_offset + i * ic4 * C4NUM;
memcpy(dst, src, input_channel * sizeof(float));
}
}
bias_data_ = reinterpret_cast<float *>(malloc(oc_block_num * oc_block * sizeof(float)));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "malloc bias failed.";
return RET_ERROR;
}
memset(bias_data_, 0, oc_block_num * oc_block * sizeof(float));
if (in_tensors_.size() == kInputSize2) {
auto ori_bias = reinterpret_cast<float *>(in_tensors_.at(kBiasIndex)->MutableData());
memcpy(bias_data_, ori_bias, output_channel * sizeof(float));
} else {
MS_ASSERT(in_tensors_.size() == kInputSize1);
}
return RET_OK;
}
int ConvolutionSWCPUKernel::InitTmpBuffer() {
int out_channel = conv_param_->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM);
MS_ASSERT(ctx_->allocator != nullptr);
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
size_t nhwc4_input_size =
ic4 * C4NUM * conv_param_->input_batch_ * conv_param_->input_h_ * conv_param_->input_w_ * sizeof(float);
nhwc4_input_ = ctx_->allocator->Malloc(nhwc4_input_size);
if (nhwc4_input_ == nullptr) {
MS_LOG(ERROR) << "malloc nhwc4 input failed.";
return RET_ERROR;
}
tmp_output_block_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(
conv_param_->output_batch_ * conv_param_->output_h_ * conv_param_->output_w_ * oc4 * C4NUM * sizeof(float)));
if (tmp_output_block_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp output block failed.";
return RET_ERROR;
}
return RET_OK;
}
void ConvolutionSWCPUKernel::ConfigInputOutput() {
// set output format
auto output_tensor = out_tensors_.at(kOutputIndex);
output_tensor->SetFormat(schema::Format::Format_NHWC);
}
int ConvolutionSWCPUKernel::Init() {
auto ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
return RET_ERROR;
}
if (!InferShapeDone()) {
return RET_OK;
}
// config input output
ConfigInputOutput();
return ReSize();
}
int ConvolutionSWCPUKernel::ReSize() {
auto ret = ConvolutionBaseCPUKernel::CheckResizeValid();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Resize is invalid.";
return ret;
}
if (slidingWindow_param_ != nullptr) {
delete slidingWindow_param_;
slidingWindow_param_ = nullptr;
}
ret = ConvolutionBaseCPUKernel::Init();
if (ret != RET_OK) {
MS_LOG(ERROR) << "ConvolutionBase init failed.";
return RET_ERROR;
}
// init sliding window param
slidingWindow_param_ = new (std::nothrow) SlidingWindowParam;
if (slidingWindow_param_ == nullptr) {
MS_LOG(ERROR) << "new SlidingWindowParam fail!";
return RET_ERROR;
}
InitSlidingParamConv(slidingWindow_param_, conv_param_, C4NUM);
return RET_OK;
}
int ConvolutionSWCPUKernel::RunImpl(int task_id) {
auto output_addr = reinterpret_cast<float *>(out_tensors_.at(kOutputIndex)->MutableData());
ConvSWFp32(reinterpret_cast<float *>(nhwc4_input_), packed_weight_, reinterpret_cast<float *>(bias_data_),
tmp_output_block_, output_addr, task_id, conv_param_, slidingWindow_param_);
return RET_OK;
}
int ConvolutionSWImpl(void *cdata, int task_id) {
auto conv = reinterpret_cast<ConvolutionSWCPUKernel *>(cdata);
auto error_code = conv->RunImpl(task_id);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "Convolution Sliding Window Run error task_id[" << task_id << "] error_code[" << error_code << "]";
return RET_ERROR;
}
return RET_OK;
}
int ConvolutionSWCPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
MS_LOG(ERROR) << "Prepare fail!ret: " << prepare_ret;
return prepare_ret;
}
// init tmp input, output
auto ret = InitTmpBuffer();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init tmp buffer failed.";
return RET_ERROR;
}
auto input_tensor = in_tensors_.at(kInputIndex);
auto ori_input_data = input_tensor->MutableData();
PackNHWCToNHWC4Fp32(ori_input_data, nhwc4_input_, conv_param_->input_batch_,
conv_param_->input_h_ * conv_param_->input_w_, conv_param_->input_channel_);
int error_code = ParallelLaunch(this->context_->thread_pool_, ConvolutionSWImpl, this, thread_count_);
if (error_code != RET_OK) {
MS_LOG(ERROR) << "conv error error_code[" << error_code << "]";
FreeTmpBuffer();
return RET_ERROR;
}
auto out_tensor = out_tensors_.front();
auto out_data = reinterpret_cast<float *>(out_tensor->MutableData());
int oc4_res = conv_param_->output_channel_ % C4NUM;
if (oc4_res != 0) {
PackNHWC4ToNHWCFp32(tmp_output_block_, out_data, conv_param_->output_batch_,
conv_param_->output_h_ * conv_param_->output_w_, conv_param_->output_channel_);
}
FreeTmpBuffer();
return RET_OK;
}
} // namespace mindspore::kernel

@ -1,70 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_SLIDEWINDOW_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_SLIDEWINDOW_H_
#include <vector>
#include "src/lite_kernel.h"
#include "nnacl/op_base.h"
#include "src/runtime/kernel/arm/base/convolution_base.h"
#include "nnacl/fp32/conv.h"
#include "nnacl/fp32/conv_depthwise.h"
namespace mindspore::kernel {
class ConvolutionSWCPUKernel : public ConvolutionBaseCPUKernel {
public:
ConvolutionSWCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx,
const mindspore::lite::PrimitiveC *primitive)
: ConvolutionBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {}
~ConvolutionSWCPUKernel() override {
if (packed_weight_ != nullptr) {
free(packed_weight_);
packed_weight_ = nullptr;
}
if (slidingWindow_param_ != nullptr) {
delete slidingWindow_param_;
slidingWindow_param_ = nullptr;
}
}
int Init() override;
int ReSize() override;
int Run() override;
int RunImpl(int task_id);
int InitWeightBias();
int InitTmpBuffer();
void ConfigInputOutput();
private:
void FreeTmpBuffer() {
if (nhwc4_input_ != nullptr) {
ctx_->allocator->Free(nhwc4_input_);
nhwc4_input_ = nullptr;
}
if (tmp_output_block_ != nullptr) {
ctx_->allocator->Free(tmp_output_block_);
tmp_output_block_ = nullptr;
}
}
float *packed_weight_ = nullptr;
float *tmp_output_block_ = nullptr;
SlidingWindowParam *slidingWindow_param_ = nullptr;
};
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_CONVOLUTION_SLIDEWINDOW_H_

@ -107,35 +107,6 @@ TEST_F(TestPack, PackInputFp32) {
MS_LOG(INFO) << "TestPackInputFp32 passed";
}
TEST_F(TestPack, PackWeightFp32) {
auto conv_param = new ConvParameter;
InitConvParamPack(conv_param);
int k_h = conv_param->kernel_h_;
int k_w = conv_param->kernel_w_;
int in_channel = conv_param->input_channel_;
int out_channel = conv_param->output_channel_;
int ic4 = UP_DIV(in_channel, C4NUM);
int oc8 = UP_DIV(out_channel, C8NUM);
size_t weight_size;
std::string weight_path = "./test_data/conv/convfp32_weight_32_3_3_3.bin";
auto weight_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(weight_path.c_str(), &weight_size));
auto packed_weight = reinterpret_cast<float *>(malloc(k_h * k_w * ic4 * C4NUM * oc8 * C8NUM * sizeof(float)));
PackWeightFp32(weight_data, conv_param, packed_weight, C8NUM, oc8);
printf("==================output data=================\n");
for (int i = 0; i < 20; i++) {
std::cout << packed_weight[i] << " ,";
}
std::cout << std::endl;
free(packed_weight);
delete conv_param;
MS_LOG(INFO) << "TestPackWeightFp32 passed";
}
#ifdef ENABLE_FP16
TEST_F(TestPack, PackInputFp16) {
size_t input_size;

Loading…
Cancel
Save