diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc index 735a0e881b..0934479802 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc @@ -30,7 +30,22 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Gather; namespace mindspore::kernel { +GatherFp16CPUKernel::~GatherFp16CPUKernel() { + if (input_data_) { + context_->allocator->Free(input_data_); + input_data_ = nullptr; + } +} + int GatherFp16CPUKernel::Init() { + auto input_tensor = in_tensors_.at(0); + if (input_tensor->data_type() == kNumberTypeFloat32 && input_tensor->data_c() != nullptr) { + const_input_ = true; + input_data_ = + reinterpret_cast(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); + Float32ToFloat16(reinterpret_cast(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); + } + if (!InferShapeDone()) { return RET_OK; } @@ -128,11 +143,13 @@ int GatherFp16CPUKernel::Run() { MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]"; return ret; } - auto input_tensor = in_tensors_.at(0); - if (input_tensor->data_type() == kNumberTypeFloat32) { - input_data_ = - reinterpret_cast(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); - Float32ToFloat16(reinterpret_cast(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); + if (!const_input_) { + auto input_tensor = in_tensors_.at(0); + if (input_tensor->data_type() == kNumberTypeFloat32) { + input_data_ = + reinterpret_cast(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); + Float32ToFloat16(reinterpret_cast(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); + } } ret = ParallelLaunch(this->context_->thread_pool_, GatherRunFp16, this, op_parameter_->thread_num_); if (ret != RET_OK) { @@ -142,7 +159,7 @@ int GatherFp16CPUKernel::Run() { context_->allocator->Free(indices_data_); indices_data_ = nullptr; } - if (input_data_) { + if (!const_input_ && input_data_) { context_->allocator->Free(input_data_); input_data_ = nullptr; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.h index 5092277749..23f27223ae 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.h @@ -30,7 +30,7 @@ class GatherFp16CPUKernel : public LiteKernel { GatherFp16CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) : LiteKernel(parameter, inputs, outputs, ctx) {} - ~GatherFp16CPUKernel() = default; + ~GatherFp16CPUKernel() override; int Init() override; int ReSize() override; @@ -42,6 +42,7 @@ class GatherFp16CPUKernel : public LiteKernel { int *indices_data_ = nullptr; int AssignIndicesData(bool isIndicesInt32, int indices_num, lite::Tensor *indices_tensor); float16_t *input_data_ = nullptr; + bool const_input_ = false; }; } // namespace mindspore::kernel