|
|
@ -48,14 +48,6 @@ using mindspore::schema::PrimitiveType_Sub;
|
|
|
|
|
|
|
|
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
namespace mindspore::kernel {
|
|
|
|
void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
|
|
|
|
void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
|
|
|
|
if (tile_data0_ != nullptr) {
|
|
|
|
|
|
|
|
free(tile_data0_);
|
|
|
|
|
|
|
|
tile_data0_ = nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (tile_data1_ != nullptr) {
|
|
|
|
|
|
|
|
free(tile_data1_);
|
|
|
|
|
|
|
|
tile_data1_ = nullptr;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (input0_fp16_ != nullptr) {
|
|
|
|
if (input0_fp16_ != nullptr) {
|
|
|
|
context_->allocator->Free(input0_fp16_);
|
|
|
|
context_->allocator->Free(input0_fp16_);
|
|
|
|
input0_fp16_ = nullptr;
|
|
|
|
input0_fp16_ = nullptr;
|
|
|
@ -70,7 +62,7 @@ void ArithmeticFP16CPUKernel::FreeTmpBuffer() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() { FreeTmpBuffer(); }
|
|
|
|
ArithmeticFP16CPUKernel::~ArithmeticFP16CPUKernel() {}
|
|
|
|
|
|
|
|
|
|
|
|
int ArithmeticFP16CPUKernel::Init() {
|
|
|
|
int ArithmeticFP16CPUKernel::Init() {
|
|
|
|
switch (op_parameter_->type_) {
|
|
|
|
switch (op_parameter_->type_) {
|
|
|
@ -162,7 +154,6 @@ int ArithmeticFP16CPUKernel::Init() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int ArithmeticFP16CPUKernel::ReSize() {
|
|
|
|
int ArithmeticFP16CPUKernel::ReSize() {
|
|
|
|
FreeTmpBuffer();
|
|
|
|
|
|
|
|
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
|
|
|
|
arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum();
|
|
|
|
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
|
|
|
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
|
|
|
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
|
|
|
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
|
|
@ -286,7 +277,7 @@ int ArithmeticFP16CPUKernel::ReSize() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim,
|
|
|
|
int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim,
|
|
|
|
int out_count, int out_thread_stride) {
|
|
|
|
int out_count, int out_thread_stride) {
|
|
|
|
if (dim > break_pos_) {
|
|
|
|
if (dim > break_pos_) {
|
|
|
|
int error_code =
|
|
|
|
int error_code =
|
|
|
|
arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, out_count);
|
|
|
|
arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, out_count);
|
|
|
@ -303,8 +294,8 @@ int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1,
|
|
|
|
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
|
|
|
|
int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i;
|
|
|
|
int error_code =
|
|
|
|
int error_code =
|
|
|
|
BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim],
|
|
|
|
BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim],
|
|
|
|
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim],
|
|
|
|
input1 + pos1_ * arithmeticParameter_->in_strides1_[dim],
|
|
|
|
output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride);
|
|
|
|
output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride);
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
if (error_code != RET_OK) {
|
|
|
|
return RET_ERROR;
|
|
|
|
return RET_ERROR;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -327,7 +318,6 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) {
|
|
|
|
|
|
|
|
|
|
|
|
if (arithmetic_run_ == nullptr) {
|
|
|
|
if (arithmetic_run_ == nullptr) {
|
|
|
|
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
|
|
|
|
MS_LOG(ERROR) << "arithmetic_run function is nullptr!";
|
|
|
|
FreeTmpBuffer();
|
|
|
|
|
|
|
|
return RET_ERROR;
|
|
|
|
return RET_ERROR;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -383,8 +373,7 @@ int ArithmeticFP16CPUKernel::Run() {
|
|
|
|
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
|
|
|
arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum();
|
|
|
|
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
|
|
|
arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum();
|
|
|
|
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
|
|
|
|
if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) {
|
|
|
|
output_fp16_ = reinterpret_cast<float16_t *>(
|
|
|
|
output_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
|
|
|
|
context_->allocator->Malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t)));
|
|
|
|
|
|
|
|
if (output_fp16_ == nullptr) {
|
|
|
|
if (output_fp16_ == nullptr) {
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!";
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!";
|
|
|
|
FreeTmpBuffer();
|
|
|
|
FreeTmpBuffer();
|
|
|
@ -392,8 +381,7 @@ int ArithmeticFP16CPUKernel::Run() {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
|
|
|
|
if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) {
|
|
|
|
input0_fp16_ = reinterpret_cast<float16_t *>(
|
|
|
|
input0_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
|
|
|
|
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
|
|
|
|
|
|
|
|
if (input0_fp16_ == nullptr) {
|
|
|
|
if (input0_fp16_ == nullptr) {
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!";
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!";
|
|
|
|
FreeTmpBuffer();
|
|
|
|
FreeTmpBuffer();
|
|
|
@ -403,8 +391,7 @@ int ArithmeticFP16CPUKernel::Run() {
|
|
|
|
arithmeticParameter_->in_elements_num0_);
|
|
|
|
arithmeticParameter_->in_elements_num0_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
|
|
|
|
if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) {
|
|
|
|
input1_fp16_ = reinterpret_cast<float16_t *>(
|
|
|
|
input1_fp16_ = reinterpret_cast<float16_t *>(malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
|
|
|
|
context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t)));
|
|
|
|
|
|
|
|
if (input1_fp16_ == nullptr) {
|
|
|
|
if (input1_fp16_ == nullptr) {
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!";
|
|
|
|
MS_LOG(ERROR) << "malloc data fail!";
|
|
|
|
FreeTmpBuffer();
|
|
|
|
FreeTmpBuffer();
|
|
|
@ -414,6 +401,7 @@ int ArithmeticFP16CPUKernel::Run() {
|
|
|
|
arithmeticParameter_->in_elements_num1_);
|
|
|
|
arithmeticParameter_->in_elements_num1_);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
ret = ParallelLaunch(THREAD_POOL_DEFAULT, ArithmeticsRun_Fp16, this, context_->thread_num_);
|
|
|
|
ret = ParallelLaunch(THREAD_POOL_DEFAULT, ArithmeticsRun_Fp16, this, context_->thread_num_);
|
|
|
|
|
|
|
|
FreeTmpBuffer();
|
|
|
|
return ret;
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -441,21 +429,21 @@ kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector<lite::tenso
|
|
|
|
return kernel;
|
|
|
|
return kernel;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Mul, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Mul, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Add, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Add, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sub, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Sub, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Div, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Div, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorMod, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorMod, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorDiv, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_FloorDiv, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalAnd, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalAnd, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalOr, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LogicalOr, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Maximum, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Maximum, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Minimum, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Minimum, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_NotEqual, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Equal, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Less, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_LessEqual, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Greater, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_GreaterEqual, CpuArithmeticFp16KernelCreator)
|
|
|
|
// REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator)
|
|
|
|
REG_KERNEL(kCPU, kNumberTypeFloat16, PrimitiveType_Eltwise, CpuArithmeticFp16KernelCreator)
|
|
|
|
} // namespace mindspore::kernel
|
|
|
|
} // namespace mindspore::kernel
|
|
|
|