!4201 modify CPU op arithmetic fp16

Merge pull request !4201 from 陶云浩/test
pull/4201/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 5613116caf

@ -48,14 +48,6 @@ using mindspore::schema::PrimitiveType_Sub;
namespace mindspore::kernel { namespace mindspore::kernel {
void ArithmeticFP16CPUKernel::FreeTmpBuffer() { void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
if (tile_data0_ != nullptr) {
free(tile_data0_);
tile_data0_ = nullptr;
}
if (tile_data1_ != nullptr) {
free(tile_data1_);
tile_data1_ = nullptr;
}
if (input0_fp16_ != nullptr) { if (input0_fp16_ != nullptr) {
context_->allocator->Free(input0_fp16_); context_->allocator->Free(input0_fp16_);
input0_fp16_ = nullptr; input0_fp16_ = nullptr;
@ -70,7 +62,7 @@ void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
} }
} }
ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() { FreeTmpBuffer(); } ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() {}
int ArithmeticFP16CPUKernel::Init() { int ArithmeticFP16CPUKernel::Init() {
switch (op_parameter_->type_) { switch (op_parameter_->type_) {
@ -162,7 +154,6 @@ int ArithmeticFP16CPUKernel::Init() {
} }
int ArithmeticFP16CPUKernel::ReSize() { int ArithmeticFP16CPUKernel::ReSize() {
FreeTmpBuffer();
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
@ -286,7 +277,7 @@ int ArithmeticFP16CPUKernel::ReSize() {
} }
int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim,
int out_count, int out_thread_stride) { int out_count, int out_thread_stride) {
if (dim > break_pos_) { if (dim > break_pos_) {
int error_code = int error_code =
arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, out_count); arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, out_count);
@ -303,8 +294,8 @@ int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1,
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i; int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
int error_code = int error_code =
BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim], BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim],
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim], input1 + pos1_ * arithmeticParameter_->in_strides1_[dim],
output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride); output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride);
if (error_code != RET_OK) { if (error_code != RET_OK) {
return RET_ERROR; return RET_ERROR;
} }
@ -327,7 +318,6 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
if (arithmetic_run_ == nullptr) { if (arithmetic_run_ == nullptr) {
MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
FreeTmpBuffer();
return RET_ERROR; return RET_ERROR;
} }
@ -383,8 +373,7 @@ int ArithmeticFP16CPUKernel::Run() {
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) { if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
output_fp16_ = reinterpret_cast<float16_t *>( output_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
context_->allocator->Malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
if (output_fp16_ == nullptr) { if (output_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!"; MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer(); FreeTmpBuffer();
@ -392,8 +381,7 @@ int ArithmeticFP16CPUKernel::Run() {
} }
} }
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
input0_fp16_ = reinterpret_cast<float16_t *>( input0_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input0_fp16_ == nullptr) { if (input0_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!"; MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer(); FreeTmpBuffer();
@ -403,8 +391,7 @@ int ArithmeticFP16CPUKernel::Run() {
arithmeticParameter_->in_elements_num0_); arithmeticParameter_->in_elements_num0_);
} }
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) { if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
input1_fp16_ = reinterpret_cast<float16_t *>( input1_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
if (input1_fp16_ == nullptr) { if (input1_fp16_ == nullptr) {
MS_LOG(ERROR) << "malloc data fail!"; MS_LOG(ERROR) << "malloc data fail!";
FreeTmpBuffer(); FreeTmpBuffer();
@ -414,6 +401,7 @@ int ArithmeticFP16CPUKernel::Run() {
arithmeticParameter_->in_elements_num1_); arithmeticParameter_->in_elements_num1_);
} }
ret = ParallelLaunch(THREAD_POOL_DEFAULT, ArithmeticsRun_Fp16, this, context_->thread_num_); ret = ParallelLaunch(THREAD_POOL_DEFAULT, ArithmeticsRun_Fp16, this, context_->thread_num_);
FreeTmpBuffer();
return ret; return ret;
} }
@ -441,21 +429,21 @@ kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tenso
return kernel; return kernel;
} }
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Mul, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Mul, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Add, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Add, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sub, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sub, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Div, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Div, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorMod, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorMod, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorDiv, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorDiv, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalAnd, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalAnd, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalOr, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalOr, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Maximum, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Maximum, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Minimum, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Minimum, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator)
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator) REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator)
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -50,8 +50,6 @@ class ArithmeticFP16CPUKernel : public LiteKernel {
int break_pos_; int break_pos_;
int out_thread_stride_; int out_thread_stride_;
int out_count_; int out_count_;
float16_t *tile_data0_ = nullptr;
float16_t *tile_data1_ = nullptr;
float16_t *input0_fp16_ = nullptr; float16_t *input0_fp16_ = nullptr;
float16_t *input1_fp16_ = nullptr; float16_t *input1_fp16_ = nullptr;
float16_t *output_fp16_ = nullptr; float16_t *output_fp16_ = nullptr;

Loading…
Cancel
Save