diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index 96e511f967..aa15ac86e8 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -96,202 +96,59 @@ int ArithmeticCPUKernel::InitBroadCastCase() { } void ArithmeticCPUKernel::InitRunFunction() { - switch (op_parameter_->type_) { - case PrimitiveType_Mul: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementMulRelu; - arithmetic_run_int_ = ElementMulReluInt; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementMulRelu6; - arithmetic_run_int_ = ElementMulRelu6Int; - break; - default: - arithmetic_run_ = ElementMul; - arithmetic_run_int_ = ElementMulInt; - break; - } - break; - case PrimitiveType_Add: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementAddRelu; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementAddRelu6; - break; - default: - arithmetic_run_ = ElementAdd; - arithmetic_run_int_ = ElementAddInt; - break; - } - break; - case PrimitiveType_Sub: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementSubRelu; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementSubRelu6; - break; - default: - arithmetic_run_ = ElementSub; - arithmetic_run_int_ = ElementSubInt; - break; - } - break; - case PrimitiveType_Div: - case PrimitiveType_RealDiv: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmetic_run_ = ElementDivRelu; - break; - case schema::ActivationType_RELU6: - arithmetic_run_ = ElementDivRelu6; - break; - default: - arithmetic_run_ = ElementDiv; - break; - } - break; - case PrimitiveType_LogicalAnd: - arithmetic_run_ = ElementLogicalAnd; - arithmetic_run_int_ = ElementLogicalAndInt; - arithmetic_run_bool_ = ElementLogicalAndBool; - break; - case PrimitiveType_LogicalOr: - arithmetic_run_ = ElementLogicalOr; - break; - case PrimitiveType_Maximum: - arithmetic_run_ = ElementMaximum; - arithmetic_run_int_ = ElementMaximumInt; - break; - case PrimitiveType_Minimum: - arithmetic_run_ = ElementMinimum; - arithmetic_run_int_ = ElementMinimumInt; - break; - case PrimitiveType_FloorDiv: - arithmetic_run_ = ElementFloorDiv; - arithmetic_run_int_ = ElementFloorDivInt; - break; - case PrimitiveType_FloorMod: - arithmetic_run_ = ElementFloorMod; - arithmetic_run_int_ = ElementFloorModInt; - break; - case PrimitiveType_Mod: - arithmetic_run_ = ElementMod; - arithmetic_run_int_ = ElementModInt; - break; - case PrimitiveType_SquaredDifference: - arithmetic_run_ = ElementSquaredDifference; - break; - case PrimitiveType_Equal: - case PrimitiveType_Less: - case PrimitiveType_Greater: - case PrimitiveType_NotEqual: - case PrimitiveType_LessEqual: - case PrimitiveType_GreaterEqual: - arithmetic_run_ = nullptr; - arithmetic_run_int_ = nullptr; - break; - default: - MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_; - arithmetic_run_ = nullptr; - break; - } - return; -} + ARITHMETIC_FUNC_INFO_FP32 fun_table[] = { + {PrimitiveType_Mul, schema::ActivationType_RELU, ElementMulRelu, ElementMulReluInt, nullptr, ElementOptMulRelu, + ElementOptMulReluInt}, + {PrimitiveType_Mul, schema::ActivationType_RELU6, ElementMulRelu6, ElementMulRelu6Int, nullptr, ElementOptMulRelu6, + ElementOptMulRelu6Int}, + {PrimitiveType_Mul, schema::ActivationType_NO_ACTIVATION, ElementMul, ElementMulInt, nullptr, ElementOptMul, + ElementOptMulInt}, + {PrimitiveType_Add, schema::ActivationType_RELU, ElementAddRelu, nullptr, nullptr, ElementOptAddRelu, nullptr}, + {PrimitiveType_Add, schema::ActivationType_RELU6, ElementAddRelu6, nullptr, nullptr, ElementOptAddRelu6, nullptr}, + {PrimitiveType_Add, schema::ActivationType_NO_ACTIVATION, ElementAdd, ElementAddInt, nullptr, ElementOptAdd, + ElementOptAddInt}, + {PrimitiveType_Sub, schema::ActivationType_RELU, ElementSubRelu, nullptr, nullptr, ElementOptSubRelu, nullptr}, + {PrimitiveType_Sub, schema::ActivationType_RELU6, ElementSubRelu6, nullptr, nullptr, ElementOptSubRelu6, nullptr}, + {PrimitiveType_Sub, schema::ActivationType_NO_ACTIVATION, ElementSub, ElementSubInt, nullptr, ElementOptSub, + ElementOptSubInt}, + {PrimitiveType_Div, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr}, + {PrimitiveType_Div, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6, nullptr}, + {PrimitiveType_Div, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv, + ElementOptDivInt}, + {PrimitiveType_RealDiv, schema::ActivationType_RELU, ElementDivRelu, nullptr, nullptr, ElementOptDivRelu, nullptr}, + {PrimitiveType_RealDiv, schema::ActivationType_RELU6, ElementDivRelu6, nullptr, nullptr, ElementOptDivRelu6, + nullptr}, + {PrimitiveType_RealDiv, schema::ActivationType_NO_ACTIVATION, ElementDiv, nullptr, nullptr, ElementOptDiv, + ElementOptDivInt}, + {PrimitiveType_LogicalAnd, schema::ActivationType_NO_ACTIVATION, ElementLogicalAnd, ElementLogicalAndInt, + ElementLogicalAndBool, nullptr, nullptr}, + {PrimitiveType_LogicalOr, schema::ActivationType_NO_ACTIVATION, ElementLogicalOr, nullptr, nullptr, nullptr, + nullptr}, + {PrimitiveType_Maximum, schema::ActivationType_NO_ACTIVATION, ElementMaximum, ElementMaximumInt, nullptr, nullptr, + nullptr}, + {PrimitiveType_Minimum, schema::ActivationType_NO_ACTIVATION, ElementMinimum, ElementMinimumInt, nullptr, nullptr, + nullptr}, + {PrimitiveType_FloorMod, schema::ActivationType_NO_ACTIVATION, ElementFloorMod, ElementFloorModInt, nullptr, + nullptr, nullptr}, + {PrimitiveType_FloorDiv, schema::ActivationType_NO_ACTIVATION, ElementFloorDiv, ElementFloorDivInt, nullptr, + nullptr, nullptr}, + {PrimitiveType_Mod, schema::ActivationType_NO_ACTIVATION, ElementMod, ElementModInt, nullptr, ElementOptMod, + ElementOptModInt}, + {PrimitiveType_SquaredDifference, schema::ActivationType_NO_ACTIVATION, ElementSquaredDifference, nullptr, nullptr, + nullptr, nullptr}}; -void ArithmeticCPUKernel::InitOptRunFunction() { - if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { - switch (arithmeticParameter_->op_parameter_.type_) { - case PrimitiveType_Mul: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptMulRelu; - arithmetic_opt_run_int_ = ElementOptMulReluInt; - break; - case schema::ActivationType_RELU6: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptMulRelu6; - arithmetic_opt_run_int_ = ElementOptMulRelu6Int; - break; - default: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptMul; - arithmetic_opt_run_int_ = ElementOptMulInt; - break; - } - break; - case PrimitiveType_Add: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptAddRelu; - break; - case schema::ActivationType_RELU6: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptAddRelu6; - break; - default: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptAdd; - arithmetic_opt_run_int_ = ElementOptAddInt; - break; - } - break; - case PrimitiveType_Sub: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptSubRelu; - break; - case schema::ActivationType_RELU6: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptSubRelu6; - break; - default: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptSub; - arithmetic_opt_run_int_ = ElementOptSubInt; - break; - } - break; - case PrimitiveType_Div: - case PrimitiveType_RealDiv: - switch (arithmeticParameter_->activation_type_) { - case schema::ActivationType_RELU: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptDivRelu; - break; - case schema::ActivationType_RELU6: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptDivRelu6; - break; - default: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptDiv; - arithmetic_opt_run_int_ = ElementOptDivInt; - break; - } - break; - case PrimitiveType_Mod: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptMod; - arithmetic_opt_run_int_ = ElementOptModInt; - break; - default: - arithmetic_opt_run_ = nullptr; - arithmetic_opt_run_int_ = nullptr; - break; + size_t length = sizeof(fun_table) / sizeof(ARITHMETIC_FUNC_INFO_FP32); + for (size_t i = 0; i < length; i++) { + if (fun_table[i].primitive_type_ == op_parameter_->type_ && + fun_table[i].activation_type_ == arithmeticParameter_->activation_type_) { + arithmetic_run_ = fun_table[i].func_; + arithmetic_run_int_ = fun_table[i].int_func_; + arithmetic_run_bool_ = fun_table[i].bool_func_; + arithmetic_opt_run_ = fun_table[i].opt_func_; + arithmetic_opt_run_int_ = fun_table[i].opt_int_func_; + return; } - } else { - arithmetic_opt_run_ = nullptr; - arithmetic_opt_run_int_ = nullptr; } - return; } void ArithmeticCPUKernel::InitParam() { @@ -321,7 +178,6 @@ void ArithmeticCPUKernel::InitParam() { int ArithmeticCPUKernel::ReSize() { InitParam(); - InitOptRunFunction(); return InitBroadCastCase(); } @@ -359,6 +215,66 @@ int ArithmeticCPUKernel::BroadcastRun(void *input0, void *input1, void *output, return RET_OK; } +bool ArithmeticCPUKernel::CanBatchScalar() { // 2 32 240 240, 2 32 1 1 + if (input0_broadcast_ == true || input1_broadcast_ == true) { + return false; + } + if (arithmeticParameter_->in_elements_num0_ == arithmeticParameter_->in_elements_num1_ || + arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { + return false; + } + size_t break_axis = 0; + for (size_t i = 0; i < arithmeticParameter_->ndim_; i++) { + if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) { + break_axis = i; + break; + } + } + if (break_axis < arithmeticParameter_->ndim_) { + for (size_t i = break_axis; i < arithmeticParameter_->ndim_; i++) { + if (arithmeticParameter_->in_shape1_[i] != 1) { + return false; + } + } + } + break_pos_ = break_axis; + return true; +} + +int ArithmeticCPUKernel::BatchScalarCalc(int task_id) { + int batch = arithmeticParameter_->out_elements_num_ / arithmeticParameter_->out_strides_[break_pos_ - 1]; + int batch_per_thread = UP_DIV(batch, thread_count_); + + int start_batch = batch_per_thread * task_id; + int end_batch = MSMIN(start_batch + batch_per_thread, batch); + int batch_size = end_batch - start_batch; + + int stride0 = arithmeticParameter_->in_strides0_[break_pos_ - 1]; + int stride1 = arithmeticParameter_->in_strides1_[break_pos_ - 1]; + int out_stride = arithmeticParameter_->out_strides_[break_pos_ - 1]; + + int offset0 = stride0 * start_batch; + int offset1 = stride1 * start_batch; + int out_offset = out_stride * start_batch; + + int ret = RET_OK; + for (int i = 0; i < batch_size; i++) { + if (data_type_ == kDataTypeFloat) { + ret = arithmetic_opt_run_( + reinterpret_cast(input0_ptr_) + offset0, reinterpret_cast(input1_ptr_) + offset1, + reinterpret_cast(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_); + } else { + ret = arithmetic_opt_run_int_( + reinterpret_cast(input0_ptr_) + offset0, reinterpret_cast(input1_ptr_) + offset1, + reinterpret_cast(out_tensors_[0]->data_c()) + out_offset, out_stride, arithmeticParameter_); + } + offset0 += stride0; + offset1 += stride1; + out_offset += out_stride; + } + return ret; +} + int ArithmeticCPUKernel::DoArithmetic(int task_id) { auto element_num = out_tensors_[0]->ElementsNum(); @@ -370,27 +286,12 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; return RET_ERROR; } - - int error_code; - if (arithmeticParameter_->broadcasting_) { - /* need broadcast in runtime */ - stride = UP_DIV(outside_, thread_count_); - int out_count = MSMIN(stride, outside_ - stride * task_id); - if (out_count <= 0) { - return RET_OK; - } - int out_thread_stride = stride * task_id; - if (data_type_ == kDataTypeFloat) { - error_code = BroadcastRun(reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_), - reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); - } else { - error_code = BroadcastRun(reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_), - reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); - } - return error_code; + if (CanBatchScalar()) { + return BatchScalarCalc(task_id); } - - if (arithmetic_opt_run_ != nullptr) { + int error_code = RET_OK; + if ((arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) && + (arithmetic_opt_run_ != nullptr && arithmetic_opt_run_int_ != nullptr)) { /* run opt function * one of input is scalar */ if (arithmeticParameter_->in_elements_num0_ == 1) { @@ -413,11 +314,24 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { reinterpret_cast(input0_ptr_) + stride * task_id, reinterpret_cast(input1_ptr_), reinterpret_cast(out_tensors_[0]->data_c()) + stride * task_id, count, arithmeticParameter_); } + } + return error_code; + } + if (arithmeticParameter_->broadcasting_) { + /* need broadcast in runtime */ + stride = UP_DIV(outside_, thread_count_); + int out_count = MSMIN(stride, outside_ - stride * task_id); + if (out_count <= 0) { + return RET_OK; + } + int out_thread_stride = stride * task_id; + if (data_type_ == kDataTypeFloat) { + error_code = BroadcastRun(reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_), + reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); } else { - MS_LOG(ERROR) << "Arithmetic opt run: at least one of inputs is scalar"; - return RET_ERROR; + error_code = BroadcastRun(reinterpret_cast(input0_ptr_), reinterpret_cast(input1_ptr_), + reinterpret_cast(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); } - return error_code; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h index 5c6ebfb8a7..c5c0842c53 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.h @@ -52,6 +52,16 @@ class ArithmeticCPUKernel : public LiteKernel { const ArithmeticParameter *param); typedef int (*ArithmeticBoolRun)(const bool *input0, const bool *input1, bool *output, const int element_size); + typedef struct { + int primitive_type_; + int activation_type_; + ArithmeticRun func_; + ArithmeticIntRun int_func_; + ArithmeticBoolRun bool_func_; + ArithmeticOptRun opt_func_; + ArithmeticOptIntRun opt_int_func_; + } ARITHMETIC_FUNC_INFO_FP32; + public: ArithmeticCPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, @@ -75,6 +85,8 @@ class ArithmeticCPUKernel : public LiteKernel { void FreeTmpPtr(); int InitBroadCastCase(); void InitParamInRunTime(); + bool CanBatchScalar(); + int BatchScalarCalc(int task_id); protected: bool input0_broadcast_ = false;