From 4b8c9da7e7f446158b058c1cef9f809c63a72258 Mon Sep 17 00:00:00 2001 From: chenjianping Date: Wed, 12 Aug 2020 15:37:04 +0800 Subject: [PATCH] cast support fp32->int --- mindspore/lite/src/ops/ops.h | 2 +- .../lite/src/runtime/kernel/arm/fp32/cast.cc | 25 +++++++++++++++---- .../lite/src/runtime/kernel/arm/fp32/crop.cc | 5 ++-- .../src/runtime/kernel/arm/fp32/softmax.cc | 4 +++ .../src/runtime/kernel/arm/fp32/softmax.h | 2 +- .../src/runtime/kernel/arm/nnacl/fp32/cast.cc | 6 +++++ .../src/runtime/kernel/arm/nnacl/fp32/cast.h | 1 + 7 files changed, 36 insertions(+), 9 deletions(-) diff --git a/mindspore/lite/src/ops/ops.h b/mindspore/lite/src/ops/ops.h index 44b62f1642..6da3911374 100644 --- a/mindspore/lite/src/ops/ops.h +++ b/mindspore/lite/src/ops/ops.h @@ -37,7 +37,7 @@ constexpr uint32_t kNHWC_w_index = 2; constexpr uint32_t kNHWC_c_index = 3; constexpr uint32_t kDimension_4d = 4; -const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32}; +const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeFloat32}; class Primitive { public: diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc index 9a0d5e2655..58b431c22d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/cast.cc @@ -65,17 +65,32 @@ int CastCPUKernel::DoCast(int thread_id) { } auto offset = thread_id * stride_; - auto output_data = reinterpret_cast(out_tensors_.at(0)->Data()); - switch (input->data_type()) { + auto output = out_tensors_.at(0); + auto output_data = output->Data(); + auto input_data_type = input->data_type(); + auto output_data_type = output->data_type(); + if (output_data_type != kNumberTypeFloat32) { + if (input_data_type == kNumberTypeFloat32 && output_data_type == kNumberTypeInt32) { + Float32ToInt32(reinterpret_cast(input->Data()) + offset, + reinterpret_cast(output_data) + offset, data_num); + } else { + MS_LOG(ERROR) << "Unsupport datatype from " << input_data_type << " to " << output_data_type; + return RET_ERROR; + } + } else { + switch (input_data_type) { case kNumberTypeUInt8: - Uint8ToFloat32(reinterpret_cast(input->Data()) + offset, output_data + offset, data_num); + Uint8ToFloat32(reinterpret_cast(input->Data()) + offset, + reinterpret_cast(output_data) + offset, data_num); break; case kNumberTypeInt32: - Int32ToFloat32(reinterpret_cast(input->Data()) + offset, output_data + offset, data_num); + Int32ToFloat32(reinterpret_cast(input->Data()) + offset, + reinterpret_cast(output_data) + offset, data_num); break; default: - MS_LOG(ERROR) << "Unsupport input data type " << input->data_type(); + MS_LOG(ERROR) << "Unsupport input data type " << input_data_type; return RET_ERROR; + } } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc index a045e796b7..57214ad08a 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/crop.cc @@ -47,8 +47,9 @@ int CropCPUKernel::CropParallelRun(int thread_id) { auto output = out_tensors_[0]; float *input_data = reinterpret_cast(input->Data()); float *output_data = reinterpret_cast(output->Data()); - Crop4D(input_data, output_data, input->shape().data(), output->shape().data(), - reinterpret_cast(op_parameter_)); + auto param = reinterpret_cast(op_parameter_); + param->thread_id_ = thread_id; + Crop4D(input_data, output_data, input->shape().data(), output->shape().data(), param); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc index 2b813e7125..f7bde07c1c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.cc @@ -65,6 +65,10 @@ int SoftmaxCPUKernel::ReSize() { free(sum_data_); } sum_data_ = reinterpret_cast(malloc(out_plane_size * in_plane_size * sizeof(float))); + if (sum_data_ == nullptr) { + MS_LOG(ERROR) << "malloc data for softmax fail!"; + return RET_ERROR; + } memset(sum_data_, 0, out_plane_size * in_plane_size * sizeof(float)); return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h index 515535a328..6ca9e56d3e 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/softmax.h @@ -27,7 +27,7 @@ class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel { SoftmaxCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::Context *ctx, const lite::Primitive *primitive) - : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive) {} + : SoftmaxBaseCPUKernel(parameter, inputs, outputs, ctx, primitive), sum_data_(nullptr) {} ~SoftmaxCPUKernel() override { if (sum_data_ != nullptr) { free(sum_data_); diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.cc b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.cc index 9b9d326a50..4a8af888af 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.cc @@ -40,6 +40,12 @@ void Int32ToFloat32(const int32_t *input, float *output, int number) { } } +void Float32ToInt32(const float *input, int32_t *output, int number) { + for (int i = 0; i < number; ++i) { + output[i] = (int32_t)input[i]; + } +} + #ifdef ENABLE_FP16 void Float32ToFloat16(const float *input, float16_t *output, int number) { for (int i = 0; i < number; ++i) { diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.h index 9bb2df36e1..3c744a6703 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp32/cast.h @@ -32,6 +32,7 @@ void Uint8ToFloat32(const uint8_t *input, float *output, int number); void Uint8ToInt8(const uint8_t *input, int8_t *output, int number); void Int8ToUint8(const int8_t *input, uint8_t *output, int number); void Int32ToFloat32(const int32_t *input, float *output, int number); +void Float32ToInt32(const float *input, int32_t *output, int number); #ifdef ENABLE_FP16 void Float32ToFloat16(const float *input, float16_t *output, int number); void Float16ToFloat32(const float16_t *input, float *output, int number);