replace gemm with matmul for fp16 conv winograd

pull/7043/head
fuzhiye 4 years ago
parent 86bbd1dc98
commit 1f9a122f17

@ -170,7 +170,6 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
int input_unit = conv_param->input_unit_;
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
int ic8 = UP_DIV(in_channel, C8NUM);
int out_unit = conv_param->output_unit_;
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
@ -179,18 +178,19 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
int out_channel = conv_param->output_channel_;
int oc8 = UP_DIV(out_channel, C8NUM);
int input_unit_square = input_unit * input_unit;
size_t output_offset = oc8 * C8NUM * input_unit_square * sizeof(float16_t);
float16_t *trans_input = buffer_list[0];
float16_t *gemm_out = buffer_list[1];
float16_t *tmp_data = buffer_list[2];
int trans_input_offset = tile_num * input_unit_square * ic8 * C8NUM;
float16_t *col_buffer = buffer_list[3];
int trans_input_offset = tile_num * input_unit_square * in_channel;
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
int tmp_data_offset = input_unit_square * C8NUM;
int col_buffer_offset = tile_num * in_channel;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_;
int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_tile_index = thread_id * tile_num;
@ -200,8 +200,14 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
in_func);
// step 3 : gemm
IndirectGemmFp16_16x8(gemm_out + task_id * gemm_out_offset, trans_input + task_id * trans_input_offset,
trans_weight, NULL, input_unit_square, ic8 * 2, oc8 * C8NUM, output_offset, 1, 1, 0, 0);
float16_t *src_ptr = trans_input + task_id * trans_input_offset;
float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset;
float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, tile_num, in_channel);
MatMul16x8(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel,
cal_num, oc8 * C8NUM, input_unit_square, false);
}
// step 4 : output transform
WinogradOutputTransformFp16(gemm_out + task_id * gemm_out_offset, output_data + out_batch_offset, bias_data,

@ -41,8 +41,8 @@ void ColMajor2Row8MajorFp16(void *src_ptr, float16_t *dst_ptr, size_t row, size_
void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, bool write_nhwc) {
int row_16 = UP_ROUND(row, C16NUM);
int col_8 = UP_ROUND(col, C8NUM);
// int row_16 = UP_ROUND(row, C16NUM);
// int col_8 = UP_ROUND(col, C8NUM);
if (write_nhwc) {
/* col16-major * row8-major => col-major */
for (int r = 0; r < row; r++) {
@ -63,24 +63,42 @@ void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const fl
}
}
} else {
/* col16-major * row8-major => row16x8-major */
for (int r = 0; r < row_16; r++) {
for (int c = 0; c < col_8; c++) {
int r16div = r / C16NUM, r16mod = r % C16NUM;
int c8div = c / C8NUM, c8mod = c % C8NUM;
size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod;
float16_t value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod;
size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
for (int i = 0; i < row; ++i) {
int src_r_offset = i;
int dst_r_offset = i * col * stride;
for (int j = 0; j < col; ++j) {
int c8div = j / 8, c8mod = j % 8;
size_t ci = dst_r_offset + c8div * 8 * stride + c8mod;
float value = 0;
for (int d = 0; d < deep; ++d) {
size_t ai = src_r_offset + d * C16NUM;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[col];
if (bias != NULL) value += bias[j];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
// /* col16-major * row8-major => row16x8-major */
// for (int r = 0; r < row_16; r++) {
// for (int c = 0; c < col_8; c++) {
// int r16div = r / C16NUM, r16mod = r % C16NUM;
// int c8div = c / C8NUM, c8mod = c % C8NUM;
// size_t ci = c8div * row_16 * C8NUM + r * C8NUM + c8mod;
// float16_t value = 0;
// for (int d = 0; d < deep; d++) {
// size_t ai = r16div * deep * C16NUM + d * C16NUM + r16mod;
// size_t bi = c8div * deep * C8NUM + d * C8NUM + c8mod;
// value = value + a[ai] * b[bi];
// }
// if (bias != NULL) value += bias[col];
// if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
// if (act_type != ActType_No) value = MSMAX(0.0f, value);
// dst[ci] = value;
// }
// }
}
return;
}

@ -29,6 +29,9 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMul16x8(const float16_t *a, const float16_t *b, float16_t *dst, const float16_t *bias, ActType act_type,
int deep, int row, int col, int stride, bool write_nhwc);
void MatMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const float16_t *bias, ActType act_type,
int depth, int row, int col, int stride, bool write_nhwc);

@ -594,7 +594,7 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
int interval_y_e = src_y_e < input_h ? input_unit : (input_h - src_y_s);
int src_plane_offset = in_channel * (src_y_s * input_w + src_x_s);
int dst_plane_offset = c * C4NUM;
int dst_plane_offset = c * in_channel;
for (int ic = 0; ic < ic8; ic++) {
// clear tmp buffer
memset(tmp_data, 0, input_unit * input_unit * C8NUM * sizeof(float16_t));
@ -622,6 +622,30 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
#endif
}
}
} else if (real_c < 8 && real_c >= 4) {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
int dst_y_offset = interval * input_unit * C8NUM + interval_x_s * C8NUM;
for (int j = 0; j < (interval_x_e - interval_x_s); j++) {
int src_x_offset = src_y_offset + j * in_channel;
int dst_x_offset = dst_y_offset + j * C8NUM;
const float16_t *src_addr = input_data + src_x_offset;
float16_t *dst_addr = tmp_data + dst_x_offset;
int rc = real_c - 4;
#ifdef ENABLE_NEON
vst1_f16(dst_addr, vld1_f16(src_addr));
#else
for (int k = 0; k < C4NUM; k++) {
dst_addr[k] = src_addr[k];
}
#endif
src_addr += 4;
dst_addr += 4;
for (int i = 0; i < rc; ++i) {
dst_addr[i] = src_addr[i];
}
}
}
} else {
for (int interval = interval_y_s; interval < interval_y_e; interval++) {
int src_y_offset = src_ic8_offset + (interval * input_w + interval_x_s) * in_channel;
@ -639,10 +663,10 @@ void WinogradInputTransformFp16(const float16_t *input_data, float16_t *trans_in
}
// input transform
int dst_ic8_offset = dst_plane_offset + ic * tile_num * C8NUM;
size_t dst_step = ic8 * C8NUM * tile_num;
int dst_ic8_offset = dst_plane_offset + ic * C8NUM;
size_t dst_step = in_channel * tile_num;
float16_t *trans_input_ptr = trans_input + dst_ic8_offset;
func(tmp_data, trans_input_ptr, C8NUM, dst_step);
func(tmp_data, trans_input_ptr, C8NUM, dst_step, real_c);
}
out_tile_index++;
} // cal_tile_num loop

File diff suppressed because it is too large Load Diff

@ -26,7 +26,8 @@
#ifdef __cplusplus
extern "C" {
#endif
typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
typedef void (*InputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step,
int real_c);
typedef void (*OutputTransFp16Func)(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c);
@ -56,6 +57,24 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da
src[14] = vld1q_f16(src_data + 14 * src_step); \
src[15] = vld1q_f16(src_data + 15 * src_step);
#define Load16DataC4Fp16 \
src[0] = vld1_f16(src_data + 0 * src_step); \
src[1] = vld1_f16(src_data + 1 * src_step); \
src[2] = vld1_f16(src_data + 2 * src_step); \
src[3] = vld1_f16(src_data + 3 * src_step); \
src[4] = vld1_f16(src_data + 4 * src_step); \
src[5] = vld1_f16(src_data + 5 * src_step); \
src[6] = vld1_f16(src_data + 6 * src_step); \
src[7] = vld1_f16(src_data + 7 * src_step); \
src[8] = vld1_f16(src_data + 8 * src_step); \
src[9] = vld1_f16(src_data + 9 * src_step); \
src[10] = vld1_f16(src_data + 10 * src_step); \
src[11] = vld1_f16(src_data + 11 * src_step); \
src[12] = vld1_f16(src_data + 12 * src_step); \
src[13] = vld1_f16(src_data + 13 * src_step); \
src[14] = vld1_f16(src_data + 14 * src_step); \
src[15] = vld1_f16(src_data + 15 * src_step);
#define Load36DataFp16 \
src[0] = vld1q_f16(src_data + 0 * src_step); \
src[1] = vld1q_f16(src_data + 1 * src_step); \
@ -94,6 +113,44 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da
src[34] = vld1q_f16(src_data + 34 * src_step); \
src[35] = vld1q_f16(src_data + 35 * src_step);
#define Load36DataC4Fp16 \
src[0] = vld1_f16(src_data + 0 * src_step); \
src[1] = vld1_f16(src_data + 1 * src_step); \
src[2] = vld1_f16(src_data + 2 * src_step); \
src[3] = vld1_f16(src_data + 3 * src_step); \
src[4] = vld1_f16(src_data + 4 * src_step); \
src[5] = vld1_f16(src_data + 5 * src_step); \
src[6] = vld1_f16(src_data + 6 * src_step); \
src[7] = vld1_f16(src_data + 7 * src_step); \
src[8] = vld1_f16(src_data + 8 * src_step); \
src[9] = vld1_f16(src_data + 9 * src_step); \
src[10] = vld1_f16(src_data + 10 * src_step); \
src[11] = vld1_f16(src_data + 11 * src_step); \
src[12] = vld1_f16(src_data + 12 * src_step); \
src[13] = vld1_f16(src_data + 13 * src_step); \
src[14] = vld1_f16(src_data + 14 * src_step); \
src[15] = vld1_f16(src_data + 15 * src_step); \
src[16] = vld1_f16(src_data + 16 * src_step); \
src[17] = vld1_f16(src_data + 17 * src_step); \
src[18] = vld1_f16(src_data + 18 * src_step); \
src[19] = vld1_f16(src_data + 19 * src_step); \
src[20] = vld1_f16(src_data + 20 * src_step); \
src[21] = vld1_f16(src_data + 21 * src_step); \
src[22] = vld1_f16(src_data + 22 * src_step); \
src[23] = vld1_f16(src_data + 23 * src_step); \
src[24] = vld1_f16(src_data + 24 * src_step); \
src[25] = vld1_f16(src_data + 25 * src_step); \
src[26] = vld1_f16(src_data + 26 * src_step); \
src[27] = vld1_f16(src_data + 27 * src_step); \
src[28] = vld1_f16(src_data + 28 * src_step); \
src[29] = vld1_f16(src_data + 29 * src_step); \
src[30] = vld1_f16(src_data + 30 * src_step); \
src[31] = vld1_f16(src_data + 31 * src_step); \
src[32] = vld1_f16(src_data + 32 * src_step); \
src[33] = vld1_f16(src_data + 33 * src_step); \
src[34] = vld1_f16(src_data + 34 * src_step); \
src[35] = vld1_f16(src_data + 35 * src_step);
#define Load64DataFp16 \
src[0] = vld1q_f16(src_data + 0 * src_step); \
src[1] = vld1q_f16(src_data + 1 * src_step); \
@ -160,13 +217,79 @@ void GeneralOutputTransformUnitFp16(const float16_t *src_data, float16_t *dst_da
src[62] = vld1q_f16(src_data + 62 * src_step); \
src[63] = vld1q_f16(src_data + 63 * src_step);
#define Load64DataC4Fp16 \
src[0] = vld1_f16(src_data + 0 * src_step); \
src[1] = vld1_f16(src_data + 1 * src_step); \
src[2] = vld1_f16(src_data + 2 * src_step); \
src[3] = vld1_f16(src_data + 3 * src_step); \
src[4] = vld1_f16(src_data + 4 * src_step); \
src[5] = vld1_f16(src_data + 5 * src_step); \
src[6] = vld1_f16(src_data + 6 * src_step); \
src[7] = vld1_f16(src_data + 7 * src_step); \
src[8] = vld1_f16(src_data + 8 * src_step); \
src[9] = vld1_f16(src_data + 9 * src_step); \
src[10] = vld1_f16(src_data + 10 * src_step); \
src[11] = vld1_f16(src_data + 11 * src_step); \
src[12] = vld1_f16(src_data + 12 * src_step); \
src[13] = vld1_f16(src_data + 13 * src_step); \
src[14] = vld1_f16(src_data + 14 * src_step); \
src[15] = vld1_f16(src_data + 15 * src_step); \
src[16] = vld1_f16(src_data + 16 * src_step); \
src[17] = vld1_f16(src_data + 17 * src_step); \
src[18] = vld1_f16(src_data + 18 * src_step); \
src[19] = vld1_f16(src_data + 19 * src_step); \
src[20] = vld1_f16(src_data + 20 * src_step); \
src[21] = vld1_f16(src_data + 21 * src_step); \
src[22] = vld1_f16(src_data + 22 * src_step); \
src[23] = vld1_f16(src_data + 23 * src_step); \
src[24] = vld1_f16(src_data + 24 * src_step); \
src[25] = vld1_f16(src_data + 25 * src_step); \
src[26] = vld1_f16(src_data + 26 * src_step); \
src[27] = vld1_f16(src_data + 27 * src_step); \
src[28] = vld1_f16(src_data + 28 * src_step); \
src[29] = vld1_f16(src_data + 29 * src_step); \
src[30] = vld1_f16(src_data + 30 * src_step); \
src[31] = vld1_f16(src_data + 31 * src_step); \
src[32] = vld1_f16(src_data + 32 * src_step); \
src[33] = vld1_f16(src_data + 33 * src_step); \
src[34] = vld1_f16(src_data + 34 * src_step); \
src[35] = vld1_f16(src_data + 35 * src_step); \
src[36] = vld1_f16(src_data + 36 * src_step); \
src[37] = vld1_f16(src_data + 37 * src_step); \
src[38] = vld1_f16(src_data + 38 * src_step); \
src[39] = vld1_f16(src_data + 39 * src_step); \
src[40] = vld1_f16(src_data + 40 * src_step); \
src[41] = vld1_f16(src_data + 41 * src_step); \
src[42] = vld1_f16(src_data + 42 * src_step); \
src[43] = vld1_f16(src_data + 43 * src_step); \
src[44] = vld1_f16(src_data + 44 * src_step); \
src[45] = vld1_f16(src_data + 45 * src_step); \
src[46] = vld1_f16(src_data + 46 * src_step); \
src[47] = vld1_f16(src_data + 47 * src_step); \
src[48] = vld1_f16(src_data + 48 * src_step); \
src[49] = vld1_f16(src_data + 49 * src_step); \
src[50] = vld1_f16(src_data + 50 * src_step); \
src[51] = vld1_f16(src_data + 51 * src_step); \
src[52] = vld1_f16(src_data + 52 * src_step); \
src[53] = vld1_f16(src_data + 53 * src_step); \
src[54] = vld1_f16(src_data + 54 * src_step); \
src[55] = vld1_f16(src_data + 55 * src_step); \
src[56] = vld1_f16(src_data + 56 * src_step); \
src[57] = vld1_f16(src_data + 57 * src_step); \
src[58] = vld1_f16(src_data + 58 * src_step); \
src[59] = vld1_f16(src_data + 59 * src_step); \
src[60] = vld1_f16(src_data + 60 * src_step); \
src[61] = vld1_f16(src_data + 61 * src_step); \
src[62] = vld1_f16(src_data + 62 * src_step); \
src[63] = vld1_f16(src_data + 63 * src_step);
InputTransFp16Func GetInputTransFp16Func(int input_unit);
void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
void InputTransform4x4UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
void InputTransform6x6UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step);
void InputTransform8x8UnitFp16(const float16_t *src_data, float16_t *dst_data, int src_step, int dst_step, int real_c);
OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActType act_type);
@ -176,6 +299,12 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT
vst1q_f16(dst_data + dst_step * out_c, m[2]); \
vst1q_f16(dst_data + dst_step * out_c + out_c, m[3]);
#define Store4DataC4Fp16 \
vst1_f16(dst_data, m[0]); \
vst1_f16(dst_data + out_c, m[1]); \
vst1_f16(dst_data + dst_step * out_c, m[2]); \
vst1_f16(dst_data + dst_step * out_c + out_c, m[3]);
#define Store9DataFp16 \
vst1q_f16(dst_data, m[0]); \
vst1q_f16(dst_data + out_c, m[1]); \
@ -187,6 +316,17 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT
vst1q_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \
vst1q_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]);
#define Store9DataC4Fp16 \
vst1_f16(dst_data, m[0]); \
vst1_f16(dst_data + out_c, m[1]); \
vst1_f16(dst_data + 2 * out_c, m[2]); \
vst1_f16(dst_data + dst_step * out_c, m[3]); \
vst1_f16(dst_data + dst_step * out_c + out_c, m[4]); \
vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[5]); \
vst1_f16(dst_data + 2 * dst_step * out_c, m[6]); \
vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[7]); \
vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[8]);
#define Store16DataFp16 \
vst1q_f16(dst_data, m[0]); \
vst1q_f16(dst_data + out_c, m[1]); \
@ -205,6 +345,24 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT
vst1q_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \
vst1q_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]);
#define Store16DataC4Fp16 \
vst1_f16(dst_data, m[0]); \
vst1_f16(dst_data + out_c, m[1]); \
vst1_f16(dst_data + 2 * out_c, m[2]); \
vst1_f16(dst_data + 3 * out_c, m[3]); \
vst1_f16(dst_data + dst_step * out_c, m[4]); \
vst1_f16(dst_data + dst_step * out_c + out_c, m[5]); \
vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[6]); \
vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[7]); \
vst1_f16(dst_data + 2 * dst_step * out_c, m[8]); \
vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[9]); \
vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[10]); \
vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[11]); \
vst1_f16(dst_data + 3 * dst_step * out_c, m[12]); \
vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[13]); \
vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[14]); \
vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[15]);
#define Store25DataFp16 \
vst1q_f16(dst_data, m[0]); \
vst1q_f16(dst_data + out_c, m[1]); \
@ -232,6 +390,33 @@ OutputTransFp16Func GetOutputTransFp16Func(int input_unit, int output_unit, ActT
vst1q_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \
vst1q_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]);
#define Store25DataC4Fp16 \
vst1_f16(dst_data, m[0]); \
vst1_f16(dst_data + out_c, m[1]); \
vst1_f16(dst_data + 2 * out_c, m[2]); \
vst1_f16(dst_data + 3 * out_c, m[3]); \
vst1_f16(dst_data + 4 * out_c, m[4]); \
vst1_f16(dst_data + dst_step * out_c, m[5]); \
vst1_f16(dst_data + dst_step * out_c + out_c, m[6]); \
vst1_f16(dst_data + dst_step * out_c + 2 * out_c, m[7]); \
vst1_f16(dst_data + dst_step * out_c + 3 * out_c, m[8]); \
vst1_f16(dst_data + dst_step * out_c + 4 * out_c, m[9]); \
vst1_f16(dst_data + 2 * dst_step * out_c, m[10]); \
vst1_f16(dst_data + 2 * dst_step * out_c + out_c, m[11]); \
vst1_f16(dst_data + 2 * dst_step * out_c + 2 * out_c, m[12]); \
vst1_f16(dst_data + 2 * dst_step * out_c + 3 * out_c, m[13]); \
vst1_f16(dst_data + 2 * dst_step * out_c + 4 * out_c, m[14]); \
vst1_f16(dst_data + 3 * dst_step * out_c, m[15]); \
vst1_f16(dst_data + 3 * dst_step * out_c + out_c, m[16]); \
vst1_f16(dst_data + 3 * dst_step * out_c + 2 * out_c, m[17]); \
vst1_f16(dst_data + 3 * dst_step * out_c + 3 * out_c, m[18]); \
vst1_f16(dst_data + 3 * dst_step * out_c + 4 * out_c, m[19]); \
vst1_f16(dst_data + 4 * dst_step * out_c, m[20]); \
vst1_f16(dst_data + 4 * dst_step * out_c + out_c, m[21]); \
vst1_f16(dst_data + 4 * dst_step * out_c + 2 * out_c, m[22]); \
vst1_f16(dst_data + 4 * dst_step * out_c + 3 * out_c, m[23]); \
vst1_f16(dst_data + 4 * dst_step * out_c + 4 * out_c, m[24]);
void OutputTransform4x2UnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,
int src_step, int dst_step, int out_c, int r_w, int r_h, int r_c);
void OutputTransform4x2ReluUnitFp16(const float16_t *src_data, float16_t *dst_data, const float16_t *bias_data,

@ -197,26 +197,28 @@ kernel::LiteKernel *CpuConvFp16KernelCreator(const std::vector<lite::Tensor *> &
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
bool use_winograd = false;
int out_unit;
if (primitive != nullptr && primitive->GetInferFlag()) {
conv_param->input_h_ = inputs.front()->Height();
conv_param->input_w_ = inputs.front()->Width();
conv_param->input_channel_ = inputs.front()->Channel();
conv_param->output_h_ = outputs.front()->Height();
conv_param->output_w_ = outputs.front()->Width();
conv_param->output_channel_ = outputs.front()->Channel();
conv_param->op_parameter_.thread_num_ = ctx->thread_num_;
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param);
}
kernel::LiteKernel *kernel = nullptr;
if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1FP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else {
bool use_winograd = false;
int out_unit;
CheckIfUseWinograd(&use_winograd, &out_unit, conv_param);
if (use_winograd) {
} else if (use_winograd) {
kernel = new (std::nothrow)
kernel::ConvolutionWinogradFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive, out_unit);
}
if (kernel_h != 1 && kernel_w != 1 && !use_winograd) {
} else {
kernel = new (std::nothrow) kernel::ConvolutionFP16CPUKernel(opParameter, inputs, outputs, ctx, primitive);
}
}
if (kernel == nullptr) {
MS_LOG(DEBUG) << "Create conv fp16 kernel failed.";
if (dequant_flag) {

@ -39,8 +39,6 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_
// original weight format : ohwi
auto channel_in = conv_param_->input_channel_;
auto channel_out = conv_param_->output_channel_;
int ic8 = UP_DIV(channel_in, C8NUM);
int ic4 = ic8 * 2;
int input_unit_square = input_unit_ * input_unit_;
int oc_block_num = UP_DIV(channel_out, oc_block);
@ -84,17 +82,7 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_
MS_LOG(ERROR) << "malloc trans_out_data failed.";
return RET_ERROR;
}
std::vector<int> shape{input_unit_ * input_unit_, oc_block_num, ic4, C4NUM, oc_block};
std::vector<int> strides;
for (int i = 0; i < 4; i++) {
int stride = 1;
for (int j = i + 1; j < 5; j++) {
stride *= shape[j];
}
strides.push_back(stride);
}
int kernel_plane_stride = channel_in;
if (oc_block == 0) {
MS_LOG(ERROR) << "Divide by zero";
free(tmp_weight_data);
@ -104,18 +92,17 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_
free(matrix_gt_data_fp16);
return RET_ERROR;
}
int stride1 = channel_in * oc_block;
for (int i = 0; i < channel_out; i++) {
int out_c_block = i / oc_block;
int out_c_res = i % oc_block;
int input_oz_offset = i * kernel_unit_ * kernel_unit_ * channel_in;
int output_oz_offset = out_c_block * strides[1] * input_unit_ * input_unit_ + out_c_res;
int output_oz_offset = out_c_block * stride1 + out_c_res;
for (int j = 0; j < channel_in; j++) {
int ic4_block = j / C4NUM;
int ic4_res = j % C4NUM;
int input_iz_offset = input_oz_offset + j;
int output_iz_offset = output_oz_offset + ic4_block * strides[2] + ic4_res * strides[3];
int output_iz_offset = output_oz_offset + j * oc_block;
for (int k = 0; k < kernel_unit_ * kernel_unit_; k++) {
int input_xy_offset = input_iz_offset + k * kernel_plane_stride;
int input_xy_offset = input_iz_offset + k * channel_in;
tmp_weight_data[k] = *(weight_data + input_xy_offset);
}
// now we only support row-major matrix-multiply
@ -125,7 +112,7 @@ int ConvolutionWinogradFP16CPUKernel::WinogradFilterTransformFp16(const float16_
MatrixMultiplyFp16(tmp_data, matrix_gt_data_fp16, trans_out_data, input_unit_, kernel_unit_, input_unit_);
for (int z = 0; z < input_unit_square; z++) {
int output_xy_offset = output_iz_offset + z * strides[1];
int output_xy_offset = output_iz_offset + z * oc_block_num * stride1;
trans_weight_[output_xy_offset] = trans_out_data[z];
}
}
@ -142,7 +129,6 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
auto filter_tensor = in_tensors_.at(kWeightIndex);
int in_channel = filter_tensor->Channel();
int out_channel = filter_tensor->Batch();
int ic8 = UP_DIV(in_channel, C8NUM);
conv_param_->input_channel_ = in_channel;
conv_param_->output_channel_ = out_channel;
@ -157,7 +143,7 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() {
}
// set data
auto trans_matrix_data_size = input_unit_ * input_unit_ * ic8 * C8NUM * oc_block_num * oc_block * sizeof(float16_t);
auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float16_t);
trans_weight_ = reinterpret_cast<float16_t *>(malloc(trans_matrix_data_size));
if (trans_weight_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_weight_ failed.";
@ -209,9 +195,9 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
const int cal_num = 16;
int channel_out = conv_param_->output_channel_;
int oc8 = UP_DIV(channel_out, C8NUM);
int ic8 = UP_DIV(conv_param_->input_channel_, C8NUM);
size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * ic8 * C8NUM * sizeof(float16_t);
size_t tile_buffer_size =
thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t);
trans_input_ = reinterpret_cast<float16_t *>(ctx_->allocator->Malloc(tile_buffer_size));
if (trans_input_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_input_ failed.";
@ -232,9 +218,17 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() {
return RET_ERROR;
}
col_buffer_ = reinterpret_cast<float16_t *>(
ctx_->allocator->Malloc(thread_count_ * cal_num * conv_param_->input_channel_ * sizeof(float16_t)));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
return RET_ERROR;
}
tmp_buffer_address_list_[0] = trans_input_;
tmp_buffer_address_list_[1] = gemm_out_;
tmp_buffer_address_list_[2] = tmp_data_;
tmp_buffer_address_list_[3] = col_buffer_;
return RET_OK;
}

@ -67,6 +67,10 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
ctx_->allocator->Free(gemm_out_);
gemm_out_ = nullptr;
}
if (col_buffer_ != nullptr) {
ctx_->allocator->Free(col_buffer_);
col_buffer_ = nullptr;
}
}
int kernel_unit_;
int input_unit_;
@ -75,7 +79,8 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel {
float16_t *trans_input_ = nullptr;
float16_t *gemm_out_ = nullptr;
float16_t *trans_weight_ = nullptr;
TmpBufferAddressFp16 tmp_buffer_address_list_[3];
float16_t *col_buffer_ = nullptr;
TmpBufferAddressFp16 tmp_buffer_address_list_[4];
InputTransFp16Func in_func_;
OutputTransFp16Func out_func_;
};

Loading…
Cancel
Save