diff --git a/mindspore/lite/nnacl/fp16/conv_fp16.c b/mindspore/lite/nnacl/fp16/conv_fp16.c index dd017aacc7..069fe5139f 100644 --- a/mindspore/lite/nnacl/fp16/conv_fp16.c +++ b/mindspore/lite/nnacl/fp16/conv_fp16.c @@ -160,7 +160,9 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_); int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_); int output_count = out_w_block * out_h_block; - int output_tile_count = UP_DIV(output_count, tile_num); + int per_thread_num = UP_DIV(output_count, conv_param->thread_num_); + int real_tile = per_thread_num < tile_num ? per_thread_num : tile_num; + int output_tile_count = UP_DIV(output_count, real_tile); int oc8 = UP_DIV(conv_param->output_channel_, C8NUM); int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_; @@ -178,9 +180,12 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_; int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_; for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { - int out_tile_index = thread_id * tile_num; - int cal_num = output_count - thread_id * tile_num; - cal_num = cal_num > tile_num ? tile_num : cal_num; + int out_tile_index = thread_id * real_tile; + int cal_num = output_count - thread_id * real_tile; + cal_num = cal_num > real_tile ? real_tile : cal_num; + if (cal_num <= 0) { + return; + } WinogradInputTransformFp16(input_data + in_batch_offset, trans_input + task_id * trans_input_offset, tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, in_func); @@ -189,7 +194,7 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa float16_t *dst_ptr = gemm_out + task_id * gemm_out_offset; float16_t *tmp_col_ptr = col_buffer + task_id * col_buffer_offset; for (int i = 0; i < input_unit_square; ++i) { - RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, tile_num, in_channel); + RowMajor2Col16MajorFp16Opt(src_ptr + i * tile_num * in_channel, tmp_col_ptr, cal_num, in_channel); MatMulFp16(tmp_col_ptr, trans_weight + i * in_channel * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, in_channel, cal_num, oc8 * C8NUM, input_unit_square, OutType_TileC8); } diff --git a/mindspore/lite/nnacl/fp16/matmul_fp16.c b/mindspore/lite/nnacl/fp16/matmul_fp16.c index b50bc74888..02cc6e48eb 100644 --- a/mindspore/lite/nnacl/fp16/matmul_fp16.c +++ b/mindspore/lite/nnacl/fp16/matmul_fp16.c @@ -16,202 +16,213 @@ #include "nnacl/fp16/matmul_fp16.h" -void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) { +static void Col2Row8SrcFromFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { int row_c8 = row / C8NUM * C8NUM; int col_c8 = col / C8NUM * C8NUM; + const float16_t *src = (const float16_t *)src_ptr; int ci = 0; - if (src_float16) { - const float16_t *src = (const float16_t *)src_ptr; - for (; ci < col_c8; ci += C8NUM) { - int ri = 0; - for (; ri < row_c8; ri += C8NUM) { - const float16_t *src_ptr1 = src + ci * row + ri; - float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; + for (; ci < col_c8; ci += C8NUM) { + int ri = 0; + for (; ri < row_c8; ri += C8NUM) { + const float16_t *src_ptr1 = src + ci * row + ri; + float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; #ifdef ENABLE_ARM64 - size_t strid_row = row * 2; - asm volatile( - "mov x10, %[src_ptr1]\n" - "mov x11, %[dst_ptr1]\n" - "mov x12, %[strid_row]\n" - "ld1 {v0.8h}, [x10], x12\n" - "ld1 {v1.8h}, [x10], x12\n" - "ld1 {v2.8h}, [x10], x12\n" - "ld1 {v3.8h}, [x10], x12\n" - "ld1 {v4.8h}, [x10], x12\n" - "ld1 {v5.8h}, [x10], x12\n" - "ld1 {v6.8h}, [x10], x12\n" - "ld1 {v7.8h}, [x10], x12\n" - - "zip1 v8.8h, v0.8h, v1.8h\n" - "zip1 v9.8h, v2.8h, v3.8h\n" - "zip1 v10.8h, v4.8h, v5.8h\n" - "zip1 v11.8h, v6.8h, v7.8h\n" - - "trn1 v12.4s, v8.4s, v9.4s\n" - "trn1 v14.4s, v10.4s, v11.4s\n" - "trn2 v13.4s, v8.4s, v9.4s\n" - "trn2 v15.4s, v10.4s, v11.4s\n" - - "trn1 v16.2d, v12.2d, v14.2d\n" - "trn2 v18.2d, v12.2d, v14.2d\n" - "trn1 v17.2d, v13.2d, v15.2d\n" - "trn2 v19.2d, v13.2d, v15.2d\n" - - "zip2 v8.8h, v0.8h, v1.8h\n" - "zip2 v9.8h, v2.8h, v3.8h\n" - "zip2 v10.8h, v4.8h, v5.8h\n" - "zip2 v11.8h, v6.8h, v7.8h\n" - - "trn1 v12.4s, v8.4s, v9.4s\n" - "trn1 v14.4s, v10.4s, v11.4s\n" - "trn2 v13.4s, v8.4s, v9.4s\n" - "trn2 v15.4s, v10.4s, v11.4s\n" - - "trn1 v20.2d, v12.2d, v14.2d\n" - "trn2 v22.2d, v12.2d, v14.2d\n" - "trn1 v21.2d, v13.2d, v15.2d\n" - "trn2 v23.2d, v13.2d, v15.2d\n" - - "st1 {v16.8h}, [x11], #16\n" - "st1 {v17.8h}, [x11], #16\n" - "st1 {v18.8h}, [x11], #16\n" - "st1 {v19.8h}, [x11], #16\n" - "st1 {v20.8h}, [x11], #16\n" - "st1 {v21.8h}, [x11], #16\n" - "st1 {v22.8h}, [x11], #16\n" - "st1 {v23.8h}, [x11], #16\n" - : - : [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) - : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", - "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); + size_t strid_row = row * 2; + asm volatile( + "mov x10, %[src_ptr1]\n" + "mov x11, %[dst_ptr1]\n" + "mov x12, %[strid_row]\n" + "ld1 {v0.8h}, [x10], x12\n" + "ld1 {v1.8h}, [x10], x12\n" + "ld1 {v2.8h}, [x10], x12\n" + "ld1 {v3.8h}, [x10], x12\n" + "ld1 {v4.8h}, [x10], x12\n" + "ld1 {v5.8h}, [x10], x12\n" + "ld1 {v6.8h}, [x10], x12\n" + "ld1 {v7.8h}, [x10], x12\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip1 v9.8h, v2.8h, v3.8h\n" + "zip1 v10.8h, v4.8h, v5.8h\n" + "zip1 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v16.2d, v12.2d, v14.2d\n" + "trn2 v18.2d, v12.2d, v14.2d\n" + "trn1 v17.2d, v13.2d, v15.2d\n" + "trn2 v19.2d, v13.2d, v15.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v22.2d, v12.2d, v14.2d\n" + "trn1 v21.2d, v13.2d, v15.2d\n" + "trn2 v23.2d, v13.2d, v15.2d\n" + + "st1 {v16.8h}, [x11], #16\n" + "st1 {v17.8h}, [x11], #16\n" + "st1 {v18.8h}, [x11], #16\n" + "st1 {v19.8h}, [x11], #16\n" + "st1 {v20.8h}, [x11], #16\n" + "st1 {v21.8h}, [x11], #16\n" + "st1 {v22.8h}, [x11], #16\n" + "st1 {v23.8h}, [x11], #16\n" + : + : [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); #else - for (int tr = 0; tr < C8NUM; ++tr) { - for (int tc = 0; tc < C8NUM; ++tc) { - dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr]; - } - } -#endif - } - for (; ri < row; ++ri) { - const float16_t *src_ptr1 = src + ci * row; - float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tr = 0; tr < C8NUM; ++tr) { for (int tc = 0; tc < C8NUM; ++tc) { - dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri]; + dst_ptr1[tr * C8NUM + tc] = src_ptr1[tc * row + tr]; } } +#endif } - for (int r = 0; r < row; r++) { - for (int tc = ci; tc < col; tc++) { - int cd8 = tc / C8NUM; - int cm8 = tc % C8NUM; - dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r]; + for (; ri < row; ++ri) { + const float16_t *src_ptr1 = src + ci * row; + float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[ri * C8NUM + tc] = src_ptr1[tc * row + ri]; } } - } else { - const float *src = (const float *)src_ptr; - for (; ci < col_c8; ci += C8NUM) { - int ri = 0; - for (; ri < row_c8; ri += C8NUM) { - const float *src_ptr1 = src + ci * row + ri; - float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; + } + for (int r = 0; r < row; r++) { + for (int tc = ci; tc < col; tc++) { + int cd8 = tc / C8NUM; + int cm8 = tc % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = src[tc * row + r]; + } + } +} + +static void Col2Row8SrcFromFp32(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { + int row_c8 = row / C8NUM * C8NUM; + int col_c8 = col / C8NUM * C8NUM; + int ci = 0; + const float *src = (const float *)src_ptr; + for (; ci < col_c8; ci += C8NUM) { + int ri = 0; + for (; ri < row_c8; ri += C8NUM) { + const float *src_ptr1 = src + ci * row + ri; + float16_t *dst_ptr1 = dst_ptr + ci * row + ri * C8NUM; #ifdef ENABLE_ARM64 - size_t strid_row = row * 4; - asm volatile( - "mov x10, %[src_ptr1]\n" - "mov x11, %[dst_ptr1]\n" - "mov x12, %[strid_row]\n" - "ld1 {v8.4s, v9.4s}, [x10], x12\n" - "ld1 {v10.4s, v11.4s}, [x10], x12\n" - "ld1 {v12.4s, v13.4s}, [x10], x12\n" - "ld1 {v14.4s, v15.4s}, [x10], x12\n" - "ld1 {v16.4s, v17.4s}, [x10], x12\n" - "ld1 {v18.4s, v19.4s}, [x10], x12\n" - "ld1 {v20.4s, v21.4s}, [x10], x12\n" - "ld1 {v22.4s, v23.4s}, [x10], x12\n" - - "fcvtn v0.4h, v8.4s\n" - "fcvtn2 v0.8h, v9.4s\n" - "fcvtn v1.4h, v10.4s\n" - "fcvtn2 v1.8h, v11.4s\n" - "fcvtn v2.4h, v12.4s\n" - "fcvtn2 v2.8h, v13.4s\n" - "fcvtn v3.4h, v14.4s\n" - "fcvtn2 v3.8h, v15.4s\n" - "fcvtn v4.4h, v16.4s\n" - "fcvtn2 v4.8h, v17.4s\n" - "fcvtn v5.4h, v18.4s\n" - "fcvtn2 v5.8h, v19.4s\n" - "fcvtn v6.4h, v20.4s\n" - "fcvtn2 v6.8h, v21.4s\n" - "fcvtn v7.4h, v22.4s\n" - "fcvtn2 v7.8h, v23.4s\n" - - "zip1 v8.8h, v0.8h, v1.8h\n" - "zip1 v9.8h, v2.8h, v3.8h\n" - "zip1 v10.8h, v4.8h, v5.8h\n" - "zip1 v11.8h, v6.8h, v7.8h\n" - - "trn1 v12.4s, v8.4s, v9.4s\n" - "trn1 v14.4s, v10.4s, v11.4s\n" - "trn2 v13.4s, v8.4s, v9.4s\n" - "trn2 v15.4s, v10.4s, v11.4s\n" - - "trn1 v16.2d, v12.2d, v14.2d\n" - "trn2 v18.2d, v12.2d, v14.2d\n" - "trn1 v17.2d, v13.2d, v15.2d\n" - "trn2 v19.2d, v13.2d, v15.2d\n" - - "zip2 v8.8h, v0.8h, v1.8h\n" - "zip2 v9.8h, v2.8h, v3.8h\n" - "zip2 v10.8h, v4.8h, v5.8h\n" - "zip2 v11.8h, v6.8h, v7.8h\n" - - "trn1 v12.4s, v8.4s, v9.4s\n" - "trn1 v14.4s, v10.4s, v11.4s\n" - "trn2 v13.4s, v8.4s, v9.4s\n" - "trn2 v15.4s, v10.4s, v11.4s\n" - - "trn1 v20.2d, v12.2d, v14.2d\n" - "trn2 v22.2d, v12.2d, v14.2d\n" - "trn1 v21.2d, v13.2d, v15.2d\n" - "trn2 v23.2d, v13.2d, v15.2d\n" - - "st1 {v16.8h}, [x11], #16\n" - "st1 {v17.8h}, [x11], #16\n" - "st1 {v18.8h}, [x11], #16\n" - "st1 {v19.8h}, [x11], #16\n" - "st1 {v20.8h}, [x11], #16\n" - "st1 {v21.8h}, [x11], #16\n" - "st1 {v22.8h}, [x11], #16\n" - "st1 {v23.8h}, [x11], #16\n" - : - : [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) - : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", - "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); + size_t strid_row = row * 4; + asm volatile( + "mov x10, %[src_ptr1]\n" + "mov x11, %[dst_ptr1]\n" + "mov x12, %[strid_row]\n" + "ld1 {v8.4s, v9.4s}, [x10], x12\n" + "ld1 {v10.4s, v11.4s}, [x10], x12\n" + "ld1 {v12.4s, v13.4s}, [x10], x12\n" + "ld1 {v14.4s, v15.4s}, [x10], x12\n" + "ld1 {v16.4s, v17.4s}, [x10], x12\n" + "ld1 {v18.4s, v19.4s}, [x10], x12\n" + "ld1 {v20.4s, v21.4s}, [x10], x12\n" + "ld1 {v22.4s, v23.4s}, [x10], x12\n" + + "fcvtn v0.4h, v8.4s\n" + "fcvtn2 v0.8h, v9.4s\n" + "fcvtn v1.4h, v10.4s\n" + "fcvtn2 v1.8h, v11.4s\n" + "fcvtn v2.4h, v12.4s\n" + "fcvtn2 v2.8h, v13.4s\n" + "fcvtn v3.4h, v14.4s\n" + "fcvtn2 v3.8h, v15.4s\n" + "fcvtn v4.4h, v16.4s\n" + "fcvtn2 v4.8h, v17.4s\n" + "fcvtn v5.4h, v18.4s\n" + "fcvtn2 v5.8h, v19.4s\n" + "fcvtn v6.4h, v20.4s\n" + "fcvtn2 v6.8h, v21.4s\n" + "fcvtn v7.4h, v22.4s\n" + "fcvtn2 v7.8h, v23.4s\n" + + "zip1 v8.8h, v0.8h, v1.8h\n" + "zip1 v9.8h, v2.8h, v3.8h\n" + "zip1 v10.8h, v4.8h, v5.8h\n" + "zip1 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v16.2d, v12.2d, v14.2d\n" + "trn2 v18.2d, v12.2d, v14.2d\n" + "trn1 v17.2d, v13.2d, v15.2d\n" + "trn2 v19.2d, v13.2d, v15.2d\n" + + "zip2 v8.8h, v0.8h, v1.8h\n" + "zip2 v9.8h, v2.8h, v3.8h\n" + "zip2 v10.8h, v4.8h, v5.8h\n" + "zip2 v11.8h, v6.8h, v7.8h\n" + + "trn1 v12.4s, v8.4s, v9.4s\n" + "trn1 v14.4s, v10.4s, v11.4s\n" + "trn2 v13.4s, v8.4s, v9.4s\n" + "trn2 v15.4s, v10.4s, v11.4s\n" + + "trn1 v20.2d, v12.2d, v14.2d\n" + "trn2 v22.2d, v12.2d, v14.2d\n" + "trn1 v21.2d, v13.2d, v15.2d\n" + "trn2 v23.2d, v13.2d, v15.2d\n" + + "st1 {v16.8h}, [x11], #16\n" + "st1 {v17.8h}, [x11], #16\n" + "st1 {v18.8h}, [x11], #16\n" + "st1 {v19.8h}, [x11], #16\n" + "st1 {v20.8h}, [x11], #16\n" + "st1 {v21.8h}, [x11], #16\n" + "st1 {v22.8h}, [x11], #16\n" + "st1 {v23.8h}, [x11], #16\n" + : + : [ dst_ptr1 ] "r"(dst_ptr1), [ src_ptr1 ] "r"(src_ptr1), [ strid_row ] "r"(strid_row) + : "x10", "x11", "x12", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", + "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"); #else - for (int tr = 0; tr < C8NUM; ++tr) { - for (int tc = 0; tc < C8NUM; ++tc) { - dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]); - } - } -#endif - } - for (; ri < row; ++ri) { - const float *src_ptr1 = src + ci * row; - float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tr = 0; tr < C8NUM; ++tr) { for (int tc = 0; tc < C8NUM; ++tc) { - dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]); + dst_ptr1[tr * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + tr]); } } +#endif } - for (int r = 0; r < row; r++) { - for (int tc = ci; tc < col; tc++) { - int cd8 = tc / C8NUM; - int cm8 = tc % C8NUM; - dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]); + for (; ri < row; ++ri) { + const float *src_ptr1 = src + ci * row; + float16_t *dst_ptr1 = dst_ptr + ci * row; + for (int tc = 0; tc < C8NUM; ++tc) { + dst_ptr1[ri * C8NUM + tc] = (float16_t)(src_ptr1[tc * row + ri]); } } } + for (int r = 0; r < row; r++) { + for (int tc = ci; tc < col; tc++) { + int cd8 = tc / C8NUM; + int cm8 = tc % C8NUM; + dst_ptr[cd8 * C8NUM * row + r * C8NUM + cm8] = (float16_t)(src[tc * row + r]); + } + } +} + +void ColMajor2Row8MajorFp16(const void *src_ptr, float16_t *dst_ptr, size_t row, size_t col, bool src_float16) { + if (src_float16) { + Col2Row8SrcFromFp16(src_ptr, dst_ptr, row, col); + } else { + Col2Row8SrcFromFp32(src_ptr, dst_ptr, row, col); + } return; } @@ -274,126 +285,129 @@ void MatVecMulFp16(const float16_t *a, const float16_t *b, float16_t *c, const f MatVecMulFp16Neon64(a, b, c, bias, (int)act_type, depth, col); } +static void Row2Col16Block16(const float16_t *src_ptr, float16_t *dst_ptr, size_t col) { + size_t stride = col * 2; + asm volatile( + "mov x10, %[src_c]\n" + "mov x11, %[dst_c]\n" + + "ld1 {v0.8h}, [x10], %[stride]\n" + "ld1 {v1.8h}, [x10], %[stride]\n" + "ld1 {v2.8h}, [x10], %[stride]\n" + "ld1 {v3.8h}, [x10], %[stride]\n" + "ld1 {v4.8h}, [x10], %[stride]\n" + "ld1 {v5.8h}, [x10], %[stride]\n" + "ld1 {v6.8h}, [x10], %[stride]\n" + "ld1 {v7.8h}, [x10], %[stride]\n" + + "zip1 v16.8h, v0.8h, v1.8h\n" + "zip1 v17.8h, v2.8h, v3.8h\n" + "zip1 v18.8h, v4.8h, v5.8h\n" + "zip1 v19.8h, v6.8h, v7.8h\n" + + "ld1 {v8.8h}, [x10], %[stride]\n" + "ld1 {v9.8h}, [x10], %[stride]\n" + "ld1 {v10.8h}, [x10], %[stride]\n" + "ld1 {v11.8h}, [x10], %[stride]\n" + "ld1 {v12.8h}, [x10], %[stride]\n" + "ld1 {v13.8h}, [x10], %[stride]\n" + "ld1 {v14.8h}, [x10], %[stride]\n" + "ld1 {v15.8h}, [x10], %[stride]\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip1 v16.8h, v8.8h, v9.8h\n" + "zip1 v17.8h, v10.8h, v11.8h\n" + "zip1 v18.8h, v12.8h, v13.8h\n" + "zip1 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + + "zip2 v16.8h, v0.8h, v1.8h\n" + "zip2 v17.8h, v2.8h, v3.8h\n" + "zip2 v18.8h, v4.8h, v5.8h\n" + "zip2 v19.8h, v6.8h, v7.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v24.2d, v20.2d, v22.2d\n" + "trn2 v25.2d, v20.2d, v22.2d\n" + "trn1 v26.2d, v21.2d, v23.2d\n" + "trn2 v27.2d, v21.2d, v23.2d\n" + + "zip2 v16.8h, v8.8h, v9.8h\n" + "zip2 v17.8h, v10.8h, v11.8h\n" + "zip2 v18.8h, v12.8h, v13.8h\n" + "zip2 v19.8h, v14.8h, v15.8h\n" + + "trn1 v20.4s, v16.4s, v17.4s\n" + "trn2 v21.4s, v16.4s, v17.4s\n" + "trn1 v22.4s, v18.4s, v19.4s\n" + "trn2 v23.4s, v18.4s, v19.4s\n" + + "trn1 v28.2d, v20.2d, v22.2d\n" + "trn2 v29.2d, v20.2d, v22.2d\n" + "trn1 v30.2d, v21.2d, v23.2d\n" + "trn2 v31.2d, v21.2d, v23.2d\n" + + "st1 {v24.8h}, [x11], #16\n" + "st1 {v28.8h}, [x11], #16\n" + "st1 {v26.8h}, [x11], #16\n" + "st1 {v30.8h}, [x11], #16\n" + "st1 {v25.8h}, [x11], #16\n" + "st1 {v29.8h}, [x11], #16\n" + "st1 {v27.8h}, [x11], #16\n" + "st1 {v31.8h}, [x11], #16\n" + : + : [ dst_c ] "r"(dst_ptr), [ src_c ] "r"(src_ptr), [ stride ] "r"(stride) + : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", + "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", + "v31"); +} + void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, size_t row, size_t col) { size_t row_up_16 = UP_ROUND(row, C16NUM); size_t row16 = row / C16NUM * C16NUM; size_t col8 = col / C8NUM * C8NUM; const float16_t *src_r = src_ptr; float16_t *dst_r = dst_ptr; - size_t ri = 0; + // find 16 block unit for (; ri < row16; ri += C16NUM) { size_t ci = 0; for (; ci < col8; ci += C8NUM) { const float16_t *src_c = src_r + ci; float16_t *dst_c = dst_r + ci * C16NUM; - #ifdef ENABLE_ARM64 - size_t stride = col * 2; - asm volatile( - "mov x10, %[src_c]\n" - "mov x11, %[dst_c]\n" - - "ld1 {v0.8h}, [x10], %[stride]\n" - "ld1 {v1.8h}, [x10], %[stride]\n" - "ld1 {v2.8h}, [x10], %[stride]\n" - "ld1 {v3.8h}, [x10], %[stride]\n" - "ld1 {v4.8h}, [x10], %[stride]\n" - "ld1 {v5.8h}, [x10], %[stride]\n" - "ld1 {v6.8h}, [x10], %[stride]\n" - "ld1 {v7.8h}, [x10], %[stride]\n" - - "zip1 v16.8h, v0.8h, v1.8h\n" - "zip1 v17.8h, v2.8h, v3.8h\n" - "zip1 v18.8h, v4.8h, v5.8h\n" - "zip1 v19.8h, v6.8h, v7.8h\n" - - "ld1 {v8.8h}, [x10], %[stride]\n" - "ld1 {v9.8h}, [x10], %[stride]\n" - "ld1 {v10.8h}, [x10], %[stride]\n" - "ld1 {v11.8h}, [x10], %[stride]\n" - "ld1 {v12.8h}, [x10], %[stride]\n" - "ld1 {v13.8h}, [x10], %[stride]\n" - "ld1 {v14.8h}, [x10], %[stride]\n" - "ld1 {v15.8h}, [x10], %[stride]\n" - - "trn1 v20.4s, v16.4s, v17.4s\n" - "trn2 v21.4s, v16.4s, v17.4s\n" - "trn1 v22.4s, v18.4s, v19.4s\n" - "trn2 v23.4s, v18.4s, v19.4s\n" - - "trn1 v24.2d, v20.2d, v22.2d\n" - "trn2 v25.2d, v20.2d, v22.2d\n" - "trn1 v26.2d, v21.2d, v23.2d\n" - "trn2 v27.2d, v21.2d, v23.2d\n" - - "zip1 v16.8h, v8.8h, v9.8h\n" - "zip1 v17.8h, v10.8h, v11.8h\n" - "zip1 v18.8h, v12.8h, v13.8h\n" - "zip1 v19.8h, v14.8h, v15.8h\n" - - "trn1 v20.4s, v16.4s, v17.4s\n" - "trn2 v21.4s, v16.4s, v17.4s\n" - "trn1 v22.4s, v18.4s, v19.4s\n" - "trn2 v23.4s, v18.4s, v19.4s\n" - - "trn1 v28.2d, v20.2d, v22.2d\n" - "trn2 v29.2d, v20.2d, v22.2d\n" - "trn1 v30.2d, v21.2d, v23.2d\n" - "trn2 v31.2d, v21.2d, v23.2d\n" - - "st1 {v24.8h}, [x11], #16\n" - "st1 {v28.8h}, [x11], #16\n" - "st1 {v26.8h}, [x11], #16\n" - "st1 {v30.8h}, [x11], #16\n" - "st1 {v25.8h}, [x11], #16\n" - "st1 {v29.8h}, [x11], #16\n" - "st1 {v27.8h}, [x11], #16\n" - "st1 {v31.8h}, [x11], #16\n" - - "zip2 v16.8h, v0.8h, v1.8h\n" - "zip2 v17.8h, v2.8h, v3.8h\n" - "zip2 v18.8h, v4.8h, v5.8h\n" - "zip2 v19.8h, v6.8h, v7.8h\n" - - "trn1 v20.4s, v16.4s, v17.4s\n" - "trn2 v21.4s, v16.4s, v17.4s\n" - "trn1 v22.4s, v18.4s, v19.4s\n" - "trn2 v23.4s, v18.4s, v19.4s\n" - - "trn1 v24.2d, v20.2d, v22.2d\n" - "trn2 v25.2d, v20.2d, v22.2d\n" - "trn1 v26.2d, v21.2d, v23.2d\n" - "trn2 v27.2d, v21.2d, v23.2d\n" - - "zip2 v16.8h, v8.8h, v9.8h\n" - "zip2 v17.8h, v10.8h, v11.8h\n" - "zip2 v18.8h, v12.8h, v13.8h\n" - "zip2 v19.8h, v14.8h, v15.8h\n" - - "trn1 v20.4s, v16.4s, v17.4s\n" - "trn2 v21.4s, v16.4s, v17.4s\n" - "trn1 v22.4s, v18.4s, v19.4s\n" - "trn2 v23.4s, v18.4s, v19.4s\n" - - "trn1 v28.2d, v20.2d, v22.2d\n" - "trn2 v29.2d, v20.2d, v22.2d\n" - "trn1 v30.2d, v21.2d, v23.2d\n" - "trn2 v31.2d, v21.2d, v23.2d\n" - - "st1 {v24.8h}, [x11], #16\n" - "st1 {v28.8h}, [x11], #16\n" - "st1 {v26.8h}, [x11], #16\n" - "st1 {v30.8h}, [x11], #16\n" - "st1 {v25.8h}, [x11], #16\n" - "st1 {v29.8h}, [x11], #16\n" - "st1 {v27.8h}, [x11], #16\n" - "st1 {v31.8h}, [x11], #16\n" - : - : [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride) - : "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", - "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", - "v30", "v31"); + Row2Col16Block16(src_c, dst_c, col); #else for (int tr = 0; tr < C16NUM; tr++) { for (int tc = 0; tc < C8NUM; tc++) { @@ -413,7 +427,7 @@ void RowMajor2Col16MajorFp16Opt(const float16_t *src_ptr, float16_t *dst_ptr, si dst_r += C16NUM * col; } for (; ri < row; ri++) { - for (size_t i = 0; i < col; i++) { + for (size_t i = 0; i < col; ++i) { dst_r[i * C16NUM] = src_r[i]; } src_r += col; diff --git a/mindspore/lite/nnacl/fp32/conv_common_fp32.c b/mindspore/lite/nnacl/fp32/conv_common_fp32.c index fd314aae4b..fb244351ae 100644 --- a/mindspore/lite/nnacl/fp32/conv_common_fp32.c +++ b/mindspore/lite/nnacl/fp32/conv_common_fp32.c @@ -40,6 +40,9 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_ for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) { int start_index = thread_id * cal_num; int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num; + if (real_cal_num <= 0) { + return; + } float *gemm_input = packed_input + task_id * deep * cal_num; float *col_major_gemm_input = col_major_input + task_id * deep * cal_num; size_t packed_input_size = deep * cal_num * sizeof(float); diff --git a/mindspore/lite/nnacl/fp32/conv_winograd_fp32.c b/mindspore/lite/nnacl/fp32/conv_winograd_fp32.c index 9c1311c1e1..d8370aa468 100644 --- a/mindspore/lite/nnacl/fp32/conv_winograd_fp32.c +++ b/mindspore/lite/nnacl/fp32/conv_winograd_fp32.c @@ -56,6 +56,9 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const int out_tile_index = thread_id * tile_num; int cal_num = output_count - out_tile_index; cal_num = cal_num > tile_num ? tile_num : cal_num; + if (cal_num <= 0) { + return; + } WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset, tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param, in_func); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc index c7771f80be..716ae15d51 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/convolution_base.cc @@ -36,7 +36,6 @@ ConvolutionBaseCPUKernel::~ConvolutionBaseCPUKernel() { } void ConvolutionBaseCPUKernel::FreeQuantParam() { - ConvQuantArg *conv_quant_arg_ = &conv_param_->conv_quant_arg_; if (conv_quant_arg_ == nullptr) { return; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.h index 6ca864c8f9..da4d11ec83 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_delegate_fp16.h @@ -44,7 +44,10 @@ class ConvolutionDelegateFP16CPUKernel : public LiteKernel { void FreeCopiedData(); int Init() override; int ReSize() override; - int Run() override { return fp16_conv_kernel_->Run(); } + int Run() override { + fp16_conv_kernel_->set_name(name_); + return fp16_conv_kernel_->Run(); + } private: uint8_t need_free_ = 0b00; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc index ec883276b1..5cfff64ce8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.cc @@ -102,6 +102,13 @@ int ConvolutionFP16CPUKernel::Init() { return RET_OK; } +void ConvolutionFP16CPUKernel::AdjustNumberOfThread() { + auto out_tensor = out_tensors_.front(); + int out_plane = out_tensor->Height() * out_tensor->Width(); + thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(out_plane, C16NUM)); + conv_param_->thread_num_ = thread_count_; +} + int ConvolutionFP16CPUKernel::ReSize() { auto ret = ConvolutionBaseCPUKernel::CheckResizeValid(); if (ret != RET_OK) { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h index 2c75a1f465..f8d34b20bc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_fp16.h @@ -44,6 +44,7 @@ class ConvolutionFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { int RunImpl(int task_id); int InitWeightBias(); int InitTmpBuffer(); + void AdjustNumberOfThread(); private: void FreeTmpBuffer() { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc index f86b58ef37..44528e4e07 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.cc @@ -108,7 +108,6 @@ int ConvolutionWinogradFP16CPUKernel::InitWeightBias() { int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { const int cal_num = 16; int channel_out = conv_param_->output_channel_; - int oc8 = UP_DIV(channel_out, C8NUM); size_t tile_buffer_size = thread_count_ * cal_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float16_t); @@ -118,8 +117,8 @@ int ConvolutionWinogradFP16CPUKernel::InitTmpBuffer() { return RET_ERROR; } - gemm_out_ = reinterpret_cast( - ctx_->allocator->Malloc(thread_count_ * cal_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float16_t))); + gemm_out_ = reinterpret_cast(ctx_->allocator->Malloc( + thread_count_ * cal_num * input_unit_ * input_unit_ * UP_ROUND(channel_out, C8NUM) * sizeof(float16_t))); if (gemm_out_ == nullptr) { MS_LOG(ERROR) << "malloc gemm_out_ failed."; return RET_ERROR; @@ -174,6 +173,13 @@ int ConvolutionWinogradFP16CPUKernel::Init() { return RET_OK; } +void ConvolutionWinogradFP16CPUKernel::AdjustNumberOfThread() { + auto out_tensor = out_tensors_.front(); + int cal_plane = UP_DIV(out_tensor->Height(), output_unit_) * UP_DIV(out_tensor->Width(), output_unit_); + thread_count_ = MSMIN(ctx_->thread_num_, UP_DIV(cal_plane, C8NUM)); + conv_param_->thread_num_ = thread_count_; +} + int ConvolutionWinogradFP16CPUKernel::ReSize() { auto ret = ConvolutionBaseCPUKernel::CheckResizeValid(); if (ret != RET_OK) { @@ -190,6 +196,7 @@ int ConvolutionWinogradFP16CPUKernel::ReSize() { MS_LOG(ERROR) << "ConfigInputOutput failed."; return RET_ERROR; } + AdjustNumberOfThread(); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h index a00dcad848..c1a4b3cfc4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/convolution_winograd_fp16.h @@ -52,6 +52,7 @@ class ConvolutionWinogradFP16CPUKernel : public ConvolutionBaseFP16CPUKernel { int InitTmpBuffer(); int ConfigInputOutput(); int WinogradFilterTransformFp16(const float16_t *weight_data, float *matrix_g, float *matrix_gt, int oc_block); + void AdjustNumberOfThread(); private: void FreeTmpBuffer() { diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc index 7d1f37ad83..428bd8f76e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.cc @@ -48,16 +48,9 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { conv_param_->input_channel_ = in_channel; conv_param_->output_channel_ = out_channel; - int oc4 = UP_DIV(out_channel, C4NUM); -#ifdef ENABLE_AVX - const int oc_block = C16NUM; -#else - const int oc_block = C8NUM; -#endif - int oc_block_num = UP_DIV(out_channel, oc_block); - // set data - auto trans_matrix_data_size = input_unit_ * input_unit_ * in_channel * oc_block_num * oc_block * sizeof(float); + auto trans_matrix_data_size = + input_unit_ * input_unit_ * in_channel * UP_ROUND(out_channel, oc_block_) * sizeof(float); if (trans_weight_ == nullptr) { trans_weight_ = reinterpret_cast(malloc(trans_matrix_data_size)); if (trans_weight_ == nullptr) { @@ -83,14 +76,15 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { MS_LOG(ERROR) << "get matrix g from CookToomFilter failed."; return ret; } - ret = WinogradFilterTransform(origin_weight_, matrix_g, matrix_gt, oc_block); + ret = WinogradFilterTransform(origin_weight_, matrix_g, matrix_gt, oc_block_); if (ret != RET_OK) { MS_LOG(ERROR) << "winograd filter transform failed."; return ret; } // init bias - size_t new_bias_size = oc4 * C4NUM * sizeof(float); + size_t new_bias_size = UP_ROUND(out_channel, C4NUM) * sizeof(float); + bias_data_ = malloc(new_bias_size); if (bias_data_ == nullptr) { bias_data_ = reinterpret_cast(malloc(new_bias_size)); if (bias_data_ == nullptr) { @@ -98,31 +92,30 @@ int ConvolutionWinogradCPUKernel::InitWeightBias() { return RET_MEMORY_FAILED; } } - memset(bias_data_, 0, new_bias_size); if (in_tensors_.size() == kInputSize2) { - memcpy(bias_data_, origin_bias_, out_channel * sizeof(float)); + size_t origin_size = out_channel * sizeof(float); + memcpy(bias_data_, origin_bias_, origin_size); + memset(reinterpret_cast(bias_data_) + out_channel, 0, new_bias_size - origin_size); } else { MS_ASSERT(in_tensors_.size() == kInputSize1); + memset(bias_data_, 0, new_bias_size); } return RET_OK; } int ConvolutionWinogradCPUKernel::InitTmpBuffer() { - int channel_out = conv_param_->output_channel_; - int oc8 = UP_DIV(channel_out, C8NUM); - int tile_num = C12NUM; MS_ASSERT(ctx_->allocator != nullptr); - size_t tile_buffer_size = - thread_count_ * tile_num * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float); + thread_count_ * tile_num_ * input_unit_ * input_unit_ * conv_param_->input_channel_ * sizeof(float); trans_input_ = reinterpret_cast(ctx_->allocator->Malloc(tile_buffer_size)); if (trans_input_ == nullptr) { MS_LOG(ERROR) << "malloc trans_input_ failed."; return RET_MEMORY_FAILED; } + int oc8 = UP_ROUND(conv_param_->output_channel_, C8NUM); gemm_out_ = reinterpret_cast( - ctx_->allocator->Malloc(thread_count_ * tile_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float))); + ctx_->allocator->Malloc(thread_count_ * tile_num_ * input_unit_ * input_unit_ * oc8 * sizeof(float))); if (gemm_out_ == nullptr) { MS_LOG(ERROR) << "malloc gemm_out_ failed."; return RET_ERROR; @@ -136,7 +129,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() { } col_buffer_ = reinterpret_cast( - ctx_->allocator->Malloc(thread_count_ * tile_num * conv_param_->input_channel_ * sizeof(float))); + ctx_->allocator->Malloc(thread_count_ * tile_num_ * conv_param_->input_channel_ * sizeof(float))); if (col_buffer_ == nullptr) { MS_LOG(ERROR) << "malloc col_buffer_ failed."; return RET_ERROR; @@ -164,10 +157,17 @@ int ConvolutionWinogradCPUKernel::ConfigInputOutput() { } int ConvolutionWinogradCPUKernel::Init() { + tile_num_ = C12NUM; +#ifdef ENABLE_AVX + oc_block_ = C16NUM; +#else + oc_block_ = C8NUM; +#endif kernel_unit_ = conv_param_->kernel_h_; input_unit_ = output_unit_ + kernel_unit_ - 1; conv_param_->input_unit_ = input_unit_; conv_param_->output_unit_ = output_unit_; + auto ret = InitWeightBias(); if (ret != RET_OK) { MS_LOG(ERROR) << "Init weight bias failed."; @@ -197,8 +197,8 @@ int ConvolutionWinogradCPUKernel::ReSize() { int ConvolutionWinogradCPUKernel::RunImpl(int task_id) { auto input_tensor = in_tensors_.at(kInputIndex); - auto ori_input_data = reinterpret_cast(input_tensor->MutableData()); - auto output_data = reinterpret_cast(out_tensors_.front()->MutableData()); + auto ori_input_data = reinterpret_cast(input_tensor->data_c()); + auto output_data = reinterpret_cast(out_tensors_.front()->data_c()); ConvWinogardFp32(ori_input_data, trans_weight_, reinterpret_cast(bias_data_), output_data, tmp_buffer_address_list_, task_id, conv_param_, in_func_, out_func_); return RET_OK; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h index 45ae9febd1..f2cb63b4e5 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/convolution_winograd_fp32.h @@ -70,9 +70,11 @@ class ConvolutionWinogradCPUKernel : public ConvolutionBaseCPUKernel { col_buffer_ = nullptr; } } - int kernel_unit_; - int input_unit_; + int kernel_unit_{0}; + int input_unit_{0}; int output_unit_; + int oc_block_{0}; + int tile_num_{0}; float *origin_weight_; // do not free float *origin_bias_; // do not free float *tmp_data_ = nullptr;