From d9b4c5a04bfc16e083f1896bbaab300fd8dabfd9 Mon Sep 17 00:00:00 2001 From: fuzhiye Date: Fri, 12 Mar 2021 11:12:01 +0800 Subject: [PATCH] add int reduce_mean func --- mindspore/lite/nnacl/fp32/reduce_fp32.c | 49 ++++++++++ mindspore/lite/nnacl/fp32/reduce_fp32.h | 2 + .../runtime/kernel/arm/fp32/reduce_fp32.cc | 97 ++++++++----------- .../src/runtime/kernel/arm/fp32/reduce_fp32.h | 25 +++-- 4 files changed, 111 insertions(+), 62 deletions(-) diff --git a/mindspore/lite/nnacl/fp32/reduce_fp32.c b/mindspore/lite/nnacl/fp32/reduce_fp32.c index 7363f4cdfa..385097e030 100644 --- a/mindspore/lite/nnacl/fp32/reduce_fp32.c +++ b/mindspore/lite/nnacl/fp32/reduce_fp32.c @@ -44,6 +44,49 @@ int ReduceMean(int outer_size, int inner_size, int axis_size, const float *src_d } return NNACL_OK; } + +int IntReduceMean(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid, + int thread_num) { + if (src_data == NULL || dst_data == NULL) { + return NNACL_NULL_PTR; + } + int i, j; +#ifdef ENABLE_NEON + int block_mod = inner_size % C4NUM; + int block_c4 = inner_size - block_mod; +#endif + for (j = tid; j < outer_size; j += thread_num) { + const int *outer_src = src_data + j * axis_size * inner_size; + int *outer_dst = dst_data + j * inner_size; + int k = 0; +#ifdef ENABLE_NEON + for (; k < block_c4; k += C4NUM) { + const int *inner_src = outer_src + k; + int *inner_dst = outer_dst + k; + int32x4_t tmp = {0, 0, 0, 0}; + for (i = 0; i < axis_size; i++) { + tmp = vaddq_s32(tmp, vld1q_s32(inner_src + i * inner_size)); + } + tmp[0] /= axis_size; + tmp[1] /= axis_size; + tmp[2] /= axis_size; + tmp[3] /= axis_size; + vst1q_s32(inner_dst, tmp); + } +#endif + for (; k < inner_size; k++) { + const int *inner_src = outer_src + k; + int *inner_dst = outer_dst + k; + int tmp = 0; + for (i = 0; i < axis_size; i++) { + tmp += inner_src[i * inner_size]; + } + *inner_dst = tmp / axis_size; + } + } + return NNACL_OK; +} + int ReduceSum(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, int thread_num) { if (src_data == NULL || dst_data == NULL) { @@ -81,6 +124,7 @@ int ReduceSum(int outer_size, int inner_size, int axis_size, const float *src_da } return NNACL_OK; } + int IntReduceSum(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid, int thread_num) { if (src_data == NULL || dst_data == NULL) { @@ -118,6 +162,7 @@ int IntReduceSum(int outer_size, int inner_size, int axis_size, const int *src_d } return NNACL_OK; } + int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, int thread_num) { if (src_data == NULL || dst_data == NULL) { @@ -139,6 +184,7 @@ int ReduceMax(int outer_size, int inner_size, int axis_size, const float *src_da } return NNACL_OK; } + int IntReduceMax(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid, int thread_num) { if (src_data == NULL || dst_data == NULL) { @@ -160,6 +206,7 @@ int IntReduceMax(int outer_size, int inner_size, int axis_size, const int *src_d } return NNACL_OK; } + int ReduceMin(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, int thread_num) { if (src_data == NULL || dst_data == NULL) { @@ -181,6 +228,7 @@ int ReduceMin(int outer_size, int inner_size, int axis_size, const float *src_da } return NNACL_OK; } + int IntReduceMin(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid, int thread_num) { if (src_data == NULL || dst_data == NULL) { @@ -271,6 +319,7 @@ int IntReduceProd(int outer_size, int inner_size, int axis_size, const int *src_ } return NNACL_OK; } + int ReduceSumSquare(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, int thread_num) { if (src_data == NULL || dst_data == NULL) { diff --git a/mindspore/lite/nnacl/fp32/reduce_fp32.h b/mindspore/lite/nnacl/fp32/reduce_fp32.h index 9a6a878936..30901f0622 100644 --- a/mindspore/lite/nnacl/fp32/reduce_fp32.h +++ b/mindspore/lite/nnacl/fp32/reduce_fp32.h @@ -24,6 +24,8 @@ extern "C" { #endif int ReduceMean(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, int thread_num); +int IntReduceMean(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid, + int thread_num); int ReduceSum(int outer_size, int inner_size, int axis_size, const float *src_data, float *dst_data, int tid, int thread_num); int IntReduceSum(int outer_size, int inner_size, int axis_size, const int *src_data, int *dst_data, int tid, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce_fp32.cc index 7233231452..d0f61a8098 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce_fp32.cc @@ -40,54 +40,13 @@ using mindspore::schema::ReduceMode_ReduceSum; using mindspore::schema::ReduceMode_ReduceSumSquare; namespace mindspore::kernel { - int ReduceCPUKernel::Init() { auto ret = ReduceBaseCPUKernel::Init(); if (ret != RET_OK) { return ret; } - switch (mode_) { - case static_cast(ReduceMode_ReduceSum): { - reducer_ = ReduceSum; - int_reducer_ = IntReduceSum; - break; - } - case static_cast(ReduceMode_ReduceMean): { - reducer_ = ReduceMean; - break; - } - case static_cast(ReduceMode_ReduceMax): { - reducer_ = ReduceMax; - int_reducer_ = IntReduceMax; - break; - } - case static_cast(ReduceMode_ReduceMin): { - reducer_ = ReduceMin; - int_reducer_ = IntReduceMin; - break; - } - case static_cast(ReduceMode_ReduceProd): { - reducer_ = ReduceProd; - int_reducer_ = IntReduceProd; - break; - } - case static_cast(ReduceMode_ReduceSumSquare): { - reducer_ = ReduceSum; - break; - } - case static_cast(ReduceMode_ReduceASum): { - reducer_ = ReduceSum; - break; - } - case static_cast(ReduceMode_ReduceAll): { - bool_reducer_ = ReduceAll; - break; - } - default: - MS_LOG(ERROR) << "Reduce unsupported reduce mode: " << mode_; - return RET_ERROR; - } + InitialKernelList(); if (!InferShapeDone()) { return RET_OK; @@ -98,19 +57,29 @@ int ReduceCPUKernel::Init() { int ReduceCPUKernel::ReSize() { return ReduceBaseCPUKernel::ReSize(); } int ReduceCPUKernel::CallReduceUnit(int task_id) { - int ret; if (data_type_ == kDataTypeFloat) { - ret = reducer_(outer_size_, inner_size_, axis_size_, static_cast(src_data_), - static_cast(dst_data_), task_id, context_->thread_num_); + if (!reducer_) { + MS_LOG(ERROR) << "function reducer_ is null."; + return RET_NULL_PTR; + } + reducer_(outer_size_, inner_size_, axis_size_, static_cast(src_data_), + static_cast(dst_data_), task_id, context_->thread_num_); } else if (data_type_ == KDataTypeBool) { - ret = bool_reducer_(outer_size_, inner_size_, axis_size_, static_cast(src_data_), - static_cast(dst_data_), task_id, context_->thread_num_); + if (!bool_reducer_) { + MS_LOG(ERROR) << "function bool_reducer_ is null."; + return RET_NULL_PTR; + } + bool_reducer_(outer_size_, inner_size_, axis_size_, static_cast(src_data_), + static_cast(dst_data_), task_id, context_->thread_num_); } else { - ret = int_reducer_(outer_size_, inner_size_, axis_size_, static_cast(src_data_), - static_cast(dst_data_), task_id, context_->thread_num_); + if (!int_reducer_) { + MS_LOG(ERROR) << "function int_reducer_ is null."; + return RET_NULL_PTR; + } + int_reducer_(outer_size_, inner_size_, axis_size_, static_cast(src_data_), + static_cast(dst_data_), task_id, context_->thread_num_); } - - return ret; + return RET_OK; } int ReduceImpl(void *cdata, int task_id) { @@ -143,7 +112,7 @@ int ReduceCPUKernel::Run() { if (i != static_cast(num_axes_ - 1)) { dst_data_ = data_buffers_.at(i); } else { - dst_data_ = out_tensors_.at(0)->MutableData(); + dst_data_ = out_tensors_.at(0)->data_c(); } outer_size_ = outer_sizes_.at(i); inner_size_ = inner_sizes_.at(i); @@ -173,7 +142,7 @@ void ReduceCPUKernel::HandleASumAndSumSquare() { return; } int num = in_tensors_.at(0)->ElementsNum(); - float *data = reinterpret_cast(in_tensors_.at(0)->data_c()); + auto *data = reinterpret_cast(in_tensors_.at(0)->data_c()); if (data == nullptr) { return; } @@ -197,7 +166,7 @@ int ReduceCPUKernel::CalculateCoeffOutput() { if (data_type_ != kDataTypeFloat) { return RET_ERROR; } - float *out_data = reinterpret_cast(out_tensor->MutableData()); + auto *out_data = reinterpret_cast(out_tensor->data_c()); if (out_data == nullptr) { return RET_NULL_PTR; } @@ -237,6 +206,26 @@ void ReduceCPUKernel::FreeTmpBuffer() { data_buffers_.clear(); } +void ReduceCPUKernel::InitialKernelList() { + ReduceKernelList func_list[] = {{ReduceMode_ReduceSum, ReduceSum, IntReduceSum, nullptr}, + {ReduceMode_ReduceMean, ReduceMean, IntReduceMean, nullptr}, + {ReduceMode_ReduceMax, ReduceMax, IntReduceMax, nullptr}, + {ReduceMode_ReduceMin, ReduceMin, IntReduceMin, nullptr}, + {ReduceMode_ReduceProd, ReduceProd, IntReduceProd, nullptr}, + {ReduceMode_ReduceSumSquare, ReduceSum, IntReduceSum, nullptr}, + {ReduceMode_ReduceASum, ReduceSum, IntReduceSum, nullptr}, + {ReduceMode_ReduceAll, nullptr, nullptr, ReduceAll}}; + int list_len = sizeof(func_list) / sizeof(ReduceKernelList); + for (int i = 0; i < list_len; ++i) { + if (mode_ == func_list[i].type_) { + reducer_ = func_list[i].float_func_; + int_reducer_ = func_list[i].int_func_; + bool_reducer_ = func_list[i].bool_func_; + break; + } + } +} + REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_ReduceFusion, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_ReduceFusion, LiteKernelCreator) REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_ReduceFusion, LiteKernelCreator) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce_fp32.h index 06b228e341..943e5d319d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/reduce_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/reduce_fp32.h @@ -26,21 +26,27 @@ using mindspore::schema::ReduceMode; namespace mindspore::kernel { -class ReduceCPUKernel : public ReduceBaseCPUKernel { - typedef int (*Reducer)(const int outer_size, const int inner_size, const int axis_size, const float *src_data, - float *dst_data, const int tid, const int thread_num); - typedef int (*IntReducer)(const int outer_size, const int inner_size, const int axis_size, const int *src_data, - int *dst_data, const int tid, const int thread_num); - typedef int (*BoolReducer)(const int outer_size, const int inner_size, const int axis_size, const bool *src_data, - bool *dst_data, const int tid, const int thread_num); +typedef int (*Reducer)(const int outer_size, const int inner_size, const int axis_size, const float *src_data, + float *dst_data, const int tid, const int thread_num); +typedef int (*IntReducer)(const int outer_size, const int inner_size, const int axis_size, const int *src_data, + int *dst_data, const int tid, const int thread_num); +typedef int (*BoolReducer)(const int outer_size, const int inner_size, const int axis_size, const bool *src_data, + bool *dst_data, const int tid, const int thread_num); +struct ReduceKernelList { + int type_; + Reducer float_func_; + IntReducer int_func_; + BoolReducer bool_func_; +}; +class ReduceCPUKernel : public ReduceBaseCPUKernel { public: ReduceCPUKernel(OpParameter *param, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx) : ReduceBaseCPUKernel(param, inputs, outputs, ctx) { reduce_param_ = reinterpret_cast(param); } - ~ReduceCPUKernel() { + ~ReduceCPUKernel() override { src_data_ = nullptr; dst_data_ = nullptr; reducer_ = nullptr; @@ -52,6 +58,9 @@ class ReduceCPUKernel : public ReduceBaseCPUKernel { int Run() override; int CallReduceUnit(int task_id); + protected: + void InitialKernelList(); + private: ReduceParameter *reduce_param_; Reducer reducer_ = nullptr;