diff --git a/mindspore/lite/nnacl/fp32/arithmetic_fp32.c b/mindspore/lite/nnacl/fp32/arithmetic_fp32.c index d888be181e..770815589d 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_fp32.c +++ b/mindspore/lite/nnacl/fp32/arithmetic_fp32.c @@ -850,6 +850,48 @@ int BroadcastFloorMod(const float *input0, const float *input1, float *tile_inpu return ElementFloorMod(tile_input0, tile_input1, output, element_size); } +int ElementMod(const float *input0, const float *input1, float *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = fmod(input0[i], input1[i]); + } + return NNACL_OK; +} + +int ElementModInt(const int *input0, const int *input1, int *output, const int element_size) { + for (int i = 0; i < element_size; i++) { + output[i] = fmod(input0[i], input1[i]); + } + return NNACL_OK; +} + +int ElementOptMod(const float *input0, const float *input1, float *output, const int element_size, + const ArithmeticParameter *param) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < element_size; index++) { + output[index] = fmod(input0[0], input1[index]); + } + } else { + for (int index = 0; index < element_size; index++) { + output[index] = fmod(input0[index], input1[0]); + } + } + return NNACL_OK; +} + +int ElementOptModInt(const int *input0, const int *input1, int *output, const int element_size, + const ArithmeticParameter *param) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < element_size; index++) { + output[index] = fmod(input0[0], input1[index]); + } + } else { + for (int index = 0; index < element_size; index++) { + output[index] = fmod(input0[index], input1[0]); + } + } + return NNACL_OK; +} + int ElementFloorDiv(const float *input0, const float *input1, float *output, const int element_size) { for (int i = 0; i < element_size; i++) { output[i] = floorf(input0[i] / input1[i]); diff --git a/mindspore/lite/nnacl/fp32/arithmetic_fp32.h b/mindspore/lite/nnacl/fp32/arithmetic_fp32.h index f3261f0ddb..e2d1ac6a28 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic_fp32.h +++ b/mindspore/lite/nnacl/fp32/arithmetic_fp32.h @@ -118,6 +118,13 @@ int ElementFloorModInt(const int *input0, const int *input1, int *output, const int BroadcastFloorMod(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param); +int ElementMod(const float *input0, const float *input1, float *output, const int element_size); +int ElementModInt(const int *input0, const int *input1, int *output, const int element_size); +int ElementOptMod(const float *input0, const float *input1, float *output, const int element_size, + const ArithmeticParameter *param); +int ElementOptModInt(const int *input0, const int *input1, int *output, const int element_size, + const ArithmeticParameter *param); + int ElementSquaredDifference(const float *input0, const float *input1, float *output, const int element_size); int BroadcastSquaredDifference(const float *input0, const float *input1, float *tile_input0, float *tile_input1, float *output, int element_size, ArithmeticParameter *param); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc index 5bdf8a123a..6db337e0dd 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic_fp32.cc @@ -209,6 +209,10 @@ void ArithmeticCPUKernel::InitRunFunction() { arithmetic_run_ = ElementFloorMod; arithmetic_run_int_ = ElementFloorModInt; break; + case PrimitiveType_Mod: + arithmetic_run_ = ElementMod; + arithmetic_run_int_ = ElementModInt; + break; case PrimitiveType_SquaredDifference: arithmetic_run_ = ElementSquaredDifference; break; @@ -302,6 +306,11 @@ void ArithmeticCPUKernel::InitOptRunFunction() { break; } break; + case PrimitiveType_Mod: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptMod; + arithmetic_opt_run_int_ = ElementOptModInt; + break; default: arithmetic_opt_run_ = nullptr; arithmetic_opt_run_int_ = nullptr; @@ -534,16 +543,15 @@ kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector