!6081 optimization for matmul on arm32

Merge pull request !6081 from lixian/master
pull/6081/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit fe902fe1f9

File diff suppressed because it is too large Load Diff

@ -112,7 +112,8 @@ void IndirectGemmFp32_8x8(float *output, const float *input, const float *weight
}
}
#endif
// #ifndef ENABLE_ARM32
#ifndef ENABLE_ARM32
void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight, const float *bias, size_t step,
size_t ic4, size_t output_channel, size_t offset, size_t mode, size_t writeC4, size_t relu,
size_t relu6) {
@ -155,7 +156,7 @@ void IndirectGemmFp32_8x4(float *output, const float *input, const float *weight
}
}
}
// #endif
#endif
int8_t MinInt8(int8_t a, int8_t b) { return b ^ ((a ^ b) & -(a < b)); }

@ -270,7 +270,12 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, C12NUM);
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
int output_tile_count = UP_DIV(output_count, tile_num);
int out_channel = conv_param->output_channel_;
int oc4 = UP_DIV(out_channel, C4NUM);
int oc8 = UP_DIV(out_channel, C8NUM);
@ -281,19 +286,19 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
float *tmp_out_data = buffer_list[2];
float *tmp_data = buffer_list[3];
float *col_buffer = buffer_list[4];
int trans_input_offset = C12NUM * input_unit_square * ic4 * C4NUM;
int gemm_out_offset = C12NUM * input_unit_square * oc8 * C8NUM;
int trans_input_offset = tile_num * input_unit_square * ic4 * C4NUM;
int gemm_out_offset = tile_num * input_unit_square * oc8 * C8NUM;
int tmp_data_offset = input_unit_square * C4NUM;
int col_buffer_offset = C12NUM * ic4 * C4NUM;
int col_buffer_offset = tile_num * ic4 * C4NUM;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * ic4 * C4NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = b * out_w_block * out_h_block * out_unit * out_unit * oc4 * C4NUM;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_tile_index = thread_id * C12NUM;
int cal_num = output_count - thread_id * C12NUM;
cal_num = cal_num > C12NUM ? C12NUM : cal_num;
int out_tile_index = thread_id * tile_num;
int cal_num = output_count - thread_id * tile_num;
cal_num = cal_num > tile_num ? tile_num : cal_num;
WinogradInputTransform(input_data + in_batch_offset, trans_input + task_id * trans_input_offset,
tmp_data + task_id * tmp_data_offset, cal_num, out_tile_index, out_w_block, conv_param,
in_func);
@ -302,7 +307,11 @@ void ConvWinogardFp32(float *input_data, float *trans_weight, const float *bias_
float *dst_ptr = gemm_out + task_id * gemm_out_offset;
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
#ifdef ENABLE_ARM32
RowMajor2Col4Major(src_ptr + i * C4NUM * ic4 * C4NUM, tmp_col_ptr, C4NUM, ic4 * C4NUM);
#else
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
#endif
MatMulOpt(tmp_col_ptr, trans_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0, ic4 * C4NUM,
cal_num, oc8 * C8NUM, input_unit_square, 2);
}
@ -460,7 +469,12 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, C12NUM);
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
int output_tile_count = UP_DIV(output_count, tile_num);
const int input_unit_square = 4 * 4;
float *tile_buffer = buffer_list[0];
@ -468,10 +482,10 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
float *tmp_dst_buffer = buffer_list[2];
float *nc4hw4_out = buffer_list[3];
float *col_buffer = buffer_list[4];
int tile_buffer_offset = C12NUM * input_unit_square * ic4 * C4NUM;
int tile_buffer_offset = tile_num * input_unit_square * ic4 * C4NUM;
int block_unit_buffer_offset = input_unit_square * C4NUM;
int tmp_dst_buffer_offset = C12NUM * input_unit_square * oc8 * C8NUM;
int col_buffer_offset = C12NUM * ic4 * C4NUM;
int tmp_dst_buffer_offset = tile_num * input_unit_square * oc8 * C8NUM;
int col_buffer_offset = tile_num * ic4 * C4NUM;
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
@ -479,8 +493,8 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
int nc4hw4_buffer_offset = batch * oc4 * C4NUM * conv_param->output_h_ * conv_param->output_w_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
int start_index = thread_id * C12NUM;
int real_cal_num = (output_count - start_index) < C12NUM ? (output_count - start_index) : C12NUM;
int start_index = thread_id * tile_num;
int real_cal_num = (output_count - start_index) < tile_num ? (output_count - start_index) : tile_num;
Conv3x3Fp32InputTransform(input_data + in_batch_offset, tile_buffer + task_id * tile_buffer_offset,
block_unit_buffer + task_id * block_unit_buffer_offset, start_index, real_cal_num,
out_w_block, conv_param);
@ -489,7 +503,11 @@ void Conv3x3Fp32(float *input_data, float *transed_weight, const float *bias_dat
float *tmp_col_ptr = col_buffer + task_id * col_buffer_offset;
float *dst_ptr = tmp_dst_buffer + task_id * tmp_dst_buffer_offset;
for (int i = 0; i < input_unit_square; ++i) {
#ifdef ENABLE_ARM32
RowMajor2Col4Major(src_ptr + i * C4NUM * ic4 * C4NUM, tmp_col_ptr, C4NUM, ic4 * C4NUM);
#else
RowMajor2Col12Major(src_ptr + i * C12NUM * ic4 * C4NUM, tmp_col_ptr, C12NUM, ic4 * C4NUM);
#endif
MatMulOpt(tmp_col_ptr, transed_weight + i * ic4 * C4NUM * oc8 * C8NUM, dst_ptr + i * C8NUM, NULL, 0,
ic4 * C4NUM, real_cal_num, oc8 * C8NUM, input_unit_square, 2);
}

@ -40,7 +40,12 @@ int DeConvPostFp32C12x8(const float *src, float *tmp, const float *bias, float *
size_t kernel_plane = conv_param->kernel_w_ * conv_param->kernel_h_;
size_t output_plane = conv_param->output_w_ * conv_param->output_h_;
int oc8 = UP_ROUND(output_channel, C8NUM);
int in_plane12 = UP_ROUND(input_plane, C12NUM);
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
int in_plane12 = UP_ROUND(input_plane, tile_num);
int src_iw_stride = C8NUM;
int src_ih_stride = conv_param->input_w_ * C8NUM;
int src_kw_stride = in_plane12 * C8NUM;

@ -16,6 +16,18 @@
#include "nnacl/fp32/matmul.h"
void RowMajor2Row4Major(float *src_ptr, float *dst_ptr, int row, int col) {
for (int r = 0; r < row; r++) {
float *src = src_ptr + r * col;
for (int c = 0; c < col; c++) {
int cd8 = c / 4;
int cm8 = c % 4;
dst_ptr[cd8 * 4 * row + r * 4 + cm8] = src[c];
}
}
return;
}
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col) {
for (int r = 0; r < row; r++) {
float *src = src_ptr + r * col;
@ -115,6 +127,61 @@ void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#elif ENABLE_ARM32
size_t stride = col * sizeof(float);
asm volatile(
"mov r10, %[src_c]\n"
"mov r12, %[dst_c]\n"
"vld1.32 {q0}, [r10], %[stride]\n"
"vld1.32 {q3}, [r10], %[stride]\n"
"vld1.32 {q10}, [r10], %[stride]\n"
"vld1.32 {q13}, [r10], %[stride]\n"
"vtrn.32 d0, d6\n"
"vtrn.32 d1, d7\n"
"vtrn.32 d20, d26\n"
"vtrn.32 d21, d27\n"
"vld1.32 {q1}, [r10], %[stride]\n"
"vld1.32 {q8}, [r10], %[stride]\n"
"vld1.32 {q11}, [r10], %[stride]\n"
"vld1.32 {q14}, [r10], %[stride]\n"
"vswp d1, d20\n"
"vswp d7, d26\n"
"vld1.32 {q2}, [r10], %[stride]\n"
"vld1.32 {q9}, [r10], %[stride]\n"
"vld1.32 {q12}, [r10], %[stride]\n"
"vld1.32 {q15}, [r10], %[stride]\n"
"vtrn.32 d2, d16\n"
"vtrn.32 d3, d17\n"
"vtrn.32 d22, d28\n"
"vtrn.32 d23, d29\n"
"vswp d3, d22\n"
"vswp d17, d28\n"
"vtrn.32 d4, d18\n"
"vtrn.32 d5, d19\n"
"vtrn.32 d24, d30\n"
"vtrn.32 d25, d31\n"
"vswp d5, d24\n"
"vswp d19, d30\n"
"vst1.32 {q0, q1}, [r12]!\n"
"vst1.32 {q2, q3}, [r12]!\n"
"vst1.32 {q8, q9}, [r12]!\n"
"vst1.32 {q10, q11}, [r12]!\n"
"vst1.32 {q12, q13}, [r12]!\n"
"vst1.32 {q14, q15}, [r12]!\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "r10", "r12", "q0", "q1", "q2", "q3", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15");
#else
for (int tr = 0; tr < C12NUM; tr++) {
for (int tc = 0; tc < C4NUM; tc++) {
@ -242,6 +309,75 @@ void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col)
return;
}
void RowMajor2Col4Major(float *src_ptr, float *dst_ptr, size_t row, size_t col) {
size_t row8 = row / C4NUM * C4NUM;
size_t col4 = col / C4NUM * C4NUM;
float *src_r = src_ptr;
float *dst_r = dst_ptr;
size_t ri = 0;
for (; ri < row8; ri += C4NUM) {
size_t ci = 0;
for (; ci < col4; ci += C4NUM) {
float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C4NUM;
/* 4x4 row-major to col-major */
#ifdef ENABLE_ARM32
size_t stride = col * 4;
asm volatile(
"mov r10, %[src_c]\n"
"mov r12, %[dst_c]\n"
"vld1.32 {q0}, [r10], %[stride]\n"
"vld1.32 {q1}, [r10], %[stride]\n"
"vld1.32 {q2}, [r10], %[stride]\n"
"vld1.32 {q3}, [r10], %[stride]\n"
"vtrn.32 d0, d2\n"
"vtrn.32 d1, d3\n"
"vtrn.32 d4, d6\n"
"vtrn.32 d5, d7\n"
"vswp d1, d4\n"
"vswp d3, d6\n"
"vst1.32 {q0}, [r12]!\n"
"vst1.32 {q1}, [r12]!\n"
"vst1.32 {q2}, [r12]!\n"
"vst1.32 {q3}, [r12]!\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ stride ] "r"(stride)
: "r10", "r12", "q0", "q1", "q2", "q3");
#else
for (int tr = 0; tr < C4NUM; tr++) {
for (int tc = 0; tc < C4NUM; tc++) {
dst_c[tc * C4NUM + tr] = src_c[tr * col + tc];
}
}
#endif
}
for (; ci < col; ci++) {
float *src_c = src_r + ci;
float *dst_c = dst_r + ci * C4NUM;
for (size_t i = 0; i < C4NUM; i++) {
dst_c[i] = src_c[i * col];
}
}
src_r += C4NUM * col;
dst_r += C4NUM * col;
}
for (; ri < row; ri++) {
for (size_t i = 0; i < col; i++) {
dst_r[i * C4NUM] = src_r[i];
}
src_r += col;
dst_r += 1;
}
return;
}
void MatrixUnPackUnit(const void *src, void *dst, size_t row, size_t col, size_t src_stride, size_t dst_stride,
size_t data_lenth) {
size_t copy_size = col * data_lenth;
@ -418,6 +554,9 @@ void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActT
MatmulFloatNeon64Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_TileC8));
}
#elif ENABLE_ARM32
MatmulFloatNeon32Opt(a, b, c, bias, (int)act_type, deep, row, col, stride, (int)(out_type == OutType_Nhwc),
(int)(out_type == OutType_TileC8));
#else
MatMul12x8(a, b, c, bias, act_type, deep, row, col, stride, out_type);
#endif

@ -29,8 +29,10 @@ extern "C" {
void MatMulOpt(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row,
int col, size_t stride, int out_type);
void RowMajor2Row4Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Row12Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col4Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void RowMajor2Col12Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
@ -40,6 +42,9 @@ void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bi
void MatmulFloatNeon64Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, size_t write_nhwc, size_t write_c4);
void MatmulFloatNeon64OptRemain(const float *a, const float *b, float *c, int depth, int row, int col, size_t stride);
#elif ENABLE_ARM32
void MatmulFloatNeon32Opt(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col, size_t stride, size_t write_nhwc, size_t write_c4);
#endif
#ifdef __cplusplus
}

@ -1223,6 +1223,78 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
: "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
"v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29",
"v30", "v31");
#elif ENABLE_ARM32
size_t srcStride = channel * sizeof(float);
size_t dstStride = plane * sizeof(float);
asm volatile(
"mov r10, %[src_ptr]\n"
"mov r12, %[dst_ptr]\n"
"vld1.32 {q0, q1}, [r10], %[srcStride]\n"
"vld1.32 {q2, q3}, [r10], %[srcStride]\n"
"vtrn.32 d0, d4\n"
"vtrn.32 d1, d5\n"
"vtrn.32 d2, d6\n"
"vtrn.32 d3, d7\n"
"vld1.32 {q4, q5}, [r10], %[srcStride]\n"
"vld1.32 {q6, q7}, [r10], %[srcStride]\n"
"vtrn.32 d8, d12\n"
"vtrn.32 d9, d13\n"
"vtrn.32 d10, d14\n"
"vtrn.32 d11, d15\n"
"vld1.32 {q8, q9}, [r10], %[srcStride]\n"
"vld1.32 {q10, q11}, [r10], %[srcStride]\n"
"vswp d1, d8\n"
"vswp d3, d10\n"
"vswp d5, d12\n"
"vswp d7, d14\n"
"vtrn.32 d16, d20\n"
"vtrn.32 d17, d21\n"
"vtrn.32 d18, d22\n"
"vtrn.32 d19, d23\n"
"vld1.32 {q12, q13}, [r10], %[srcStride]\n"
"vld1.32 {q14, q15}, [r10], %[srcStride]\n"
"vtrn.32 d24, d28\n"
"vtrn.32 d25, d29\n"
"vtrn.32 d26, d30\n"
"vtrn.32 d27, d31\n"
"vswp d17, d24\n"
"vswp d19, d26\n"
"vswp d21, d28\n"
"vswp d23, d30\n"
"add r10, r12, #16\n"
"vst1.32 {q0}, [r12], %[dstStride]\n"
"vst1.32 {q8}, [r10], %[dstStride]\n"
"vst1.32 {q2}, [r12], %[dstStride]\n"
"vst1.32 {q10}, [r10], %[dstStride]\n"
"vst1.32 {q4}, [r12], %[dstStride]\n"
"vst1.32 {q12}, [r10], %[dstStride]\n"
"vst1.32 {q6}, [r12], %[dstStride]\n"
"vst1.32 {q14}, [r10], %[dstStride]\n"
"vst1.32 {q1}, [r12], %[dstStride]\n"
"vst1.32 {q9}, [r10], %[dstStride]\n"
"vst1.32 {q3}, [r12], %[dstStride]\n"
"vst1.32 {q11}, [r10], %[dstStride]\n"
"vst1.32 {q5}, [r12], %[dstStride]\n"
"vst1.32 {q13}, [r10], %[dstStride]\n"
"vst1.32 {q7}, [r12], %[dstStride]\n"
"vst1.32 {q15}, [r10], %[dstStride]\n"
:
:
[ dst_ptr ] "r"(dst_ptr), [ src_ptr ] "r"(src_ptr), [ srcStride ] "r"(srcStride), [ dstStride ] "r"(dstStride)
: "r10", "r12", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14",
"q15");
#else
for (int tr = 0; tr < C8NUM; tr++) {
for (int tc = 0; tc < C8NUM; tc++) {

@ -67,8 +67,13 @@ void WinogradInputTransform(const float *input_data, float *trans_input, float *
}
}
// input transform
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
size_t dst_step = C12NUM * ic4 * C4NUM;
size_t dst_step = tile_num * ic4 * C4NUM;
float *trans_input_ptr = trans_input + dst_ic4_offset;
func(tmp_data, trans_input_ptr, C4NUM, dst_step);
// GeneralInputTransformUnit(tmp_data, trans_input_ptr, matrix_b, matrix_bt, C4NUM, dst_step, input_unit);
@ -331,8 +336,13 @@ void Conv3x3Fp32InputTransform(const float *input_data, float *trans_input, floa
}
// input transform
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
int dst_ic4_offset = dst_plane_offset + ic * C4NUM;
size_t dst_step = C12NUM * ic4 * C4NUM;
size_t dst_step = tile_num * ic4 * C4NUM;
float *trans_input_ptr = trans_input + dst_ic4_offset;
Conv3x3Fp32InputUnit(tmp_data, trans_input_ptr, dst_step);
}

@ -26,15 +26,13 @@ if (PLATFORM_ARM64)
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
endif()
#[[
if (PLATFORM_ARM32)
# assembly
file(GLOB ASSEMBLY_SRC nnacl/assembly/arm32/*.s
nnacl/assembly/arm32/*.S
file(GLOB ASSEMBLY_SRC ${CMAKE_CURRENT_SOURCE_DIR}/../../../../nnacl/assembly/arm32/*.s
${CMAKE_CURRENT_SOURCE_DIR}/../../../../nnacl/assembly/arm32/*.S
)
set_property(SOURCE ${ASSEMBLY_SRC} PROPERTY LANGUAGE C)
set(KERNEL_SRC ${KERNEL_SRC} ${ASSEMBLY_SRC})
endif()
]]
add_library(cpu_kernel_mid_ OBJECT ${KERNEL_SRC} ${TRAIN_KERNEL_SRC})

@ -59,6 +59,7 @@ void Convolution1x1CPUKernel::InitConv1x1MatmulParam() {
matmul_param_->row_ = conv_param_->output_h_ * conv_param_->output_w_;
matmul_param_->col_ = conv_param_->output_channel_;
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM);
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
matmul_param_->col_8_ = UP_ROUND(matmul_param_->col_, C8NUM);
matmul_param_->act_type_ = conv_param_->act_type_;
@ -120,8 +121,11 @@ void Convolution1x1CPUKernel::Pre1x1Trans(float *src_input, float *src_output) {
} else {
input_ptr_ = src_input;
}
#ifdef ENABLE_ARM32
RowMajor2Col4Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#else
RowMajor2Col12Major(input_ptr_, pack_input_, matmul_param_->row_, matmul_param_->deep_);
#endif
return;
}
@ -169,8 +173,13 @@ int Convolution1x1CPUKernel::Run() {
auto src_in = reinterpret_cast<float *>(in_tensors_[0]->MutableData());
auto src_out = reinterpret_cast<float *>(out_tensors_[0]->MutableData());
#ifdef ENABLE_ARM32
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float)));
#else
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)));
#endif
if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc pack_input_ error!";
return RET_MEMORY_FAILED;

@ -95,7 +95,12 @@ int Convolution3x3CPUKernel::InitTmpBuffer() {
const int k_plane = 16;
MS_ASSERT(ctx_->allocator != nullptr);
size_t tile_buffer_size = thread_count_ * C12NUM * C16NUM * ic4 * C4NUM * sizeof(float);
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
size_t tile_buffer_size = thread_count_ * tile_num * C16NUM * ic4 * C4NUM * sizeof(float);
tile_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
if (tile_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tile buffer failed.";
@ -109,14 +114,14 @@ int Convolution3x3CPUKernel::InitTmpBuffer() {
return RET_ERROR;
}
size_t tmp_dst_buffer_size = thread_count_ * C12NUM * k_plane * oC8 * C8NUM * sizeof(float);
size_t tmp_dst_buffer_size = thread_count_ * tile_num * k_plane * oC8 * C8NUM * sizeof(float);
tmp_dst_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tmp_dst_buffer_size));
if (tmp_dst_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc tmp_dst_buffer_ failed.";
return RET_ERROR;
}
size_t col_buffer_size = thread_count_ * C12NUM * C4NUM * ic4 * sizeof(float);
size_t col_buffer_size = thread_count_ * tile_num * C4NUM * ic4 * sizeof(float);
col_buffer_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(col_buffer_size));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";

@ -150,9 +150,14 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
int oc4 = UP_DIV(channel_out, C4NUM);
int oc8 = UP_DIV(channel_out, C8NUM);
int ic4 = UP_DIV(conv_param_->input_channel_, C4NUM);
#ifdef ENABLE_ARM32
int tile_num = 4;
#else
int tile_num = 12;
#endif
MS_ASSERT(ctx_->allocator != nullptr);
size_t tile_buffer_size = thread_count_ * C12NUM * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float);
size_t tile_buffer_size = thread_count_ * tile_num * input_unit_ * input_unit_ * ic4 * C4NUM * sizeof(float);
trans_input_ = reinterpret_cast<float *>(ctx_->allocator->Malloc(tile_buffer_size));
if (trans_input_ == nullptr) {
MS_LOG(ERROR) << "malloc trans_input_ failed.";
@ -160,7 +165,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
}
gemm_out_ = reinterpret_cast<float *>(
ctx_->allocator->Malloc(thread_count_ * C12NUM * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float)));
ctx_->allocator->Malloc(thread_count_ * tile_num * input_unit_ * input_unit_ * oc8 * C8NUM * sizeof(float)));
if (gemm_out_ == nullptr) {
MS_LOG(ERROR) << "malloc gemm_out_ failed.";
return RET_ERROR;
@ -184,7 +189,7 @@ int ConvolutionWinogradCPUKernel::InitTmpBuffer() {
}
col_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * C12NUM * ic4 * C4NUM * sizeof(float)));
reinterpret_cast<float *>(ctx_->allocator->Malloc(thread_count_ * tile_num * ic4 * C4NUM * sizeof(float)));
if (col_buffer_ == nullptr) {
MS_LOG(ERROR) << "malloc col_buffer_ failed.";
return RET_ERROR;

@ -85,6 +85,7 @@ int DeConvolutionCPUKernel::InitParam() {
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->col_ = conv_param_->output_channel_ * kernel_plane_;
matmul_param_->row_12_ = UP_ROUND(matmul_param_->row_, C12NUM);
matmul_param_->row_4_ = UP_ROUND(matmul_param_->row_, C4NUM);
matmul_param_->col_8_ = UP_ROUND(conv_param_->output_channel_, C8NUM) * kernel_plane_;
thread_count_ = MSMIN(op_parameter_->thread_num_, UP_DIV(conv_param_->output_channel_, C8NUM));
@ -112,10 +113,17 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
return RET_OK;
}
#ifdef ENABLE_ARM32
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_4_;
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_4_, oc * C8NUM * kernel_plane_,
matmul_param_->col_, OutType_C8);
#else
auto tmp_buffer = tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_12_;
MatMulOpt(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer, nullptr, ActType_No, matmul_param_->deep_, matmul_param_->row_12_, oc * C8NUM * kernel_plane_,
matmul_param_->col_, OutType_C8);
#endif
DeConvPostFp32C12x8(tmp_buffer, pack_output_ + task_id * thread_stride_ * C8NUM * output_plane_,
reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id * C8NUM,
@ -159,15 +167,25 @@ int DeConvolutionCPUKernel::InitRunBuf() {
return RET_NULL_PTR;
}
#ifdef ENABLE_ARM32
tmp_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->col_8_ * sizeof(float)));
#else
tmp_buffer_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->col_8_ * sizeof(float)));
#endif
if (tmp_buffer_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc tmp_buffer_ error!";
return RET_NULL_PTR;
}
#ifdef ENABLE_ARM32
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_4_ * matmul_param_->deep_ * sizeof(float)));
#else
pack_input_ =
reinterpret_cast<float *>(ctx_->allocator->Malloc(matmul_param_->row_12_ * matmul_param_->deep_ * sizeof(float)));
#endif
if (pack_input_ == nullptr) {
MS_LOG(ERROR) << "deconv Malloc pack_input_ error!";
return RET_ERROR;

@ -49,6 +49,7 @@ int FullconnectionCPUKernel::ReSize() {
fc_param_->row_12_ = UP_ROUND(fc_param_->row_, C12NUM);
fc_param_->col_8_ = UP_ROUND(fc_param_->col_, C8NUM);
fc_param_->row_4_ = UP_ROUND(fc_param_->row_, C4NUM);
thread_count_ = MSMIN(thread_count_, UP_DIV(fc_param_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(fc_param_->col_8_, 8), thread_count_);
@ -59,11 +60,19 @@ int FullconnectionCPUKernel::ReSize() {
memcpy(bias_ptr_, in_tensors_[2]->MutableData(), fc_param_->col_ * sizeof(float));
}
#ifdef ENABLE_ARM32
a_c12_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_4_ * fc_param_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
return RET_MEMORY_FAILED;
}
memset(a_c12_ptr_, 0, fc_param_->row_4_ * fc_param_->deep_ * sizeof(float));
#else
a_c12_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->row_12_ * fc_param_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
return RET_MEMORY_FAILED;
}
memset(a_c12_ptr_, 0, fc_param_->row_12_ * fc_param_->deep_ * sizeof(float));
#endif
b_r8_ptr_ = reinterpret_cast<float *>(malloc(fc_param_->col_8_ * fc_param_->deep_ * sizeof(float)));
if (b_r8_ptr_ == nullptr) {
@ -87,7 +96,11 @@ int FullconnectionCPUKernel::Init() {
}
void FullconnectionCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
#ifdef ENABLE_ARM32
RowMajor2Col4Major(src_ptr, a_c12_ptr_, fc_param_->row_, fc_param_->deep_);
#else
RowMajor2Col12Major(src_ptr, a_c12_ptr_, fc_param_->row_, fc_param_->deep_);
#endif
}
void FullconnectionCPUKernel::InitMatrixB(float *src_ptr, float *dst_ptr) {

@ -62,17 +62,27 @@ int MatmulCPUKernel::ReSize() {
params_->row_ = c_shape[c_shape.size() - 2];
params_->col_ = c_shape[c_shape.size() - 1];
params_->deep_ = params_->a_transpose_ ? a_shape[a_shape.size() - 2] : a_shape[a_shape.size() - 1];
params_->row_4_ = UP_ROUND(params_->row_, C4NUM);
params_->row_12_ = UP_ROUND(params_->row_, C12NUM);
params_->col_8_ = UP_ROUND(params_->col_, 8);
thread_count_ = MSMIN(thread_count_, UP_DIV(params_->col_8_, 8));
thread_stride_ = UP_DIV(UP_DIV(params_->col_8_, 8), thread_count_);
#ifdef ENABLE_ARM32
a_c12_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->row_4_ * params_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(a_c12_ptr_, 0, params_->row_4_ * params_->deep_ * sizeof(float));
#else
a_c12_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->row_12_ * params_->deep_ * sizeof(float)));
if (a_c12_ptr_ == nullptr) {
FreeTmpBuffer();
return RET_MEMORY_FAILED;
}
memset(a_c12_ptr_, 0, params_->row_12_ * params_->deep_ * sizeof(float));
#endif
b_r8_ptr_ = reinterpret_cast<float *>(malloc(params_->batch * params_->col_8_ * params_->deep_ * sizeof(float)));
if (b_r8_ptr_ == nullptr) {
@ -106,12 +116,21 @@ int MatmulCPUKernel::ReSize() {
void MatmulCPUKernel::InitMatrixA(float *src_ptr, float *dst_ptr) {
for (int i = 0; i < params_->batch; i++) {
float *src = src_ptr + i * params_->deep_ * params_->row_;
#ifdef ENABLE_ARM32
float *dst = dst_ptr + i * params_->deep_ * params_->row_4_;
if (params_->a_transpose_) {
RowMajor2Row4Major(src, dst, params_->deep_, params_->row_);
} else {
RowMajor2Col4Major(src, dst, params_->row_, params_->deep_);
}
#else
float *dst = dst_ptr + i * params_->deep_ * params_->row_12_;
if (params_->a_transpose_) {
RowMajor2Row12Major(src, dst, params_->deep_, params_->row_);
} else {
RowMajor2Col12Major(src, dst, params_->row_, params_->deep_);
}
#endif
}
return;
}

@ -79,7 +79,7 @@ if (PLATFORM_ARM64)
${TEST_ASSEMBLY_SRC}
)
endif()
#[[
if (PLATFORM_ARM32)
# assembly
file(GLOB TEST_ASSEMBLY_SRC
@ -91,7 +91,7 @@ if (PLATFORM_ARM32)
${TEST_ASSEMBLY_SRC}
)
endif()
]]
if (ENABLE_FP16)
file(GLOB KERNEL_OP_FP16_SRC
${LITE_DIR}/src/runtime/kernel/arm/fp16/*.cc

Loading…
Cancel
Save