From 27cc6d6c17ace3e57cdf51bc0d72c857ce9cb1a6 Mon Sep 17 00:00:00 2001 From: yangruoqi713 Date: Wed, 4 Nov 2020 16:12:01 +0800 Subject: [PATCH] [MSLITE][Develop] fix bug of arm cpu fp16 op cast --- mindspore/lite/nnacl/fp16/cast_fp16.c | 12 ++++ mindspore/lite/nnacl/fp16/cast_fp16.h | 2 + .../src/runtime/kernel/arm/fp16/cast_fp16.cc | 71 ++++++++++++++----- .../src/runtime/kernel/arm/fp16/pad_fp16.cc | 2 - mindspore/lite/src/scheduler.cc | 6 +- 5 files changed, 70 insertions(+), 23 deletions(-) diff --git a/mindspore/lite/nnacl/fp16/cast_fp16.c b/mindspore/lite/nnacl/fp16/cast_fp16.c index 235a7b3be2..d973b2268a 100644 --- a/mindspore/lite/nnacl/fp16/cast_fp16.c +++ b/mindspore/lite/nnacl/fp16/cast_fp16.c @@ -27,6 +27,18 @@ void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number) { } } +void Float16ToInt32(const float16_t *input, int32_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int32_t)input[i]; + } +} + +void Float16ToInt64(const float16_t *input, int64_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int64_t)input[i]; + } +} + #ifndef ENABLE_ARM64 void Float32ToFloat16(const float *input, float16_t *output, int number) { for (int i = 0; i < number; ++i) { diff --git a/mindspore/lite/nnacl/fp16/cast_fp16.h b/mindspore/lite/nnacl/fp16/cast_fp16.h index be942074a8..c64c30f683 100644 --- a/mindspore/lite/nnacl/fp16/cast_fp16.h +++ b/mindspore/lite/nnacl/fp16/cast_fp16.h @@ -24,6 +24,8 @@ extern "C" { #endif void BoolToFloat16(const bool *input, float16_t *output, int number); void Uint8ToFloat16(const uint8_t *input, float16_t *output, int number); +void Float16ToInt32(const float16_t *input, int32_t *output, int number); +void Float16ToInt64(const float16_t *input, int64_t *output, int number); void Float32ToFloat16(const float *input, float16_t *output, int number); void Float16ToFloat32(const float16_t *input, float *output, int number); #ifdef __cplusplus diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc index 3811b3e14c..70b4b08eaf 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/cast_fp16.cc @@ -65,25 +65,58 @@ int CastFp16CPUKernel::DoCast(int thread_id) { } auto offset = thread_id * stride_; - auto output_data = out_tensors_.at(0)->MutableData(); - switch (input->data_type()) { - case kNumberTypeBool: - BoolToFloat16(reinterpret_cast(input->MutableData()) + offset, - reinterpret_cast(output_data) + offset, data_num); - case kNumberTypeUInt8: - Uint8ToFloat16(reinterpret_cast(input->MutableData()) + offset, - reinterpret_cast(output_data) + offset, data_num); - case kNumberTypeFloat32: - Float32ToFloat16(reinterpret_cast(input->MutableData()) + offset, - reinterpret_cast(output_data) + offset, data_num); - break; - case kNumberTypeFloat16: - Float16ToFloat32(reinterpret_cast(input->MutableData()) + offset, - reinterpret_cast(output_data) + offset, data_num); - break; - default: - MS_LOG(ERROR) << "Unsupported input data type " << input->data_type(); - return RET_ERROR; + auto output = out_tensors_.at(0); + auto output_data = output->data_c(); + auto input_data_type = input->data_type(); + auto output_data_type = output->data_type(); + + if (input_data_type == kNumberTypeFloat16) { + switch (output_data_type) { + case kNumberTypeInt64: + Float16ToInt64(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + case kNumberTypeInt32: + Float16ToInt32(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + case kNumberTypeFloat32: + Float16ToFloat32(reinterpret_cast(input->MutableData()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + case kNumberTypeFloat16: + memcpy(reinterpret_cast(output_data) + offset, + reinterpret_cast(input->data_c()) + offset, data_num * sizeof(float16_t)); + break; + default: + MS_LOG(ERROR) << "Unsupported output data type " << output_data_type; + return RET_ERROR; + } + } else if (input_data_type == kNumberTypeFloat32) { + switch (output_data_type) { + case kNumberTypeInt64: + Float32ToInt64(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + case kNumberTypeInt32: + Float32ToInt32(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + case kNumberTypeFloat32: + memcpy(reinterpret_cast(output_data) + offset, reinterpret_cast(input->data_c()) + offset, + data_num * sizeof(float)); + break; + case kNumberTypeFloat16: + Float32ToFloat16(reinterpret_cast(input->MutableData()) + offset, + reinterpret_cast(output_data) + offset, data_num); + break; + default: + MS_LOG(ERROR) << "Unsupported output data type " << output_data_type; + return RET_ERROR; + } + } else { + MS_LOG(ERROR) << "Unsupported input data type " << input_data_type; + return RET_ERROR; } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/pad_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/pad_fp16.cc index a3fe9404b9..699c3f31ae 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/pad_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/pad_fp16.cc @@ -94,6 +94,4 @@ kernel::LiteKernel *CpuPadFp16KernelCreator(const std::vector &i } return kernel; } - -REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Pad, CpuPadFp16KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/scheduler.cc b/mindspore/lite/src/scheduler.cc index e5a9efdac1..38b736d544 100644 --- a/mindspore/lite/src/scheduler.cc +++ b/mindspore/lite/src/scheduler.cc @@ -304,7 +304,8 @@ TypeId Scheduler::GetFirstFp32Fp16OrInt8Type(const std::vector &in_ten return dtype; } } - return kNumberTypeFloat32; + MS_ASSERT(in_tensors.size() > 0); + return in_tensors[0]->data_type(); } void Scheduler::SetKernelTensorDataType(kernel::LiteKernel *kernel) { @@ -346,7 +347,8 @@ kernel::SubGraphType Scheduler::GetKernelSubGraphType(kernel::LiteKernel *kernel if (desc.data_type == kNumberTypeFloat16) { return kernel::kCpuFP16SubGraph; } else if (desc.data_type == kNumberTypeFloat32 || desc.data_type == kNumberTypeInt8 || - desc.data_type == kNumberTypeInt32 || desc.data_type == kNumberTypeBool) { + desc.data_type == kNumberTypeInt32 || desc.data_type == kNumberTypeBool || + desc.data_type == kNumberTypeUInt8) { return kernel::kCpuFP32SubGraph; } }