[MSLITE] support mod op

pull/10129/head
ling 4 years ago
parent 4ce11a930b
commit f54ec95bca

@ -850,6 +850,48 @@ int BroadcastFloorMod(const float *input0, const float *input1, float *tile_inpu
return ElementFloorMod(tile_input0, tile_input1, output, element_size);
}
int ElementMod(const float *input0, const float *input1, float *output, const int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = fmod(input0[i], input1[i]);
}
return NNACL_OK;
}
int ElementModInt(const int *input0, const int *input1, int *output, const int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = fmod(input0[i], input1[i]);
}
return NNACL_OK;
}
int ElementOptMod(const float *input0, const float *input1, float *output, const int element_size,
const ArithmeticParameter *param) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < element_size; index++) {
output[index] = fmod(input0[0], input1[index]);
}
} else {
for (int index = 0; index < element_size; index++) {
output[index] = fmod(input0[index], input1[0]);
}
}
return NNACL_OK;
}
int ElementOptModInt(const int *input0, const int *input1, int *output, const int element_size,
const ArithmeticParameter *param) {
if (param->in_elements_num0_ == 1) {
for (int index = 0; index < element_size; index++) {
output[index] = fmod(input0[0], input1[index]);
}
} else {
for (int index = 0; index < element_size; index++) {
output[index] = fmod(input0[index], input1[0]);
}
}
return NNACL_OK;
}
int ElementFloorDiv(const float *input0, const float *input1, float *output, const int element_size) {
for (int i = 0; i < element_size; i++) {
output[i] = floorf(input0[i] / input1[i]);

@ -118,6 +118,13 @@ int ElementFloorModInt(const int *input0, const int *input1, int *output, const
int BroadcastFloorMod(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output,
int element_size, ArithmeticParameter *param);
int ElementMod(const float *input0, const float *input1, float *output, const int element_size);
int ElementModInt(const int *input0, const int *input1, int *output, const int element_size);
int ElementOptMod(const float *input0, const float *input1, float *output, const int element_size,
const ArithmeticParameter *param);
int ElementOptModInt(const int *input0, const int *input1, int *output, const int element_size,
const ArithmeticParameter *param);
int ElementSquaredDifference(const float *input0, const float *input1, float *output, const int element_size);
int BroadcastSquaredDifference(const float *input0, const float *input1, float *tile_input0, float *tile_input1,
float *output, int element_size, ArithmeticParameter *param);

@ -209,6 +209,10 @@ void ArithmeticCPUKernel::InitRunFunction() {
arithmetic_run_ = ElementFloorMod;
arithmetic_run_int_ = ElementFloorModInt;
break;
case PrimitiveType_Mod:
arithmetic_run_ = ElementMod;
arithmetic_run_int_ = ElementModInt;
break;
case PrimitiveType_SquaredDifference:
arithmetic_run_ = ElementSquaredDifference;
break;
@ -302,6 +306,11 @@ void ArithmeticCPUKernel::InitOptRunFunction() {
break;
}
break;
case PrimitiveType_Mod:
arithmeticParameter_->broadcasting_ = false;
arithmetic_opt_run_ = ElementOptMod;
arithmetic_opt_run_int_ = ElementOptModInt;
break;
default:
arithmetic_opt_run_ = nullptr;
arithmetic_opt_run_int_ = nullptr;
@ -534,16 +543,15 @@ kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector<lite::Tenso
}
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Mul, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Add, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Sub, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Div, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_RealDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Mod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_Mod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalAnd, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_LogicalOr, CpuArithmeticFp32KernelCreator)
@ -551,8 +559,6 @@ REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Maximum, CpuArithmeticFp32Ker
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_Minimum, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorDiv, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeInt32, PrimitiveType_FloorMod, CpuArithmeticFp32KernelCreator)
REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_SquaredDifference, CpuArithmeticFp32KernelCreator)

@ -35,6 +35,7 @@ using mindspore::schema::PrimitiveType_LogicalAnd;
using mindspore::schema::PrimitiveType_LogicalOr;
using mindspore::schema::PrimitiveType_Maximum;
using mindspore::schema::PrimitiveType_Minimum;
using mindspore::schema::PrimitiveType_Mod;
using mindspore::schema::PrimitiveType_Mul;
using mindspore::schema::PrimitiveType_NotEqual;
using mindspore::schema::PrimitiveType_RealDiv;

Loading…
Cancel
Save