From 9b459ce48b5400e2b11d51b1614d8f78277af6a2 Mon Sep 17 00:00:00 2001 From: sunsuodong Date: Mon, 8 Mar 2021 17:36:58 +0800 Subject: [PATCH] arithmetic --- .../kernel/arm/fp16/arithmetic_fp16.cc | 4 - .../kernel/arm/fp32/arithmetic_fp32.cc | 128 ++++++++++++------ .../runtime/kernel/arm/fp32/arithmetic_fp32.h | 5 +- 3 files changed, 94 insertions(+), 43 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index 05e6fbad04..cade91f227 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -119,10 +119,6 @@ void ArithmeticFP16CPUKernel::TileConstTensor(const void *in_data, void *out_dat int ArithmeticFP16CPUKernel::Execute(const void *input0, const void *input1, void *output, int size, bool is_opt) { int ret = RET_OK; - if (in_tensors_[0]->data_type() != kNumberTypeFloat16) { - MS_LOG(ERROR) << "data type is not fp16"; - return RET_ERROR; - } if (is_opt) { CHECK_NULL_RETURN(arithmetic_opt_func_, RET_ERROR); ret = arithmetic_opt_func_(reinterpret_cast(input0), reinterpret_cast(input1), diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index bbdbd20fb5..acce89c95d 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -60,7 +60,11 @@ int ArithmeticCPUKernel::ReSize() { outside_ *= param_->out_shape_[i]; } } - return ConstTensorBroadCast(); + int ret = RET_OK; + if (!isScalarClac() && !isBatchScalarCalc() && !isBiasCalc()) { + ret = ConstTensorBroadCast(); + } + return ret; } int ArithmeticCPUKernel::CheckDataType() { @@ -73,6 +77,47 @@ int ArithmeticCPUKernel::CheckDataType() { return RET_OK; } +bool ArithmeticCPUKernel::isScalarClac() { // 2 32 240 240, 1 1 1 1 + if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && (arithmetic_opt_run_ != nullptr)) { + return true; + } else { + return false; + } +} + +bool ArithmeticCPUKernel::isBatchScalarCalc() { // 2 32 240 240, 2 32 1 1 + if (arithmetic_opt_run_ == nullptr) { + return false; + } + size_t break_axis = 0; + for (size_t i = 0; i < param_->ndim_; i++) { + if (param_->in_shape0_[i] != param_->in_shape1_[i]) { + break_axis = i; + break; + } + } + if (break_axis < param_->ndim_) { + for (size_t i = break_axis; i < param_->ndim_; i++) { + if (param_->in_shape1_[i] != 1) { + return false; + } + } + } + break_pos_ = break_axis; + return true; +} + +bool ArithmeticCPUKernel::isBiasCalc() { // 2 240 240 32, 1 1 1 32 + int last_shape0 = param_->in_shape0_[param_->ndim_ - 1]; + int last_shape1 = param_->in_shape1_[param_->ndim_ - 1]; + if (param_->in_elements_num0_ > param_->in_elements_num1_) { + return param_->in_elements_num1_ == last_shape1 && last_shape0 == last_shape1; + } else if (param_->in_elements_num0_ < param_->in_elements_num1_) { + return param_->in_elements_num0_ == last_shape0 && last_shape0 == last_shape1; + } + return false; +} + int ArithmeticCPUKernel::ConstTensorBroadCast() { /* if const node need broadcast and all need-broadcast-node are const, broadcast in resize */ if (!param_->broadcasting_) { @@ -86,11 +131,6 @@ int ArithmeticCPUKernel::ConstTensorBroadCast() { param_->in_elements_num1_ != param_->out_elements_num_) { return RET_OK; } - if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && arithmetic_opt_run_ != nullptr) { - /* run opt function - * one of input is scalar */ - return RET_OK; - } FreeConstTileBuff(); if (in_tensors_[0]->data_c() != nullptr && param_->in_elements_num0_ != param_->out_elements_num_) { @@ -252,32 +292,6 @@ int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, return RET_OK; } -bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1 - if (input0_broadcast_ || input1_broadcast_) { - return false; - } - if (param_->in_elements_num0_ == param_->in_elements_num1_ || param_->in_elements_num0_ == 1 || - param_->in_elements_num1_ == 1) { - return false; - } - size_t break_axis = 0; - for (size_t i = 0; i < param_->ndim_; i++) { - if (param_->in_shape0_[i] != param_->in_shape1_[i]) { - break_axis = i; - break; - } - } - if (break_axis < param_->ndim_) { - for (size_t i = break_axis; i < param_->ndim_; i++) { - if (param_->in_shape1_[i] != 1) { - return false; - } - } - } - break_pos_ = break_axis; - return true; -} - int ArithmeticCPUKernel::BatchScalarCalc(int task_id) { if (break_pos_ < 1) { return RET_ERROR; @@ -308,6 +322,40 @@ int ArithmeticCPUKernel::BatchScalarCalc(int task_id) { return ret; } +int ArithmeticCPUKernel::BiasCalc(int task_id) { + int last_shape = param_->out_shape_[param_->ndim_ - 1]; + int batch = param_->out_elements_num_ / last_shape; + int batch_per_thread = UP_DIV(batch, context_->thread_num_); + + int start_batch = batch_per_thread * task_id; + int end_batch = MSMIN(start_batch + batch_per_thread, batch); + int batch_size = end_batch - start_batch; + + int stride = last_shape * data_type_len_; + int offset = stride * start_batch; + int ret = RET_OK; + if (param_->in_elements_num0_ > param_->in_elements_num1_) { + for (int i = 0; i < batch_size; i++) { + ret = Execute(static_cast(input0_ptr_) + offset, static_cast(input1_ptr_), + static_cast(output_ptr_) + offset, last_shape, false); + if (ret != RET_OK) { + return ret; + } + offset += stride; + } + } else { + for (int i = 0; i < batch_size; i++) { + ret = Execute(static_cast(input0_ptr_), static_cast(input1_ptr_) + offset, + static_cast(output_ptr_) + offset, last_shape, false); + if (ret != RET_OK) { + return ret; + } + offset += stride; + } + } + return ret; +} + int ArithmeticCPUKernel::DoArithmetic(int task_id) { auto element_num = out_tensors_[0]->ElementsNum(); int stride = UP_DIV(element_num, context_->thread_num_); @@ -315,13 +363,9 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { if (count <= 0) { return RET_OK; } - /* run opt function, every batch one of input is scalar */ - if (CanBatchScalar()) { - return BatchScalarCalc(task_id); - } int offset = stride * task_id * data_type_len_; /* run opt function, one of input is scalar */ - if ((param_->in_elements_num0_ == 1 || param_->in_elements_num1_ == 1) && arithmetic_opt_run_ != nullptr) { + if (isScalarClac()) { // 2 32 240 240, 1 1 1 1 if (param_->in_elements_num0_ == 1) { return Execute(input0_ptr_, static_cast(input1_ptr_) + offset, static_cast(output_ptr_) + offset, count, true); @@ -330,6 +374,14 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { static_cast(output_ptr_) + offset, count, true); } } + /* run opt function, every batch one of input is scalar */ + if (isBatchScalarCalc()) { // 2 32 240 240, 2 32 1 1 + return BatchScalarCalc(task_id); + } + /* each batch is eltwise calculation */ + if (isBiasCalc()) { // 2 240 240 32, 1 1 1 32 + return BiasCalc(task_id); + } /* need broadcast in runtime */ if (param_->broadcasting_) { stride = UP_DIV(outside_, context_->thread_num_); @@ -339,7 +391,7 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { } return BroadcastRun(input0_ptr_, input1_ptr_, output_ptr_, 0, out_count, stride * task_id); } - /* no broadcast in runtime */ + /* all elements eltwise calculation */ return Execute(static_cast(input0_ptr_) + offset, static_cast(input1_ptr_) + offset, static_cast(output_ptr_) + offset, count, false); } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h index a84de59192..02e321bea2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h @@ -108,9 +108,12 @@ class ArithmeticCPUKernel : public LiteKernel { int data_type_len_ = sizeof(float); private: - bool CanBatchScalar(); int BatchScalarCalc(int task_id); + int BiasCalc(int task_id); void FreeConstTileBuff(); + bool isScalarClac(); + bool isBatchScalarCalc(); + bool isBiasCalc(); ArithmeticRun arithmetic_run_ = nullptr; ArithmeticOptRun arithmetic_opt_run_ = nullptr; ArithmeticIntRun arithmetic_run_int_ = nullptr;