!4399 matmul optimize

Merge pull request !4399 from ling/fc
pull/4399/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 9afdf9be9b

@ -31,10 +31,6 @@ Convolution1x1CPUKernel::~Convolution1x1CPUKernel() {
free(pack_input_);
pack_input_ = nullptr;
}
if (pack_output_ != nullptr) {
free(pack_output_);
pack_output_ = nullptr;
}
if (pre_trans_input_ && input_ptr_ != nullptr) {
free(input_ptr_);
input_ptr_ = nullptr;
@ -112,13 +108,6 @@ int Convolution1x1CPUKernel::InitConv1x1Param() {
return RET_MEMORY_FAILED;
}
memset(pack_input_, 0, matmul_param_->row_8_ * matmul_param_->deep_ * sizeof(float));
pack_output_ = reinterpret_cast<float *>(malloc(matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float)));
if (pack_output_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 Malloc pack_output_ error!";
return RET_MEMORY_FAILED;
}
memset(pack_output_, 0, matmul_param_->row_8_ * matmul_param_->col_8_ * sizeof(float));
return RET_OK;
}
@ -157,7 +146,7 @@ int Convolution1x1CPUKernel::Init() {
}
int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
int cur_oc = MSMIN(thread_stride_, matmul_param_->col_8_ - task_id * thread_stride_);
int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
@ -165,23 +154,12 @@ int Convolution1x1CPUKernel::DoConv1x1(int task_id) {
auto bias = (bias_data_ == nullptr) ? nullptr : reinterpret_cast<float *>(bias_data_) + thread_stride_ * task_id;
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * matmul_param_->deep_,
pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_, bias, matmul_param_->act_type_,
matmul_param_->deep_, matmul_param_->row_8_, cur_oc);
output_ptr_ + task_id * thread_stride_, bias, matmul_param_->act_type_, matmul_param_->deep_,
matmul_param_->row_, cur_oc, matmul_param_->col_, true);
return RET_OK;
}
int Convolution1x1CPUKernel::DoConv1x1Post(int task_id) {
int cur_oc = MSMIN(thread_stride_, matmul_param_->col_ - task_id * thread_stride_);
if (cur_oc <= 0) {
return RET_OK;
}
float *src = pack_output_ + task_id * thread_stride_ * matmul_param_->row_8_;
float *dst = output_ptr_ + task_id * thread_stride_;
Row8x8Major2RowMajor(src, dst, matmul_param_->row_, cur_oc, matmul_param_->col_);
return RET_OK;
}
int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata);
auto error_code = conv1x1->DoConv1x1(task_id);
@ -192,12 +170,6 @@ int Convolution1x1Run(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
return RET_OK;
}
int Convolution1x1Post(int task_id, LiteParallelGroupEnv *penv, void *cdata) {
auto conv1x1 = reinterpret_cast<Convolution1x1CPUKernel *>(cdata);
conv1x1->DoConv1x1Post(task_id);
return RET_OK;
}
int Convolution1x1CPUKernel::Run() {
auto prepare_ret = Prepare();
if (prepare_ret != RET_OK) {
@ -216,8 +188,6 @@ int Convolution1x1CPUKernel::Run() {
MS_LOG(ERROR) << "conv1x1 strassen error error_code[" << error_code << "]";
return RET_ERROR;
}
LiteBackendParallelLaunch(Convolution1x1Post, this, thread_count_);
}
return RET_OK;
}

@ -46,7 +46,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
public:
int DoConv1x1(int task_id);
int DoConv1x1Post(int task_id);
private:
int InitConv1x1Param();
@ -61,7 +60,6 @@ class Convolution1x1CPUKernel : public ConvolutionBaseCPUKernel {
int thread_stride_ = 0;
float *weight_ptr_ = nullptr;
float *pack_input_ = nullptr;
float *pack_output_ = nullptr;
float *input_ptr_ = nullptr;
float *output_ptr_ = nullptr;
};

@ -152,7 +152,7 @@ int DeConvolutionCPUKernel::DoDeconv(int task_id) {
MatMul(pack_input_, weight_ptr_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->deep_,
tmp_buffer_ + task_id * thread_stride_ * C8NUM * kernel_plane_ * matmul_param_->row_8_, nullptr, ActType_No,
matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_);
matmul_param_->deep_, matmul_param_->row_8_, oc * C8NUM * kernel_plane_, matmul_param_->col_, false);
return RET_OK;
}

@ -104,7 +104,7 @@ int FullconnectionCPUKernel::DoMatmul(int task_id) {
MatMul(a_c8_ptr_, b_r8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->deep_,
c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * fc_param_->row_8_,
bias_ptr_ + task_id * thread_stride_ * C8NUM, fc_param_->act_type_, fc_param_->deep_, fc_param_->row_8_,
cur_oc * 8);
cur_oc * 8, 0, false);
return RET_OK;
}

@ -77,7 +77,7 @@ int MatmulCPUKernel::RunImpl(int task_id) {
}
auto cur_b = b_r8_ptr_ + task_id * thread_stride_ * C8NUM * params_->deep_;
auto cur_c = c_r8x8_ptr_ + task_id * thread_stride_ * C8NUM * params_->row_8_;
MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8);
MatMul(a_c8_ptr_, cur_b, cur_c, NULL, ActType_No, params_->deep_, params_->row_8_, cur_oc * 8, 0, false);
return RET_OK;
}

@ -640,7 +640,7 @@ IndirectGemmStart:
add x15, x15, x7
str s30, [x15]
add x0, x0, #4
b WriteEnd
b WriteEndHalf
Write2:
dup s17, v16.s[1]
stp s16, s17, [x15]
@ -666,7 +666,7 @@ IndirectGemmStart:
dup s31, v30.s[1]
stp s30, s31, [x15]
add x0, x0, #8
b WriteEnd
b WriteEndHalf
Write3:
add x17, x15, #8
dup s17, v16.s[1]

@ -221,34 +221,57 @@ void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col
return;
}
void MatMul8x8(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_,
int col_8_) {
/* col8-major * row8-major => col8x8-major */
for (int row = 0; row < row_8_; row++) {
for (int col = 0; col < col_8_; col++) {
int r8div = row / 8, r8mod = row % 8;
int c8div = col / 8, c8mod = col % 8;
size_t ci = c8div * row_8_ * 8 + row * 8 + c8mod;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
void MatMul8x8(const float *a, const float *b, float *dst, const float *bias, ActType act_type, int deep, int row,
int col, int stride, bool write_nhwc) {
if (write_nhwc) {
/* col8-major * row8-major => col-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r8div = r / 8, r8mod = r % 8;
int c8div = c / 8, c8mod = c % 8;
size_t ci = r * stride + c;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
}
} else {
/* col8-major * row8-major => col8x8-major */
int col_8 = UP_ROUND(col, C8NUM);
int row_8 = UP_ROUND(row, C8NUM);
for (int r = 0; r < row_8; r++) {
for (int c = 0; c < col_8; c++) {
int r8div = r / 8, r8mod = r % 8;
int c8div = c / 8, c8mod = c % 8;
size_t ci = c8div * row_8 * 8 + r * 8 + c8mod;
float value = 0;
for (int d = 0; d < deep; d++) {
size_t ai = r8div * deep * 8 + d * 8 + r8mod;
size_t bi = c8div * deep * 8 + d * 8 + c8mod;
value = value + a[ai] * b[bi];
}
if (bias != NULL) value += bias[c];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
dst[ci] = value;
}
if (bias != NULL) value += bias[col];
if (act_type == ActType_Relu6) value = MSMIN(6.0f, value);
if (act_type != ActType_No) value = MSMAX(0.0f, value);
c[ci] = value;
}
}
return;
}
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row_8_,
int col_8_) {
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int deep, int row, int col,
int stride, bool write_nhwc) {
#ifdef __aarch64__
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row_8_, col_8_);
MatmulFloatNeon64(a, b, c, bias, (int)act_type, deep, row, col, stride, write_nhwc);
#else
MatMul8x8(a, b, c, bias, act_type, deep, row_8_, col_8_);
MatMul8x8(a, b, c, bias, act_type, deep, row, col, stride, write_nhwc);
#endif
}

@ -26,13 +26,14 @@
#ifdef __cplusplus
extern "C" {
#endif
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col);
void MatMul(const float *a, const float *b, float *c, const float *bias, ActType act_type, int depth, int row, int col,
int stride, bool write_nhwc);
void RowMajor2Row8Major(float *src_ptr, float *dst_ptr, int row, int col);
void RowMajor2Col8Major(float *src_ptr, float *dst_ptr, size_t row, size_t col);
void Row8x8Major2RowMajor(float *src_ptr, float *dst_ptr, size_t row, size_t col, size_t stride);
#ifdef __aarch64__
void MatmulFloatNeon64(const float *a, const float *b, float *c, const float *bias, int act_type, int depth, int row,
int col);
int col, size_t stride, bool write_nhwc);
#endif
#ifdef __cplusplus
}

@ -370,26 +370,35 @@ TEST_F(TestConv1x1Fp32, Conv1x1Test2) {
conv1x1->Run();
CompareOutputData(reinterpret_cast<float *>(outputs_[0]->Data()), correct, total_size, 0.0001);
/* running warm up */
for (int i = 0; i < 0; i++) {
conv1x1->Run();
auto ptr = reinterpret_cast<float *>(outputs_[0]->Data());
bool first = true;
for (int i = 0; i < total_size; i++) {
if (fabs(ptr[i] - correct[i]) > 0.001 && first) {
printf("%d %f %f\n", i, ptr[i], correct[i]);
first = false;
}
}
/* running time cost */
int loop_count = 1;
auto time_start = mindspore::lite::GetTimeUs();
for (int i = 0; i < loop_count; i++) {
conv1x1->Run();
}
auto time_end = mindspore::lite::GetTimeUs();
auto cost = time_end - time_start;
uint64_t time_avg = cost / loop_count;
printf("1x1 average time : %f ms\n", time_avg / 1000.0f);
delete conv_param;
delete conv1x1;
for (auto t : inputs_) delete t;
for (auto t : outputs_) delete t;
free(correct);
// /* running warm up */
// for (int i = 0; i < 0; i++) {
// conv1x1->Run();
// }
//
// /* running time cost */
// int loop_count = 1;
// auto time_start = mindspore::lite::GetTimeUs();
// for (int i = 0; i < loop_count; i++) {
// conv1x1->Run();
// }
// auto time_end = mindspore::lite::GetTimeUs();
// auto cost = time_end - time_start;
// uint64_t time_avg = cost / loop_count;
// printf("1x1 average time : %f ms\n", time_avg / 1000.0f);
//
// delete conv_param;
// delete conv1x1;
// for (auto t : inputs_) delete t;
// for (auto t : outputs_) delete t;
// free(correct);
}
} // namespace mindspore

Loading…
Cancel
Save