diff --git a/mindspore/lite/nnacl/base/cast_base.h b/mindspore/lite/nnacl/base/cast_base.h index ee62db13df..411f2926a3 100644 --- a/mindspore/lite/nnacl/base/cast_base.h +++ b/mindspore/lite/nnacl/base/cast_base.h @@ -71,6 +71,11 @@ inline void Int32ToInt64(const int32_t *input, int64_t *output, int number) { } } +inline void Float32ToInt16(const float *input, int16_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int16_t)input[i]; + } +} #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc index 50783d362b..fbe5b255f4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast_fp32.cc @@ -85,6 +85,9 @@ int CastCPUKernel::DoCast(int thread_id) { } else if (input_data_type == kNumberTypeInt32 && output_data_type == kNumberTypeInt64) { Int32ToInt64(reinterpret_cast(input->data_c()) + offset, reinterpret_cast(output_data) + offset, data_num); + } else if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt16) { + Float32ToInt16(reinterpret_cast(input->data_c()) + offset, + reinterpret_cast(output_data) + offset, data_num); } else { MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type; return RET_ERROR;