From 8870ca716c65a854f6eebafd31ec7c0006e549d2 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Wed, 23 Sep 2020 16:06:16 +0800 Subject: [PATCH] Fix the bug of fp16 fc tensor pack --- .../kernel/arm/fp16/fullconnection_fp16.cc | 21 ++++++++++++++++++- .../kernel/arm/fp16/fullconnection_fp16.h | 1 + 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc index b1c4bf44ec..bfded12ce2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.cc @@ -78,7 +78,15 @@ int FullconnectionFP16CPUKernel::ReSize() { } memset(b_pack_ptr_, 0, fc_param_->col_8_ * fc_param_->deep_ * sizeof(float16_t)); - InitMatrixB(reinterpret_cast(in_tensors_[1]->data_c()), b_pack_ptr_); + fc_param_->b_const_ = (in_tensors_[1]->data_c() != nullptr); + if (fc_param_->b_const_) { + if (in_tensors_[1]->data_type() == kNumberTypeFloat32) { + InitMatrixB(reinterpret_cast(in_tensors_[1]->data_c()), b_pack_ptr_); + } else { + InitMatrixB(reinterpret_cast(in_tensors_[1]->data_c()), b_pack_ptr_); + } + } + if (in_tensors_.size() == 3) { bias_ptr_ = reinterpret_cast(ctx_->allocator->Malloc(fc_param_->col_8_ * sizeof(float16_t))); if (bias_ptr_ == nullptr) { @@ -108,6 +116,10 @@ void FullconnectionFP16CPUKernel::InitMatrixB(float *b_ptr, float16_t *b_pack_pt RowMajor2Col8MajorFp16(reinterpret_cast(b_ptr), b_pack_ptr, fc_param_->col_, fc_param_->deep_, true); } +void FullconnectionFP16CPUKernel::InitMatrixB(float16_t *b_ptr, float16_t *b_pack_ptr) { + RowMajor2Col8MajorFp16(reinterpret_cast(b_ptr), b_pack_ptr, fc_param_->col_, fc_param_->deep_, false); +} + int FullconnectionFP16CPUKernel::Init() { if (!InferShapeDone()) { return RET_OK; @@ -156,6 +168,13 @@ int FullconnectionFP16CPUKernel::Run() { } else { InitMatrixA(reinterpret_cast(in_tensors_[0]->data_c()), a_pack_ptr_); } + if (!fc_param_->b_const_) { + if (in_tensors_[1]->data_type() == kNumberTypeFloat32) { + InitMatrixB(reinterpret_cast(in_tensors_[1]->data_c()), b_pack_ptr_); + } else { + InitMatrixB(reinterpret_cast(in_tensors_[1]->data_c()), b_pack_ptr_); + } + } ParallelLaunch(this->context_->thread_pool_, FcFP16Run, this, thread_count_); if (out_tensor->data_type() == kNumberTypeFloat32) { auto size = out_tensor->ElementsNum(); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.h index 408ffe70ad..ea4cdd7717 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/fullconnection_fp16.h @@ -42,6 +42,7 @@ class FullconnectionFP16CPUKernel : public FullconnectionBaseCPUKernel { void InitMatrixA(float *a_ptr, float16_t *a_pack_ptr); void InitMatrixA(float16_t *a_ptr, float16_t *a_pack_ptr); void InitMatrixB(float *b_ptr, float16_t *b_pack_ptr); + void InitMatrixB(float16_t *b_ptr, float16_t *b_pack_ptr); void FreeTmpBuffer(); private: