|
|
|
@ -56,111 +56,7 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|
|
|
|
const mindspore::lite::PrimitiveC *primitive)
|
|
|
|
|
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
|
|
|
|
|
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
|
|
|
|
|
switch (parameter->type_) {
|
|
|
|
|
case PrimitiveType_Mul:
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) {
|
|
|
|
|
case schema::ActivationType_RELU:
|
|
|
|
|
arithmetic_run_ = ElementMulRelu;
|
|
|
|
|
arithmetic_run_int_ = ElementMulReluInt;
|
|
|
|
|
break;
|
|
|
|
|
case schema::ActivationType_RELU6:
|
|
|
|
|
arithmetic_run_ = ElementMulRelu6;
|
|
|
|
|
arithmetic_run_int_ = ElementMulRelu6Int;
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
arithmetic_run_ = ElementMul;
|
|
|
|
|
arithmetic_run_int_ = ElementMulInt;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_Add:
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) {
|
|
|
|
|
case schema::ActivationType_RELU:
|
|
|
|
|
arithmetic_run_ = ElementAddRelu;
|
|
|
|
|
break;
|
|
|
|
|
case schema::ActivationType_RELU6:
|
|
|
|
|
arithmetic_run_ = ElementAddRelu6;
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
arithmetic_run_ = ElementAdd;
|
|
|
|
|
arithmetic_run_int_ = ElementAddInt;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_Sub:
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) {
|
|
|
|
|
case schema::ActivationType_RELU:
|
|
|
|
|
arithmetic_run_ = ElementSubRelu;
|
|
|
|
|
break;
|
|
|
|
|
case schema::ActivationType_RELU6:
|
|
|
|
|
arithmetic_run_ = ElementSubRelu6;
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
arithmetic_run_ = ElementSub;
|
|
|
|
|
arithmetic_run_int_ = ElementSubInt;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_Div:
|
|
|
|
|
case PrimitiveType_RealDiv:
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) {
|
|
|
|
|
case schema::ActivationType_RELU:
|
|
|
|
|
arithmetic_run_ = ElementDivRelu;
|
|
|
|
|
break;
|
|
|
|
|
case schema::ActivationType_RELU6:
|
|
|
|
|
arithmetic_run_ = ElementDivRelu6;
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
arithmetic_run_ = ElementDiv;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_LogicalAnd:
|
|
|
|
|
arithmetic_run_ = ElementLogicalAnd;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_LogicalOr:
|
|
|
|
|
arithmetic_run_ = ElementLogicalOr;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_Maximum:
|
|
|
|
|
arithmetic_run_ = ElementMaximum;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_Minimum:
|
|
|
|
|
arithmetic_run_ = ElementMinimum;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_FloorDiv:
|
|
|
|
|
arithmetic_run_ = ElementFloorDiv;
|
|
|
|
|
arithmetic_run_int_ = ElementFloorDivInt;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_FloorMod:
|
|
|
|
|
arithmetic_run_ = ElementFloorMod;
|
|
|
|
|
arithmetic_run_int_ = ElementFloorModInt;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_Equal:
|
|
|
|
|
arithmetic_run_ = ElementEqual;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_NotEqual:
|
|
|
|
|
arithmetic_run_ = ElementNotEqual;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_Less:
|
|
|
|
|
arithmetic_run_ = ElementLess;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_LessEqual:
|
|
|
|
|
arithmetic_run_ = ElementLessEqual;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_Greater:
|
|
|
|
|
arithmetic_run_ = ElementGreater;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_GreaterEqual:
|
|
|
|
|
arithmetic_run_ = ElementGreaterEqual;
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_SquaredDifference:
|
|
|
|
|
arithmetic_run_ = ElementSquaredDifference;
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
|
|
|
|
|
arithmetic_run_ = nullptr;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
InitRunFunction();
|
|
|
|
|
}
|
|
|
|
|
~ArithmeticCPUKernel() override;
|
|
|
|
|
|
|
|
|
@ -171,6 +67,20 @@ class ArithmeticCPUKernel : public LiteKernel {
|
|
|
|
|
virtual int DoArithmetic(int task_id);
|
|
|
|
|
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void InitRunFunction();
|
|
|
|
|
void InitOptRunFunction();
|
|
|
|
|
void InitParam();
|
|
|
|
|
void FreeTmpPtr();
|
|
|
|
|
int InitBroadCastCase();
|
|
|
|
|
void InitParamInRunTime();
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
bool input0_broadcast_ = false;
|
|
|
|
|
bool input1_broadcast_ = false;
|
|
|
|
|
void *input0_ptr_ = nullptr;
|
|
|
|
|
void *input1_ptr_ = nullptr;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
int break_pos_ = 0;
|
|
|
|
|
int outside_ = 0;
|
|
|
|
|