[MSLITE] fp32 add optimize

pull/9420/head
ling 4 years ago
parent 07dbe082c5
commit 09417ba444

@ -153,8 +153,8 @@ int TransposeFp16CPUKernel::Run() {
fp16_out_data_ = reinterpret_cast<float16_t *>(out_tensor->MutableData());
}
in_shape_ = const_cast<int *>(in_tensor->shape().data());
out_shape_ = const_cast<int *>(out_tensor->shape().data());
memcpy(in_shape_, in_tensor->shape().data(), in_tensor->shape().size() * sizeof(int));
memcpy(out_shape_, out_tensor->shape().data(), out_tensor->shape().size() * sizeof(int));
ret = ParallelLaunch(this->context_->thread_pool_, TransposeFp16Run, this, thread_h_num_);
if (ret != RET_OK) {

@ -48,8 +48,8 @@ class TransposeFp16CPUKernel : public LiteKernel {
float *out_data_;
float16_t *fp16_in_data_ = nullptr;
float16_t *fp16_out_data_ = nullptr;
int *in_shape_;
int *out_shape_;
int in_shape_[8];
int out_shape_[8];
};
} // namespace mindspore::kernel

@ -56,111 +56,7 @@ class ArithmeticCPUKernel : public LiteKernel {
const mindspore::lite::PrimitiveC *primitive)
: LiteKernel(parameter, inputs, outputs, ctx, primitive), thread_count_(ctx->thread_num_) {
arithmeticParameter_ = reinterpret_cast<ArithmeticParameter *>(parameter);
switch (parameter->type_) {
case PrimitiveType_Mul:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementMulRelu;
arithmetic_run_int_ = ElementMulReluInt;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementMulRelu6;
arithmetic_run_int_ = ElementMulRelu6Int;
break;
default:
arithmetic_run_ = ElementMul;
arithmetic_run_int_ = ElementMulInt;
break;
}
break;
case PrimitiveType_Add:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementAddRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementAddRelu6;
break;
default:
arithmetic_run_ = ElementAdd;
arithmetic_run_int_ = ElementAddInt;
break;
}
break;
case PrimitiveType_Sub:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementSubRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementSubRelu6;
break;
default:
arithmetic_run_ = ElementSub;
arithmetic_run_int_ = ElementSubInt;
break;
}
break;
case PrimitiveType_Div:
case PrimitiveType_RealDiv:
switch (arithmeticParameter_->activation_type_) {
case schema::ActivationType_RELU:
arithmetic_run_ = ElementDivRelu;
break;
case schema::ActivationType_RELU6:
arithmetic_run_ = ElementDivRelu6;
break;
default:
arithmetic_run_ = ElementDiv;
break;
}
break;
case PrimitiveType_LogicalAnd:
arithmetic_run_ = ElementLogicalAnd;
break;
case PrimitiveType_LogicalOr:
arithmetic_run_ = ElementLogicalOr;
break;
case PrimitiveType_Maximum:
arithmetic_run_ = ElementMaximum;
break;
case PrimitiveType_Minimum:
arithmetic_run_ = ElementMinimum;
break;
case PrimitiveType_FloorDiv:
arithmetic_run_ = ElementFloorDiv;
arithmetic_run_int_ = ElementFloorDivInt;
break;
case PrimitiveType_FloorMod:
arithmetic_run_ = ElementFloorMod;
arithmetic_run_int_ = ElementFloorModInt;
break;
case PrimitiveType_Equal:
arithmetic_run_ = ElementEqual;
break;
case PrimitiveType_NotEqual:
arithmetic_run_ = ElementNotEqual;
break;
case PrimitiveType_Less:
arithmetic_run_ = ElementLess;
break;
case PrimitiveType_LessEqual:
arithmetic_run_ = ElementLessEqual;
break;
case PrimitiveType_Greater:
arithmetic_run_ = ElementGreater;
break;
case PrimitiveType_GreaterEqual:
arithmetic_run_ = ElementGreaterEqual;
break;
case PrimitiveType_SquaredDifference:
arithmetic_run_ = ElementSquaredDifference;
break;
default:
MS_LOG(ERROR) << "Error Operator type " << parameter->type_;
arithmetic_run_ = nullptr;
break;
}
InitRunFunction();
}
~ArithmeticCPUKernel() override;
@ -171,6 +67,20 @@ class ArithmeticCPUKernel : public LiteKernel {
virtual int DoArithmetic(int task_id);
virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride);
private:
void InitRunFunction();
void InitOptRunFunction();
void InitParam();
void FreeTmpPtr();
int InitBroadCastCase();
void InitParamInRunTime();
private:
bool input0_broadcast_ = false;
bool input1_broadcast_ = false;
void *input0_ptr_ = nullptr;
void *input1_ptr_ = nullptr;
protected:
int break_pos_ = 0;
int outside_ = 0;

@ -263,7 +263,6 @@ int MatmulCPUKernel::RunImpl(int task_id) {
MS_ASSERT(cur_a_ptr_);
MS_ASSERT(b);
MS_ASSERT(c);
MS_ASSERT(bias);
if (is_vector_a_) {
MatVecMul(cur_a_ptr_, b, c, bias, ActType_No, params_->deep_, cur_oc);
} else {

@ -143,8 +143,8 @@ int TransposeInt8CPUKernel::Run() {
in_ptr_ = reinterpret_cast<int8_t *>(in_tensor->data_c());
out_ptr_ = reinterpret_cast<int8_t *>(out_tensor->data_c());
in_shape_ = in_dims.data();
out_shape_ = out_dims.data();
memcpy(in_shape_, in_dims.data(), in_dims.size() * sizeof(int));
memcpy(out_shape_, out_dims.data(), out_dims.size() * sizeof(int));
int ret = MallocTmpBuf();
if (ret != RET_OK) {
@ -157,8 +157,6 @@ int TransposeInt8CPUKernel::Run() {
}
FreeTmpBuf();
in_shape_ = nullptr;
out_shape_ = nullptr;
return ret;
}

@ -48,8 +48,6 @@ class TransposeInt8CPUKernel : public LiteKernel {
TransposeParameter *transpose_param_;
int8_t *in_ptr_ = nullptr;
int8_t *out_ptr_ = nullptr;
int *in_shape_ = nullptr;
int *out_shape_ = nullptr;
int *dim_size_ = nullptr;
int *position_ = nullptr;
bool extra_dims_ = false;
@ -57,6 +55,8 @@ class TransposeInt8CPUKernel : public LiteKernel {
int thread_h_stride_ = 0;
int thread_h_num_ = 0;
int num_unit_ = 0;
int in_shape_[8];
int out_shape_[8];
};
} // namespace mindspore::kernel

Loading…
Cancel
Save