|
|
|
@ -28,7 +28,21 @@ using mindspore::lite::RET_ERROR;
|
|
|
|
|
using mindspore::lite::RET_OK;
|
|
|
|
|
|
|
|
|
|
using mindspore::schema::PrimitiveType_Add;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Div;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Equal;
|
|
|
|
|
using mindspore::schema::PrimitiveType_FloorDiv;
|
|
|
|
|
using mindspore::schema::PrimitiveType_FloorMod;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Greater;
|
|
|
|
|
using mindspore::schema::PrimitiveType_GreaterEqual;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Less;
|
|
|
|
|
using mindspore::schema::PrimitiveType_LessEqual;
|
|
|
|
|
using mindspore::schema::PrimitiveType_LogicalAnd;
|
|
|
|
|
using mindspore::schema::PrimitiveType_LogicalOr;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Maximum;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Minimum;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Mul;
|
|
|
|
|
using mindspore::schema::PrimitiveType_NotEqual;
|
|
|
|
|
using mindspore::schema::PrimitiveType_SquaredDifference;
|
|
|
|
|
using mindspore::schema::PrimitiveType_Sub;
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
@ -97,7 +111,44 @@ int ArithmeticFP16CPUKernel::Init() {
|
|
|
|
|
arithmetic_run_ = ElementSubFp16;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case PrimitiveType_Div:
|
|
|
|
|
switch (arithmeticParameter_->activation_type_) {
|
|
|
|
|
case schema::ActivationType_RELU:
|
|
|
|
|
arithmetic_run_ = ElementDivReluFp16;
|
|
|
|
|
break;
|
|
|
|
|
case schema::ActivationType_RELU6:
|
|
|
|
|
arithmetic_run_ = ElementDivRelu6Fp16;
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
arithmetic_run_ = ElementDivFp16;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
case PrimitiveType_FloorMod:
|
|
|
|
|
arithmetic_run_ = ElementFloorModFp16;
|
|
|
|
|
case PrimitiveType_FloorDiv:
|
|
|
|
|
arithmetic_run_ = ElementFloorDivFp16;
|
|
|
|
|
case PrimitiveType_LogicalAnd:
|
|
|
|
|
arithmetic_run_ = ElementLogicalAndFp16;
|
|
|
|
|
case PrimitiveType_LogicalOr:
|
|
|
|
|
arithmetic_run_ = ElementLogicalOrFp16;
|
|
|
|
|
case PrimitiveType_SquaredDifference:
|
|
|
|
|
arithmetic_run_ = ElementSquaredDifferenceFp16;
|
|
|
|
|
case PrimitiveType_Maximum:
|
|
|
|
|
arithmetic_run_ = ElementMaximumFp16;
|
|
|
|
|
case PrimitiveType_Minimum:
|
|
|
|
|
arithmetic_run_ = ElementMinimumFp16;
|
|
|
|
|
case PrimitiveType_NotEqual:
|
|
|
|
|
arithmetic_run_ = ElementNotEqualFp16;
|
|
|
|
|
case PrimitiveType_Equal:
|
|
|
|
|
arithmetic_run_ = ElementEqualFp16;
|
|
|
|
|
case PrimitiveType_Less:
|
|
|
|
|
arithmetic_run_ = ElementLessFp16;
|
|
|
|
|
case PrimitiveType_LessEqual:
|
|
|
|
|
arithmetic_run_ = ElementLessEqual;
|
|
|
|
|
case PrimitiveType_Greater:
|
|
|
|
|
arithmetic_run_ = ElementGreaterFp16;
|
|
|
|
|
case PrimitiveType_GreaterEqual:
|
|
|
|
|
arithmetic_run_ = ElementGreaterEqualFp16;
|
|
|
|
|
default:
|
|
|
|
|
MS_LOG(ERROR) << "Error Operator type " << op_parameter_->type_;
|
|
|
|
|
arithmetic_run_ = nullptr;
|
|
|
|
@ -115,8 +166,8 @@ int ArithmeticFP16CPUKernel::ReSize() {
|
|
|
|
|
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
|
|
|
|
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
|
|
|
|
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
|
|
|
|
|
input0_fp16_ = reinterpret_cast<float16_t *>(context_->allocator->Malloc(
|
|
|
|
|
arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
|
|
|
|
|
input0_fp16_ = reinterpret_cast<float16_t *>(
|
|
|
|
|
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
|
|
|
|
|
if (input0_fp16_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -125,8 +176,8 @@ int ArithmeticFP16CPUKernel::ReSize() {
|
|
|
|
|
arithmeticParameter_->in_elements_num0_);
|
|
|
|
|
}
|
|
|
|
|
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
|
|
|
|
|
input1_fp16_ = reinterpret_cast<float16_t *>(context_->allocator->Malloc(
|
|
|
|
|
arithmeticParameter_->in_elements_num1_ * sizeof(float16_t)));
|
|
|
|
|
input1_fp16_ = reinterpret_cast<float16_t *>(
|
|
|
|
|
context_->allocator->Malloc(arithmeticParameter_->in_elements_num1_ * sizeof(float16_t)));
|
|
|
|
|
if (input0_fp16_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -135,8 +186,8 @@ int ArithmeticFP16CPUKernel::ReSize() {
|
|
|
|
|
arithmeticParameter_->in_elements_num1_);
|
|
|
|
|
}
|
|
|
|
|
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
|
|
|
|
|
output_fp16_ = reinterpret_cast<float16_t *>(context_->allocator->Malloc(
|
|
|
|
|
arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
|
|
|
|
|
output_fp16_ = reinterpret_cast<float16_t *>(
|
|
|
|
|
context_->allocator->Malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
|
|
|
|
|
if (output_fp16_ == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -197,22 +248,22 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
|
|
|
|
|
|
|
|
|
|
int error_code = RET_OK;
|
|
|
|
|
if (arithmeticParameter_->broadcasting_) {
|
|
|
|
|
error_code = arithmetic_run_(tile_data0_ + thread_stride, tile_data1_ + thread_stride,
|
|
|
|
|
output_data + thread_stride, count);
|
|
|
|
|
error_code =
|
|
|
|
|
arithmetic_run_(tile_data0_ + thread_stride, tile_data1_ + thread_stride, output_data + thread_stride, count);
|
|
|
|
|
} else if (arithmetic_opt_run_ != nullptr) {
|
|
|
|
|
if (arithmeticParameter_->in_elements_num0_ == 1) {
|
|
|
|
|
error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride,
|
|
|
|
|
count, arithmeticParameter_);
|
|
|
|
|
error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count,
|
|
|
|
|
arithmeticParameter_);
|
|
|
|
|
} else if (arithmeticParameter_->in_elements_num1_ == 1) {
|
|
|
|
|
error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1, output_data + thread_stride,
|
|
|
|
|
count, arithmeticParameter_);
|
|
|
|
|
error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1, output_data + thread_stride, count,
|
|
|
|
|
arithmeticParameter_);
|
|
|
|
|
} else {
|
|
|
|
|
error_code = arithmetic_opt_run_(input0_data + thread_stride, input1_data1 + thread_stride,
|
|
|
|
|
output_data + thread_stride, count, arithmeticParameter_);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
error_code = arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride,
|
|
|
|
|
output_data + thread_stride, count);
|
|
|
|
|
error_code =
|
|
|
|
|
arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride, output_data + thread_stride, count);
|
|
|
|
|
}
|
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|