diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index 373ebf7879..35334f0054 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -29,6 +29,7 @@ using mindspore::lite::RET_OK; using mindspore::schema::PrimitiveType_Add; using mindspore::schema::PrimitiveType_Div; +using mindspore::schema::PrimitiveType_Eltwise; using mindspore::schema::PrimitiveType_Equal; using mindspore::schema::PrimitiveType_FloorDiv; using mindspore::schema::PrimitiveType_FloorMod; @@ -172,8 +173,6 @@ int ArithmeticFP16CPUKernel::ReSize() { MS_LOG(ERROR) << "malloc data fail!"; return RET_ERROR; } - Float32ToFloat16(reinterpret_cast(in_tensors_[0]->Data()), input0_fp16_, - arithmeticParameter_->in_elements_num0_); } if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) { input1_fp16_ = reinterpret_cast( @@ -182,8 +181,6 @@ int ArithmeticFP16CPUKernel::ReSize() { MS_LOG(ERROR) << "malloc data fail!"; return RET_ERROR; } - Float32ToFloat16(reinterpret_cast(in_tensors_[1]->Data()), input1_fp16_, - arithmeticParameter_->in_elements_num1_); } if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) { output_fp16_ = reinterpret_cast( @@ -297,15 +294,33 @@ int ArithmeticFP16CPUKernel::ReSize() { } if (arithmeticParameter_->broadcasting_) { - auto tile_size = arithmeticParameter_->out_elements_num_ * sizeof(float16_t); - tile_data0_ = reinterpret_cast(malloc(tile_size)); - tile_data1_ = reinterpret_cast(malloc(tile_size)); - if (tile_data0_ == nullptr || tile_data1_ == nullptr) { - MS_LOG(ERROR) << "malloc tile data fail!"; - return RET_ERROR; + outside_ = 1; + for (int i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) { + if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) { + break_pos_ = i; + break; + } + outside_ *= arithmeticParameter_->out_shape_[i]; } + ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_); + ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_); + ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_); } + return RET_OK; +} +int ArithmeticFP16CPUKernel::broadcast_run_(float16_t *input0, float16_t *input1, float16_t *output, int dim) { + if (dim > break_pos_) { + return arithmetic_run_(input0 + out_thread_stride_, input1 + out_thread_stride_, output + out_thread_stride_, + out_count_); + } + for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { + int pos0_ = arithmeticParameter_->in_shape0_[0] == 1 ? 0 : i; + int pos1_ = arithmeticParameter_->in_shape1_[0] == 1 ? 0 : i; + return broadcast_run_(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim], + input1 + pos1_ * arithmeticParameter_->in_strides1_[dim], + output + i * arithmeticParameter_->out_strides_[dim], dim + 1); + } return RET_OK; } @@ -329,8 +344,10 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { int error_code = RET_OK; if (arithmeticParameter_->broadcasting_) { - error_code = - arithmetic_run_(tile_data0_ + thread_stride, tile_data1_ + thread_stride, output_data + thread_stride, count); + stride = UP_DIV(outside_, context_->thread_num_); + out_count_ = MSMIN(stride, outside_ - stride * task_id); + out_thread_stride_ = stride * task_id; + error_code = broadcast_run_(input0_data, input1_data1, output_data, 0); } else if (arithmetic_opt_run_ != nullptr) { if (arithmeticParameter_->in_elements_num0_ == 1) { error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count, @@ -373,13 +390,15 @@ int ArithmeticFP16CPUKernel::Run() { return ret; } - if (arithmeticParameter_->broadcasting_) { - auto input_data0 = reinterpret_cast(in_tensors_[0]->Data()); - auto input_data1 = reinterpret_cast(in_tensors_[1]->Data()); - float16_t *input0 = input0_fp16_ == nullptr ? input_data0 : input0_fp16_; - float16_t *input1 = input1_fp16_ == nullptr ? input_data1 : input1_fp16_; - TileDimensionsFp16(input0, input1, tile_data0_, tile_data1_, arithmeticParameter_); + if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { + Float32ToFloat16(reinterpret_cast(in_tensors_[0]->Data()), input0_fp16_, + arithmeticParameter_->in_elements_num0_); } + if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) { + Float32ToFloat16(reinterpret_cast(in_tensors_[1]->Data()), input1_fp16_, + arithmeticParameter_->in_elements_num1_); + } + ret = LiteBackendParallelLaunch(ArithmeticsRun, this, context_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Arithmetic function fail!ret: " << ret; @@ -428,4 +447,5 @@ REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16Kernel REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator) +REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator) } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h index c978f24971..0a8a9f67b2 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.h @@ -30,20 +30,25 @@ class ArithmeticFP16CPUKernel : public LiteKernel { public: ArithmeticFP16CPUKernel(OpParameter *parameter, const std::vector &inputs, - const std::vector &outputs, const lite::Context *ctx, - const mindspore::lite::PrimitiveC *primitive) + const std::vector &outputs, const lite::Context *ctx, + const mindspore::lite::PrimitiveC *primitive) : LiteKernel(parameter, inputs, outputs, ctx, primitive) { - arithmeticParameter_ = reinterpret_cast(parameter); - } + arithmeticParameter_ = reinterpret_cast(parameter); + } ~ArithmeticFP16CPUKernel() override; int Init() override; int ReSize() override; int Run() override; int DoArithmetic(int task_id); + int broadcast_run_(float16_t *input0, float16_t *input1, float16_t *output, int dim); private: void FreeTmpBuffer(); + int break_pos_; + int outside_; + int out_thread_stride_; + int out_count_; float16_t *tile_data0_ = nullptr; float16_t *tile_data1_ = nullptr; float16_t *input0_fp16_ = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.c b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.c index a801e93621..35f8be3907 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.c +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.c @@ -18,33 +18,6 @@ #include #include "nnacl/arithmetic_common.h" -void TileOneDimensionFp16(float16_t *inData, float16_t *outData, int dim, size_t ndim, int *inShape, int *inStrides, - int *outStrides, int *multiple) { - int srcDimSize = inShape[dim]; - if (dim == ndim - 1) { - for (int i = 0; i < multiple[dim]; i++) { - memcpy(outData, inData, srcDimSize * sizeof(float16_t)); - outData += srcDimSize; - } - return; - } - for (size_t i = 0; i < srcDimSize; i++) { - for (size_t j = 0; j < multiple[dim]; j++) { - TileOneDimensionFp16(inData + inStrides[dim] * i, outData + outStrides[dim] * (i + j * srcDimSize), dim + 1, ndim, - inShape, inStrides, outStrides, multiple); - } - } -} - -void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, - ArithmeticParameter *param) { - CalcMultiplesAndStrides(param); - TileOneDimensionFp16(data0, tile_data0, 0, param->ndim_, param->in_shape0_, param->in_strides0_, param->out_strides_, - param->multiples0_); - TileOneDimensionFp16(data1, tile_data1, 0, param->ndim_, param->in_shape1_, param->in_strides1_, param->out_strides_, - param->multiples1_); -} - int ElementMulFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; diff --git a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h index c3369519a6..5bffd41e5c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h +++ b/mindspore/lite/src/runtime/kernel/arm/nnacl/fp16/arithmetic_fp16.h @@ -111,8 +111,6 @@ int ElementLessEqual(float16_t *input0, float16_t *input1, float16_t *output, in int ElementGreaterFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); int ElementGreaterEqualFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size); -void TileDimensionsFp16(float16_t *data0, float16_t *data1, float16_t *tile_data0, float16_t *tile_data1, - ArithmeticParameter *param); #ifdef __cplusplus } #endif