|
|
|
@ -18,6 +18,7 @@
|
|
|
|
|
#include <string.h>
|
|
|
|
|
#include "nnacl/fp32/common_func.h"
|
|
|
|
|
#include "nnacl/winograd_transform.h"
|
|
|
|
|
#include "nnacl/fp32/matmul.h"
|
|
|
|
|
|
|
|
|
|
void SWBorderPixel(float *dst, const float *src, const float *weight, const float *bias, int height, int width,
|
|
|
|
|
int in_kh_step, int in_kw_step, int kernel_h, int kernel_w, int ic4, bool is_relu, bool is_relu6) {
|
|
|
|
@ -57,16 +58,18 @@ void SWBorderPixel(float *dst, const float *src, const float *weight, const floa
|
|
|
|
|
void SWBorder(float *dst, const float *src, const float *weight, const float *bias, int top, int bottom, int left,
|
|
|
|
|
int right, const ConvParameter *conv_param, const SlidingWindowParam *sliding) {
|
|
|
|
|
int ic4 = sliding->ic4_channel_ / C4NUM;
|
|
|
|
|
bool relu = conv_param->act_type_ == ActType_Relu;
|
|
|
|
|
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
|
|
|
|
float *dst_h = dst + top * sliding->out_h_step_;
|
|
|
|
|
for (int oh = top; oh < bottom; oh++) {
|
|
|
|
|
int ih = oh * conv_param->stride_h_ - conv_param->pad_h_;
|
|
|
|
|
int ih = oh * conv_param->stride_h_ - conv_param->pad_u_;
|
|
|
|
|
int start_kh = MSMAX(0, UP_DIV(-ih, conv_param->dilation_h_));
|
|
|
|
|
int end_kh = MSMIN(conv_param->kernel_h_, UP_DIV(conv_param->input_h_ - ih, conv_param->dilation_h_));
|
|
|
|
|
const float *src_h = src + ih * sliding->in_h_step_;
|
|
|
|
|
|
|
|
|
|
float *dst_kernel = dst_h + left * sliding->block_channel_;
|
|
|
|
|
for (int ow = left; ow < right; ow++) {
|
|
|
|
|
int iw = ow * conv_param->stride_w_ - conv_param->pad_w_;
|
|
|
|
|
int iw = ow * conv_param->stride_w_ - conv_param->pad_l_;
|
|
|
|
|
int start_kw = MSMAX(0, UP_DIV(-iw, conv_param->dilation_w_));
|
|
|
|
|
int end_kw = MSMIN(conv_param->kernel_w_, UP_DIV(conv_param->input_w_ - iw, conv_param->dilation_w_));
|
|
|
|
|
const float *src_w = src_h + iw * sliding->ic4_channel_;
|
|
|
|
@ -75,8 +78,8 @@ void SWBorder(float *dst, const float *src, const float *weight, const float *bi
|
|
|
|
|
const float *weight_kernel = weight + (start_kh * conv_param->kernel_w_ + start_kw) * sliding->ic4_channel_;
|
|
|
|
|
|
|
|
|
|
SWBorderPixel(dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw,
|
|
|
|
|
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, ic4,
|
|
|
|
|
conv_param->is_relu_, conv_param->is_relu6_);
|
|
|
|
|
sliding->in_kh_step_, sliding->in_kw_step_, conv_param->kernel_h_, conv_param->kernel_w_, ic4, relu,
|
|
|
|
|
relu6);
|
|
|
|
|
|
|
|
|
|
dst_kernel += sliding->block_channel_;
|
|
|
|
|
} // width loop
|
|
|
|
@ -144,6 +147,8 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
|
|
|
|
|
float *output_data, int task_id, ConvParameter *conv_param, SlidingWindowParam *slidingWindow_param) {
|
|
|
|
|
int ic4 = slidingWindow_param->ic4_channel_ / C4NUM;
|
|
|
|
|
int oc4_res = conv_param->output_channel_ % C4NUM;
|
|
|
|
|
bool relu = conv_param->act_type_ == ActType_Relu;
|
|
|
|
|
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
|
|
|
|
const float *src = input_data;
|
|
|
|
|
float *dst = NULL;
|
|
|
|
|
if (oc4_res == 0) {
|
|
|
|
@ -169,28 +174,26 @@ void ConvSWFp32(const float *input_data, const float *packed_weight, const float
|
|
|
|
|
|
|
|
|
|
if (slidingWindow_param->right_ > slidingWindow_param->left_ &&
|
|
|
|
|
slidingWindow_param->bottom_ > slidingWindow_param->top_) {
|
|
|
|
|
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_h_;
|
|
|
|
|
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_w_;
|
|
|
|
|
int in_h_start = slidingWindow_param->top_ * conv_param->stride_h_ - conv_param->pad_u_;
|
|
|
|
|
int in_w_start = slidingWindow_param->left_ * conv_param->stride_w_ - conv_param->pad_l_;
|
|
|
|
|
const float *in_t =
|
|
|
|
|
src_data + in_h_start * slidingWindow_param->in_h_step_ + in_w_start * slidingWindow_param->ic4_channel_;
|
|
|
|
|
float *out_t = dst_data + slidingWindow_param->top_ * slidingWindow_param->out_h_step_ +
|
|
|
|
|
slidingWindow_param->left_ * slidingWindow_param->block_channel_;
|
|
|
|
|
#ifdef ENABLE_ARM64
|
|
|
|
|
ConvSwFp32Center(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
|
|
|
|
|
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_,
|
|
|
|
|
conv_param->kernel_w_, slidingWindow_param->out_h_step_ * sizeof(float),
|
|
|
|
|
slidingWindow_param->block_channel_ * sizeof(float), ic4,
|
|
|
|
|
slidingWindow_param->in_sh_step_ * sizeof(float),
|
|
|
|
|
slidingWindow_param->in_sw_step_ * sizeof(float),
|
|
|
|
|
slidingWindow_param->in_kh_step_ * sizeof(float),
|
|
|
|
|
slidingWindow_param->in_kw_step_ * sizeof(float),
|
|
|
|
|
conv_param->is_relu_, conv_param->is_relu6_);
|
|
|
|
|
ConvSwFp32Center(
|
|
|
|
|
out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
|
|
|
|
|
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, conv_param->kernel_w_,
|
|
|
|
|
slidingWindow_param->out_h_step_ * sizeof(float), slidingWindow_param->block_channel_ * sizeof(float), ic4,
|
|
|
|
|
slidingWindow_param->in_sh_step_ * sizeof(float), slidingWindow_param->in_sw_step_ * sizeof(float),
|
|
|
|
|
slidingWindow_param->in_kh_step_ * sizeof(float), slidingWindow_param->in_kw_step_ * sizeof(float), relu,
|
|
|
|
|
relu6);
|
|
|
|
|
#else
|
|
|
|
|
SWCenter(out_t, in_t, weight, bias, slidingWindow_param->bottom_ - slidingWindow_param->top_,
|
|
|
|
|
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_,
|
|
|
|
|
conv_param->kernel_w_, slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, ic4,
|
|
|
|
|
slidingWindow_param->right_ - slidingWindow_param->left_, conv_param->kernel_h_, conv_param->kernel_w_,
|
|
|
|
|
slidingWindow_param->out_h_step_, slidingWindow_param->block_channel_, ic4,
|
|
|
|
|
slidingWindow_param->in_sh_step_, slidingWindow_param->in_sw_step_, slidingWindow_param->in_kh_step_,
|
|
|
|
|
slidingWindow_param->in_kw_step_, conv_param->is_relu_, conv_param->is_relu6_);
|
|
|
|
|
slidingWindow_param->in_kw_step_, relu, relu6);
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
} // output C4 loop
|
|
|
|
@ -219,6 +222,8 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|
|
|
|
int kernel_plane = kernel_h * kernel_w;
|
|
|
|
|
int unit_size = kernel_plane * ic4 * C4NUM;
|
|
|
|
|
int packed_input_size = output_tile_count * TILE_NUM * unit_size;
|
|
|
|
|
bool relu = conv_param->act_type_ == ActType_Relu;
|
|
|
|
|
bool relu6 = conv_param->act_type_ == ActType_Relu6;
|
|
|
|
|
|
|
|
|
|
// we accumulate 4 channels per time for input blocks
|
|
|
|
|
int conv_depth = kernel_h * kernel_w;
|
|
|
|
@ -240,11 +245,11 @@ void ConvFp32(float *input_data, float *packed_input, float *packed_weight, cons
|
|
|
|
|
if (real_cal_num == TILE_NUM) {
|
|
|
|
|
float *gemm_output = output_data + out_offset;
|
|
|
|
|
gemm_func(gemm_output, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0, 0,
|
|
|
|
|
conv_param->is_relu_, conv_param->is_relu6_);
|
|
|
|
|
relu, relu6);
|
|
|
|
|
} else {
|
|
|
|
|
// res part
|
|
|
|
|
gemm_func(tmp_out_block, gemm_input, packed_weight, bias_data, conv_depth, ic4, out_channel, output_offset, 0,
|
|
|
|
|
0, conv_param->is_relu_, conv_param->is_relu6_);
|
|
|
|
|
0, relu, relu6);
|
|
|
|
|
memcpy(output_data + out_offset, tmp_out_block, real_cal_num * out_channel * sizeof(float));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -264,34 +269,42 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
|
|
|
|
|
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
|
|
|
|
|
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
|
|
|
|
|
int output_count = out_w_block * out_h_block;
|
|
|
|
|
int output_tile_count = UP_DIV(output_count, TILE_NUM);
|
|
|
|
|
int output_tile_count = UP_DIV(output_count, C12NUM);
|
|
|
|
|
int out_channel = conv_param->output_channel_;
|
|
|
|
|
int oc4 = UP_DIV(out_channel, C4NUM);
|
|
|
|
|
int oc8 = UP_DIV(out_channel, C8NUM);
|
|
|
|
|
int input_unit_square = input_unit * input_unit;
|
|
|
|
|
size_t output_offset = oc4 * C4NUM * input_unit_square * sizeof(float);
|
|
|
|
|
|
|
|
|
|
float *trans_input = buffer_list[0];
|
|
|
|
|
float *gemm_out = buffer_list[1];
|
|
|
|
|
float *tmp_out_data = buffer_list[2];
|
|
|
|
|
float *tmp_data = buffer_list[3];
|
|
|
|
|
int trans_input_offset = TILE_NUM * input_unit_square * ic4 * C4NUM;
|
|
|
|
|
int gemm_out_offset = TILE_NUM * input_unit_square * oc4 * C4NUM;
|
|
|
|
|
float *col_buffer = buffer_list[4];
|
|
|
|
|
int trans_input_offset = C12NUM * input_unit_square * ic4 * C4NUM;
|
|
|
|
|
int gemm_out_offset = C12NUM * input_unit_square * oc8 * C8NUM;
|
|
|
|
|
int tmp_data_offset = input_unit_square * C4NUM;
|
|
|
|
|
int col_buffer_offset = C12NUM * ic4 * C4NUM;
|
|
|
|
|
// step 1 : filter transform (pre-processed offline)
|
|
|
|
|
// step 2 : input transform (online)
|
|
|
|
|
for (int b = 0; b < in_batch; b++) {
|
|
|
|
|
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
|
|
|
|
|
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM;
|
|
|
|
|
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
|
|
|
|
|
int out_tile_index = thread_id * TILE_NUM;
|
|
|
|
|
int cal_num = output_count - thread_id * TILE_NUM;
|
|
|
|
|
cal_num = cal_num > TILE_NUM ? TILE_NUM : cal_num;
|
|
|
|
|
int out_tile_index = thread_id * C12NUM;
|
|
|
|
|
int cal_num = output_count - thread_id * C12NUM;
|
|
|
|
|
cal_num = cal_num > C12NUM ? C12NUM : cal_num;
|
|
|
|
|
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
|
|
|
|
|
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
|
|
|
|
|
input_trans_func);
|
|
|
|
|
// step 3 : gemm
|
|
|
|
|
gemm_func(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset, trans_weight, NULL,
|
|
|
|
|
input_unit_square, ic4, oc4 * C4NUM, output_offset, 1, 1, 0, 0);
|
|
|
|
|
float *src_ptr = trans_input + task_id * trans_input_offset;
|
|
|
|
|
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
|
|
|
|
|
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
|
|
|
|
|
for (int i = 0; i < input_unit_square; ++i) {
|
|
|
|
|
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
|
|
|
|
|
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
|
|
|
|
|
C12NUM, oc8 * C8NUM, input_unit_square, 2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// step 4 : output transform
|
|
|
|
|
WinogradOutputTransform(gemm_out + task_id * gemm_out_offset, tmp_out_data + tmp_out_batch_offset, bias_data,
|
|
|
|
@ -442,18 +455,21 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
|
|
|
|
|
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
|
|
|
|
|
int output_channel = conv_param->output_channel_;
|
|
|
|
|
int oc4 = UP_DIV(output_channel, C4NUM);
|
|
|
|
|
int oc8 = UP_DIV(output_channel, C8NUM);
|
|
|
|
|
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
|
|
|
|
|
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
|
|
|
|
|
int output_count = out_w_block * out_h_block;
|
|
|
|
|
int output_tile_count = UP_DIV(output_count, TILE_NUM);
|
|
|
|
|
int output_tile_count = UP_DIV(output_count, C12NUM);
|
|
|
|
|
const int input_unit_square = 4 * 4;
|
|
|
|
|
float *tile_buffer = buffer_list[0];
|
|
|
|
|
float *block_unit_buffer = buffer_list[1];
|
|
|
|
|
float *tmp_dst_buffer = buffer_list[2];
|
|
|
|
|
float *nc4hw4_out = buffer_list[3];
|
|
|
|
|
int tile_buffer_offset = TILE_NUM * input_unit_square * ic4 * C4NUM;
|
|
|
|
|
float *col_buffer = buffer_list[4];
|
|
|
|
|
int tile_buffer_offset = C12NUM * input_unit_square * ic4 * C4NUM;
|
|
|
|
|
int block_unit_buffer_offset = input_unit_square * C4NUM;
|
|
|
|
|
int tmp_dst_buffer_offset = TILE_NUM * input_unit_square * oc4 * C4NUM;
|
|
|
|
|
int tmp_dst_buffer_offset = C12NUM * input_unit_square * oc8 * C8NUM;
|
|
|
|
|
int col_buffer_offset = C12NUM * ic4 * C4NUM;
|
|
|
|
|
|
|
|
|
|
int input_batch = conv_param->input_batch_;
|
|
|
|
|
for (int batch = 0; batch < input_batch; batch++) {
|
|
|
|
@ -461,15 +477,20 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
|
|
|
|
|
int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_;
|
|
|
|
|
|
|
|
|
|
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
|
|
|
|
|
int start_index = thread_id * TILE_NUM;
|
|
|
|
|
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
|
|
|
|
|
int start_index = thread_id * C12NUM;
|
|
|
|
|
int real_cal_num = (output_count - start_index) < C12NUM ? (output_count - start_index) : C12NUM;
|
|
|
|
|
Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
|
|
|
|
|
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
|
|
|
|
|
out_w_block, conv_param);
|
|
|
|
|
|
|
|
|
|
gemm_func(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
|
|
|
|
|
transed_weight, NULL, input_unit_square, ic4, oc4 * C4NUM,
|
|
|
|
|
oc4 * C4NUM * input_unit_square * sizeof(float), 1, 1, 0, 0);
|
|
|
|
|
float *src_ptr = tile_buffer + task_id * tile_buffer_offset;
|
|
|
|
|
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
|
|
|
|
|
float *dst_ptr = tmp_dst_buffer + task_id * tmp_dst_buffer_offset;
|
|
|
|
|
for (int i = 0; i < input_unit_square; ++i) {
|
|
|
|
|
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
|
|
|
|
|
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
|
|
|
|
|
ic4 * C4NUM, C12NUM, oc8 * C8NUM, input_unit_square, 2);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Conv3x3Fp32OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, nc4hw4_out + nc4hw4_buffer_offset,
|
|
|
|
|
bias_data, start_index, real_cal_num, out_w_block, conv_param);
|
|
|
|
|