From 63b4695bffd84285ba241861d9b2aa0218a1aac6 Mon Sep 17 00:00:00 2001 From: gongdaguo Date: Fri, 27 Nov 2020 14:52:21 +0800 Subject: [PATCH] fix range --- mindspore/lite/nnacl/fp32/range_fp32.c | 13 +++-- mindspore/lite/nnacl/fp32/range_fp32.h | 3 +- mindspore/lite/src/ops/range.cc | 36 ++++++++++++-- .../src/runtime/kernel/arm/fp32/range_fp32.cc | 47 +++++++++++++------ .../src/runtime/kernel/arm/fp32/range_fp32.h | 3 ++ 5 files changed, 77 insertions(+), 25 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/range_fp32.c b/mindspore/lite/nnacl/fp32/range_fp32.c index 8234688700..e1097a2b8c 100644 --- a/mindspore/lite/nnacl/fp32/range_fp32.c +++ b/mindspore/lite/nnacl/fp32/range_fp32.c @@ -16,9 +16,14 @@ #include "nnacl/fp32/range_fp32.h" -void Range(float *output_ptr, int start, int limit, int delta) { - size_t index = 0; - for (size_t i = start; i < limit; i += delta) { - output_ptr[index++] = (float)(i); +void Range(float *output_ptr, float start, float delta, int nums) { + for (int i = 0; i < nums; ++i, start += delta) { + output_ptr[i] = start; + } +} + +void RangeInt(int *output_ptr, int start, int delta, int nums) { + for (int i = 0; i < nums; ++i, start += delta) { + output_ptr[i] = start; } } diff --git a/mindspore/lite/nnacl/fp32/range_fp32.h b/mindspore/lite/nnacl/fp32/range_fp32.h index 5e28d59fbd..712c4d4e08 100644 --- a/mindspore/lite/nnacl/fp32/range_fp32.h +++ b/mindspore/lite/nnacl/fp32/range_fp32.h @@ -30,7 +30,8 @@ typedef struct RangeParameter { #ifdef __cplusplus extern "C" { #endif -void Range(float *output_ptr, int start, int limit, int delta); +void Range(float *output_ptr, float start, float delta, int nums); +void RangeInt(int *output_ptr, int start, int delta, int nums); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/src/ops/range.cc b/mindspore/lite/src/ops/range.cc index d4b55cbc8a..baaa1c9956 100644 --- a/mindspore/lite/src/ops/range.cc +++ b/mindspore/lite/src/ops/range.cc @@ -64,7 +64,11 @@ int Range::InferShape(std::vector inputs_, std::vector outpu auto output = outputs_.front(); MS_ASSERT(output != nullptr); - output->set_data_type(mindspore::kNumberTypeFloat32); + if (inputs_.size() == 3) { + output->set_data_type(input->data_type()); + } else { + output->set_data_type(mindspore::kNumberTypeInt32); + } output->set_format(input->format()); if (!infer_flag()) { return RET_OK; @@ -72,14 +76,36 @@ int Range::InferShape(std::vector inputs_, std::vector outpu int shape_size = 0; if (inputs_.size() == 3) { - shape_size = -1; + if ((inputs_.at(0)->data_c() == nullptr) || (inputs_.at(1)->data_c() == nullptr) || + (inputs_.at(2)->data_c() == nullptr)) { + return RET_INFER_INVALID; + } + switch (inputs_.at(0)->data_type()) { + case kNumberTypeInt: + case kNumberTypeInt32: { + auto start = *reinterpret_cast(inputs_.at(0)->data_c()); + auto limit = *reinterpret_cast(inputs_.at(1)->data_c()); + auto delta = *reinterpret_cast(inputs_.at(2)->data_c()); + shape_size = std::max(static_cast(std::ceil(static_cast(limit - start) / delta)), 0); + } break; + case kNumberTypeFloat32: + case kNumberTypeFloat: { + auto start = *reinterpret_cast(inputs_.at(0)->data_c()); + auto limit = *reinterpret_cast(inputs_.at(1)->data_c()); + auto delta = *reinterpret_cast(inputs_.at(2)->data_c()); + shape_size = std::max(static_cast(std::ceil(static_cast(limit - start) / delta)), 0); + } break; + default: { + MS_LOG(ERROR) << "Range has unsupported dataType: " << inputs_.at(0)->data_type(); + return RET_INFER_ERR; + } + } } else { shape_size = std::ceil(static_cast(GetLimit() - GetStart()) / GetDelta()); } - std::vector in_shape; - in_shape.push_back(shape_size); - output->set_shape(in_shape); + std::vector in_shape = {shape_size}; + output->set_shape(in_shape); return RET_OK; } } // namespace lite diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/range_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/range_fp32.cc index 9ca46cbedb..f309fbafd6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/range_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/range_fp32.cc @@ -27,29 +27,44 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Range; namespace mindspore::kernel { -int RangeCPUKernel::Init() { return RET_OK; } +int RangeCPUKernel::Init() { + if (!InferShapeDone()) { + return RET_OK; + } + return ReSize(); +} -int RangeCPUKernel::ReSize() { return RET_OK; } +int RangeCPUKernel::ReSize() { + if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat16 || + in_tensors_[0]->data_type() == kNumberTypeFloat) { + data_type_ = kDataTypeFloat; + } else { + data_type_ = kDataTypeInt; + } + return RET_OK; +} int RangeCPUKernel::Run() { - size_t start = (reinterpret_cast(op_parameter_))->start_; - size_t limit = (reinterpret_cast(op_parameter_))->limit_; - size_t delta = (reinterpret_cast(op_parameter_))->delta_; if (in_tensors_.size() == 3) { - if ((in_tensors_.at(0)->data_type() == mindspore::kNumberTypeInt32) && - (in_tensors_.at(1)->data_type() == mindspore::kNumberTypeInt32) && - (in_tensors_.at(2)->data_type() == mindspore::kNumberTypeInt32)) { - start = *reinterpret_cast(in_tensors_.at(0)->data_c()); - limit = *reinterpret_cast(in_tensors_.at(1)->data_c()); - delta = *reinterpret_cast(in_tensors_.at(2)->data_c()); + if (data_type_ == kDataTypeInt) { + RangeInt(reinterpret_cast(out_tensors_.at(0)->data_c()), + *reinterpret_cast(in_tensors_.at(0)->data_c()), + *reinterpret_cast(in_tensors_.at(2)->data_c()), out_tensors_.at(0)->shape()[0]); + } else { + Range(reinterpret_cast(out_tensors_.at(0)->data_c()), + *reinterpret_cast(in_tensors_.at(0)->data_c()), + *reinterpret_cast(in_tensors_.at(2)->data_c()), out_tensors_.at(0)->shape()[0]); + } + } else { + if (data_type_ == kDataTypeInt) { + RangeInt(reinterpret_cast(out_tensors_.at(0)->data_c()), + (reinterpret_cast(op_parameter_))->start_, + (reinterpret_cast(op_parameter_))->delta_, out_tensors_.at(0)->shape()[0]); } else { MS_LOG(ERROR) << "Unsupported parameter type : " << in_tensors_.at(0)->data_type() << "."; return RET_ERROR; } } - auto output_ptr = reinterpret_cast(out_tensors_.at(0)->data_c()); - MS_ASSERT(output_ptr); - Range(output_ptr, start, limit, delta); return RET_OK; } @@ -77,5 +92,7 @@ kernel::LiteKernel *CpuRangeFp32KernelCreator(const std::vector } REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Range, CpuRangeFp32KernelCreator) - +REG_KERNEL(kCPU, kNumberTypeFloat, PrimitiveType_Range, CpuRangeFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Range, CpuRangeFp32KernelCreator) +REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Range, CpuRangeFp32KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/range_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/range_fp32.h index da6ed7110a..47b935f4c6 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/range_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/range_fp32.h @@ -33,6 +33,9 @@ class RangeCPUKernel : public LiteKernel { int Init() override; int ReSize() override; int Run() override; + + private: + LiteDataType data_type_ = kDataTypeFloat; }; } // namespace mindspore::kernel