From 0bebd24edae8dd7953f9114b1bbd065f64e489d3 Mon Sep 17 00:00:00 2001 From: zhanyuan Date: Thu, 18 Feb 2021 17:41:48 +0800 Subject: [PATCH] [MSLITE] Fix the bug of fp16 gather working in multithreading --- .../lite/src/runtime/kernel/arm/fp16/gather_fp16.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc index 90694a0375..6be60ca84d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/gather_fp16.cc @@ -91,9 +91,6 @@ int GatherFp16CPUKernel::DoGather(int task_id) { auto thread_stride = stride * task_id; int8_t *int8_in = nullptr; if (input_tensor->data_type() == kNumberTypeFloat32) { - input_data_ = - reinterpret_cast(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); - Float32ToFloat16(reinterpret_cast(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); int8_in = reinterpret_cast(input_data_); } else if (input_tensor->data_type() == kNumberTypeFloat16) { int8_in = reinterpret_cast(input_tensor->data_c()); @@ -127,7 +124,12 @@ int GatherFp16CPUKernel::Run() { MS_LOG(ERROR) << "AssignIndicesData failed, error_code[" << ret << "]"; return ret; } - + auto input_tensor = in_tensors_.at(0); + if (input_tensor->data_type() == kNumberTypeFloat32) { + input_data_ = + reinterpret_cast(context_->allocator->Malloc(input_tensor->ElementsNum() * sizeof(float16_t))); + Float32ToFloat16(reinterpret_cast(input_tensor->data_c()), input_data_, input_tensor->ElementsNum()); + } ret = ParallelLaunch(this->context_->thread_pool_, GatherRunFp16, this, op_parameter_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Gather function error error_code[" << ret << "]"; @@ -140,7 +142,6 @@ int GatherFp16CPUKernel::Run() { context_->allocator->Free(input_data_); input_data_ = nullptr; } - return ret; }