From 6665e537f60f16cb08d94f92fd19013126071f2e Mon Sep 17 00:00:00 2001 From: tao_yunhao Date: Mon, 24 Aug 2020 15:50:18 +0800 Subject: [PATCH] modify arm cpu fp16&fp32 op: Arithmetic --- mindspore/lite/nnacl/fp16/arithmetic_fp16.c | 378 +++++++++------ mindspore/lite/nnacl/fp32/arithmetic.c | 442 +++++++++++++++++- mindspore/lite/nnacl/fp32/arithmetic.h | 6 + .../kernel/arm/fp16/arithmetic_fp16.cc | 155 +++--- .../runtime/kernel/arm/fp16/arithmetic_fp16.h | 4 +- .../src/runtime/kernel/arm/fp32/arithmetic.cc | 135 ++++-- .../src/runtime/kernel/arm/fp32/arithmetic.h | 26 +- 7 files changed, 842 insertions(+), 304 deletions(-) diff --git a/mindspore/lite/nnacl/fp16/arithmetic_fp16.c b/mindspore/lite/nnacl/fp16/arithmetic_fp16.c index a801e93621..ef8ae6fd64 100644 --- a/mindspore/lite/nnacl/fp16/arithmetic_fp16.c +++ b/mindspore/lite/nnacl/fp16/arithmetic_fp16.c @@ -74,33 +74,48 @@ int ElementOptMulFp16(float16_t *input0, float16_t *input1, float16_t *output, i ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vmulq_f16(vin0, vin1); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vmulq_f16(vin0, vin1); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = in0 * in1; + for (int i = 0; i < C8NUM; ++i) { + output[i] = in0_opt * input1[i]; + } +#endif + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = in0_opt * input1[index]; } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vmulq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = input0[i] * in1_opt; + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = in0 * in1; + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] * in1_opt; + } } return NNACL_OK; @@ -113,7 +128,6 @@ int ElementMulReluFp16(float16_t *input0, float16_t *input1, float16_t *output, #ifdef ENABLE_NEON float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON float16x8_t vin0 = vld1q_f16(input0); @@ -143,39 +157,58 @@ int ElementOptMulReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { -#ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vmulq_f16(vin0, vin1); - vout = vmaxq_f16(vout, zeros); - vst1q_f16(output, vout); + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output, vout); #else - float16_t res; - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - res = in0 * in1; - output[i] = res > 0 ? res : 0; + float16_t res; + for (int i = 0; i < C8NUM; ++i) { + res = in0_opt * input1[i]; + output[i] = res > 0 ? res : 0; + } +#endif + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t res = in0_opt * input1[index]; + output[index] = res > 0 ? res : 0; } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output, vout); +#else + float16_t res; + for (int i = 0; i < C8NUM; ++i) { + res = input0[i] * in1_opt; + output[i] = res > 0 ? res : 0; + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - float16_t res = in0 * in1; - output[index] = res > 0 ? res : 0; + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t res = input0[index] * in1_opt; + output[index] = res > 0 ? res : 0; + } } return NNACL_OK; @@ -216,37 +249,52 @@ int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; #endif - - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vmulq_f16(vin0, vin1); - vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = MSMIN(MSMAX(in0 * in1, 0), 6); + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMIN(MSMAX(in0_opt * input1[i], 0), 6); + } +#endif + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(in0_opt * input1[index], 0), 6); } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vmulq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMIN(MSMAX(input0[i] * in1_opt, 0), 6); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = MSMIN(MSMAX(in0 * in1, 0), 6); + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] * in1_opt, 0), 6); + } } return NNACL_OK; @@ -255,7 +303,6 @@ int ElementOptMulRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp int ElementAddFp16(float16_t *input0, float16_t *input1, float16_t *output, int element_size) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; - for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON float16x8_t vin0 = vld1q_f16(input0); @@ -280,34 +327,50 @@ int ElementOptAddFp16(float16_t *input0, float16_t *input1, float16_t *output, i ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; #endif - for (int index = 0; index < block_c8; index += C8NUM) { + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vaddq_f16(vin0, vin1); - vst1q_f16(output, vout); + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vaddq_f16(vin0, vin1); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = in0 + in1; + for (int i = 0; i < C8NUM; ++i) { + output[i] = in0_opt + input1[i]; + } +#endif + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = in0_opt + input1[index]; } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vaddq_f16(vin0, vin1); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = input0[i] + in1_opt; + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = in0 + in1; + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] + in1_opt; + } } + return NNACL_OK; } @@ -345,37 +408,54 @@ int ElementOptAddReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { -#ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vaddq_f16(vin0, vin1); - vout = vmaxq_f16(vout, zeros); - vst1q_f16(output, vout); + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = MSMAX(in0 + in1, 0); + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMAX(in0_opt + input1[i], 0); + } +#endif + input1 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t res = in0_opt + input1[index]; + output[index] = res > 0 ? res : 0; } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vmaxq_f16(vout, zeros); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMAX(input0[i] + in1_opt, 0); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - float16_t res = in0 + in1; - output[index] = res > 0 ? res : 0; + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + float16_t res = input0[index] + in1_opt; + output[index] = res > 0 ? res : 0; + } } return NNACL_OK; } @@ -415,39 +495,54 @@ int ElementOptAddRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; #endif - for (int index = 0; index < block_c8; index += C8NUM) { -#ifdef ENABLE_NEON - float16x8_t vin0 = param->in_elements_num0_ == 1 ? vin0_opt : vld1q_f16(input0); - float16x8_t vin1 = param->in_elements_num1_ == 1 ? vin1_opt : vld1q_f16(input1); - float16x8_t vout = vaddq_f16(vin0, vin1); - vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); - vst1q_f16(output, vout); + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vin0_opt; + float16x8_t vin1 = vld1q_f16(input1); + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output, vout); #else - for (int i = 0; i < C8NUM; ++i) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[i]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[i]; - output[i] = MSMIN(MSMAX(in0 + in1, 0), 6); + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMIN(MSMAX(in0_opt + input1[i], 0), 6); + } +#endif + input1 += C8NUM; + output += C8NUM; } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(in0_opt + input1[index], 0), 6); + } + } else { + for (int index = 0; index < block_c8; index += C8NUM) { +#ifdef ENABLE_NEON + float16x8_t vin0 = vld1q_f16(input0); + float16x8_t vin1 = vin1_opt; + float16x8_t vout = vaddq_f16(vin0, vin1); + vout = vminq_f16(vmaxq_f16(vout, zeros), bounds); + vst1q_f16(output, vout); +#else + for (int i = 0; i < C8NUM; ++i) { + output[i] = MSMIN(MSMAX(input0[i] + in1_opt, 0), 6); + } #endif - input0 += C8NUM; - input1 += C8NUM; - output += C8NUM; - } - for (int index = 0; index < block_mod; ++index) { - float16_t in0 = param->in_elements_num0_ == 1 ? in0_opt : input0[index]; - float16_t in1 = param->in_elements_num1_ == 1 ? in1_opt : input1[index]; - output[index] = MSMIN(MSMAX(in0 + in1, 0), 6); + input0 += C8NUM; + output += C8NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] + in1_opt, 0), 6); + } } - return NNACL_OK; } @@ -479,11 +574,11 @@ int ElementOptSubFp16(float16_t *input0, float16_t *input1, float16_t *output, i ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; #endif for (int index = 0; index < block_c8; index += C8NUM) { #ifdef ENABLE_NEON @@ -542,11 +637,11 @@ int ElementOptSubReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; #endif for (int index = 0; index < block_c8; index += C8NUM) { @@ -609,11 +704,11 @@ int ElementOptSubRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; #endif @@ -680,11 +775,11 @@ int ElementOptDivFp16(float16_t *input0, float16_t *input1, float16_t *output, i ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; #endif for (int index = 0; index < block_c8; index += C8NUM) { if (param->in_elements_num1_ == 1) { @@ -765,12 +860,11 @@ int ElementOptDivReluFp16(float16_t *input0, float16_t *input1, float16_t *outpu ArithmeticParameter *param) { int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; - + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; #endif for (int index = 0; index < block_c8; index += C8NUM) { @@ -855,11 +949,11 @@ int ElementOptDivRelu6Fp16(float16_t *input0, float16_t *input1, float16_t *outp int block_mod = element_size % C8NUM; int block_c8 = element_size - block_mod; + float16_t in0_opt = input0[0]; + float16_t in1_opt = input1[0]; #ifdef ENABLE_NEON float16x8_t vin0_opt = {input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0], input0[0]}; float16x8_t vin1_opt = {input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0], input1[0]}; - float16_t in0_opt = input0[0]; - float16_t in1_opt = input1[0]; float16x8_t zeros = {0, 0, 0, 0, 0, 0, 0, 0}; float16x8_t bounds = {6, 6, 6, 6, 6, 6, 6, 6}; #endif diff --git a/mindspore/lite/nnacl/fp32/arithmetic.c b/mindspore/lite/nnacl/fp32/arithmetic.c index 0bfc8e6164..d08ba5c92c 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic.c +++ b/mindspore/lite/nnacl/fp32/arithmetic.c @@ -20,55 +20,455 @@ #define ACCURACY_DATA 0.00000001 int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + float in0_opt = input0[0]; + float in1_opt = input1[0]; +#ifdef ENABLE_NEON + float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; +#endif + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vin0_opt; + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vmulq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = in0_opt * input1[i]; + } +#endif + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = in0_opt * input1[index]; + } + } else { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vin1_opt; + float32x4_t vout = vmulq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = input0[i] * in1_opt; + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] * in1_opt; + } + } + + return NNACL_OK; +} +int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + float in0_opt = input0[0]; + float in1_opt = input1[0]; +#ifdef ENABLE_NEON + float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; + float32x4_t zeros = {0, 0, 0, 0}; +#endif + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vin0_opt; + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vmaxq_f32(vmulq_f32(vin0, vin1), zeros); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMAX(in0_opt * input1[i], 0); + } +#endif + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMAX(in0_opt * input1[index], 0); + } + } else { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vin1_opt; + float32x4_t vout = vmaxq_f32(vmulq_f32(vin0, vin1), zeros); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMAX(input0[i] * in1_opt, 0); + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMAX(input0[index] * in1_opt, 0); + } + } + + return NNACL_OK; +} +int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + float in0_opt = input0[0]; + float in1_opt = input1[0]; +#ifdef ENABLE_NEON + float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; + float32x4_t zeros = {0, 0, 0, 0}; + float32x4_t bounds = {6, 6, 6, 6}; +#endif if (param->in_elements_num0_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[0] * input1[i]; + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vin0_opt; + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMIN(MSMAX(in0_opt * input1[i], 0), 6); + } +#endif + input1 += C4NUM; + output += C4NUM; } - } else if (param->in_elements_num1_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] * input1[0]; + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(in0_opt * input1[index], 0), 6); } } else { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] * input1[i]; + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vin1_opt; + float32x4_t vout = vminq_f32(vmaxq_f32(vmulq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMIN(MSMAX(input0[i] * in1_opt, 0), 6); + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] * in1_opt, 0), 6); } } + return NNACL_OK; } int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + float in0_opt = input0[0]; + float in1_opt = input1[0]; +#ifdef ENABLE_NEON + float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; +#endif if (param->in_elements_num0_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[0] - input1[i]; + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vin0_opt; + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vsubq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = in0_opt - input1[i]; + } +#endif + input1 += C4NUM; + output += C4NUM; } - } else if (param->in_elements_num1_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] - input1[0]; + for (int index = 0; index < block_mod; ++index) { + output[index] = in0_opt - input1[index]; } } else { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] - input1[i]; + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vin1_opt; + float32x4_t vout = vsubq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = input0[i] - in1_opt; + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] - in1_opt; } } return NNACL_OK; } +int ElementOptSubRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + float in0_opt = input0[0]; + float in1_opt = input1[0]; +#ifdef ENABLE_NEON + float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; + float32x4_t zeros = {0, 0, 0, 0}; +#endif + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vin0_opt; + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vmaxq_f32(vsubq_f32(vin0, vin1), zeros); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMAX(in0_opt - input1[i], 0); + } +#endif + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMAX(in0_opt - input1[index], 0); + } + } else { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vin1_opt; + float32x4_t vout = vmaxq_f32(vsubq_f32(vin0, vin1), zeros); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMAX(input0[i] - in1_opt, 0); + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMAX(input0[index] - in1_opt, 0); + } + } + + return NNACL_OK; +} +int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + float in0_opt = input0[0]; + float in1_opt = input1[0]; +#ifdef ENABLE_NEON + float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; + float32x4_t zeros = {0, 0, 0, 0}; + float32x4_t bounds = {6, 6, 6, 6}; +#endif + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vin0_opt; + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vminq_f32(vmaxq_f32(vsubq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMIN(MSMAX(in0_opt - input1[i], 0), 6); + } +#endif + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(in0_opt - input1[index], 0), 6); + } + } else { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vin1_opt; + float32x4_t vout = vminq_f32(vmaxq_f32(vsubq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMIN(MSMAX(input0[i] - in1_opt, 0), 6); + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] - in1_opt, 0), 6); + } + } + + return NNACL_OK; +} int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + float in0_opt = input0[0]; + float in1_opt = input1[0]; +#ifdef ENABLE_NEON + float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; +#endif if (param->in_elements_num0_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[0] + input1[i]; + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vin0_opt; + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vaddq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = in0_opt + input1[i]; + } +#endif + input1 += C4NUM; + output += C4NUM; } - } else if (param->in_elements_num1_ == 1) { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] + input1[0]; + for (int index = 0; index < block_mod; ++index) { + output[index] = in0_opt + input1[index]; } } else { - for (int i = 0; i < element_size; ++i) { - output[i] = input0[i] + input1[i]; + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vin1_opt; + float32x4_t vout = vaddq_f32(vin0, vin1); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = input0[i] + in1_opt; + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = input0[index] + in1_opt; } } return NNACL_OK; } +int ElementOptAddRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + float in0_opt = input0[0]; + float in1_opt = input1[0]; +#ifdef ENABLE_NEON + float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; + float32x4_t zeros = {0, 0, 0, 0}; +#endif + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vin0_opt; + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vmaxq_f32(vaddq_f32(vin0, vin1), zeros); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMAX(in0_opt + input1[i], 0); + } +#endif + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMAX(in0_opt + input1[index], 0); + } + } else { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vin1_opt; + float32x4_t vout = vmaxq_f32(vaddq_f32(vin0, vin1), zeros); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMAX(input0[i] + in1_opt, 0); + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMAX(input0[index] + in1_opt, 0); + } + } + + return NNACL_OK; +} +int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param) { + int block_mod = element_size % C4NUM; + int block_c4 = element_size - block_mod; + float in0_opt = input0[0]; + float in1_opt = input1[0]; +#ifdef ENABLE_NEON + float32x4_t vin0_opt = {input0[0], input0[0], input0[0], input0[0]}; + float32x4_t vin1_opt = {input1[0], input1[0], input1[0], input1[0]}; + float32x4_t zeros = {0, 0, 0, 0}; + float32x4_t bounds = {6, 6, 6, 6}; +#endif + if (param->in_elements_num0_ == 1) { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vin0_opt; + float32x4_t vin1 = vld1q_f32(input1); + float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMIN(MSMAX(in0_opt + input1[i], 0), 6); + } +#endif + input1 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(in0_opt + input1[index], 0), 6); + } + } else { + for (int index = 0; index < block_c4; index += C4NUM) { +#ifdef ENABLE_NEON + float32x4_t vin0 = vld1q_f32(input0); + float32x4_t vin1 = vin1_opt; + float32x4_t vout = vminq_f32(vmaxq_f32(vaddq_f32(vin0, vin1), zeros), bounds); + vst1q_f32(output, vout); +#else + for (int i = 0; i < C4NUM; ++i) { + output[i] = MSMIN(MSMAX(input0[i] + in1_opt, 0), 6); + } +#endif + input0 += C4NUM; + output += C4NUM; + } + for (int index = 0; index < block_mod; ++index) { + output[index] = MSMIN(MSMAX(input0[index] + in1_opt, 0), 6); + } + } + + return NNACL_OK; +} int ElementMul(float *input0, float *input1, float *output, int element_size) { int block_mod = element_size % C4NUM; diff --git a/mindspore/lite/nnacl/fp32/arithmetic.h b/mindspore/lite/nnacl/fp32/arithmetic.h index 5d3303ca0b..ab0e0d0297 100644 --- a/mindspore/lite/nnacl/fp32/arithmetic.h +++ b/mindspore/lite/nnacl/fp32/arithmetic.h @@ -27,8 +27,14 @@ extern "C" { #endif int ElementOptAdd(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); +int ElementOptAddRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); +int ElementOptAddRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementOptSub(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); +int ElementOptSubRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); +int ElementOptSubRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementOptMul(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); +int ElementOptMulRelu(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); +int ElementOptMulRelu6(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); int ElementMul(float *input0, float *input1, float *output, int element_size); int ElementMulRelu(float *input0, float *input1, float *output, int element_size); int ElementMulRelu6(float *input0, float *input1, float *output, int element_size); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index 3bb3554999..f7da12b275 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -162,6 +162,7 @@ int ArithmeticFP16CPUKernel::Init() { } int ArithmeticFP16CPUKernel::ReSize() { + FreeTmpBuffer(); arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); @@ -175,10 +176,10 @@ int ArithmeticFP16CPUKernel::ReSize() { arithmetic_opt_run_ = ElementOptMulReluFp16; break; case schema::ActivationType_RELU6: - arithmetic_opt_run_ = ElementOptDivRelu6Fp16; + arithmetic_opt_run_ = ElementOptMulRelu6Fp16; break; default: - arithmetic_opt_run_ = ElementOptDivFp16; + arithmetic_opt_run_ = ElementOptMulFp16; break; } break; @@ -267,20 +268,46 @@ int ArithmeticFP16CPUKernel::ReSize() { break; } } + + if (arithmeticParameter_->broadcasting_) { + outside_ = 1; + for (int i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) { + if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) { + break_pos_ = i; + break; + } + outside_ *= arithmeticParameter_->out_shape_[i]; + } + ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_); + ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_); + ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_); + } return RET_OK; } -int ArithmeticFP16CPUKernel::broadcast_run_(float16_t *input0, float16_t *input1, float16_t *output, int dim) { +int ArithmeticFP16CPUKernel::BroadcastRun(float16_t *input0, float16_t *input1, float16_t *output, int dim, + int out_count, int out_thread_stride) { if (dim > break_pos_) { - return arithmetic_run_(input0 + out_thread_stride_, input1 + out_thread_stride_, output + out_thread_stride_, - out_count_); + int error_code = + arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, out_count); + if (output_fp16_ != nullptr) { + auto output_fp32 = reinterpret_cast(out_tensors_[0]->Data()); + int bias = output - output_fp16_; + output_fp32 += bias; + Float16ToFloat32(output + out_thread_stride, output_fp32 + out_thread_stride, out_count); + } + return error_code; } for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { - int pos0_ = arithmeticParameter_->in_shape0_[0] == 1 ? 0 : i; - int pos1_ = arithmeticParameter_->in_shape1_[0] == 1 ? 0 : i; - return broadcast_run_(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim], - input1 + pos1_ * arithmeticParameter_->in_strides1_[dim], - output + i * arithmeticParameter_->out_strides_[dim], dim + 1); + int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i; + int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i; + int error_code = + BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim], + input1 + pos1_ * arithmeticParameter_->in_strides1_[dim], + output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride); + if (error_code != RET_OK) { + return RET_ERROR; + } } return RET_OK; } @@ -300,13 +327,16 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { if (arithmetic_run_ == nullptr) { MS_LOG(ERROR) << "arithmetic_run function is nullptr!"; + FreeTmpBuffer(); return RET_ERROR; } int error_code = RET_OK; if (arithmeticParameter_->broadcasting_) { - error_code = - arithmetic_run_(tile_data0_ + thread_stride, tile_data1_ + thread_stride, output_data + thread_stride, count); + stride = UP_DIV(outside_, context_->thread_num_); + out_count_ = MSMIN(stride, outside_ - stride * task_id); + out_thread_stride_ = stride * task_id; + error_code = BroadcastRun(input0_data, input1_data1, output_data, 0, out_count_, out_thread_stride_); } else if (arithmetic_opt_run_ != nullptr) { if (arithmeticParameter_->in_elements_num0_ == 1) { error_code = arithmetic_opt_run_(input0_data, input1_data1 + thread_stride, output_data + thread_stride, count, @@ -323,17 +353,16 @@ int ArithmeticFP16CPUKernel::DoArithmetic(int task_id) { arithmetic_run_(input0_data + thread_stride, input1_data1 + thread_stride, output_data + thread_stride, count); } if (error_code != RET_OK) { - FreeTmpBuffer(); return RET_ERROR; } - if (output_fp16_ != nullptr) { + if (output_fp16_ != nullptr && !arithmeticParameter_->broadcasting_) { auto output_fp32 = reinterpret_cast(out_tensors_[0]->Data()); Float16ToFloat32(output_data + thread_stride, output_fp32 + thread_stride, count); } return RET_OK; } -static int ArithmeticsRun(int task_id, LiteParallelGroupEnv *penv, void *cdata) { +static int ArithmeticsRun_Fp16(int task_id, LiteParallelGroupEnv *penv, void *cdata) { auto arithmetic_kernel = reinterpret_cast(cdata); auto error_code = arithmetic_kernel->DoArithmetic(task_id); if (error_code != RET_OK) { @@ -353,24 +382,6 @@ int ArithmeticFP16CPUKernel::Run() { arithmeticParameter_->in_elements_num0_ = in_tensors_[0]->ElementsNum(); arithmeticParameter_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arithmeticParameter_->out_elements_num_ = out_tensors_[0]->ElementsNum(); - if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { - input0_fp16_ = reinterpret_cast( - context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t))); - if (input0_fp16_ == nullptr) { - MS_LOG(ERROR) << "malloc data fail!"; - FreeTmpBuffer(); - return RET_ERROR; - } - } - if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) { - input1_fp16_ = reinterpret_cast( - context_->allocator->Malloc(arithmeticParameter_->in_elements_num1_ * sizeof(float16_t))); - if (input0_fp16_ == nullptr) { - MS_LOG(ERROR) << "malloc data fail!"; - FreeTmpBuffer(); - return RET_ERROR; - } - } if (out_tensors_[0]->data_type() == kNumberTypeFloat32 || out_tensors_[0]->data_type() == kNumberTypeFloat) { output_fp16_ = reinterpret_cast( context_->allocator->Malloc(arithmeticParameter_->out_elements_num_ * sizeof(float16_t))); @@ -380,46 +391,30 @@ int ArithmeticFP16CPUKernel::Run() { return RET_ERROR; } } - if (in_tensors_[0]->data_type() == kNumberTypeFloat32 || in_tensors_[0]->data_type() == kNumberTypeFloat) { - Float32ToFloat16(reinterpret_cast(in_tensors_[0]->Data()), input0_fp16_, - arithmeticParameter_->in_elements_num0_); - } - if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) { - Float32ToFloat16(reinterpret_cast(in_tensors_[1]->Data()), input1_fp16_, - arithmeticParameter_->in_elements_num1_); - } - - if (arithmeticParameter_->broadcasting_) { - auto tile_size = arithmeticParameter_->out_elements_num_ * sizeof(float16_t); - tile_data0_ = reinterpret_cast(malloc(tile_size)); - if (tile_data0_ == nullptr) { + input0_fp16_ = reinterpret_cast( + context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t))); + if (input0_fp16_ == nullptr) { MS_LOG(ERROR) << "malloc data fail!"; FreeTmpBuffer(); return RET_ERROR; } - tile_data1_ = reinterpret_cast(malloc(tile_size)); - if (tile_data1_ == nullptr) { + Float32ToFloat16(reinterpret_cast(in_tensors_[0]->Data()), input0_fp16_, + arithmeticParameter_->in_elements_num0_); + } + if (in_tensors_[1]->data_type() == kNumberTypeFloat32 || in_tensors_[1]->data_type() == kNumberTypeFloat) { + input1_fp16_ = reinterpret_cast( + context_->allocator->Malloc(arithmeticParameter_->in_elements_num0_ * sizeof(float16_t))); + if (input1_fp16_ == nullptr) { MS_LOG(ERROR) << "malloc data fail!"; FreeTmpBuffer(); return RET_ERROR; } - auto input0 = reinterpret_cast(in_tensors_[0]->Data()); - auto input1 = reinterpret_cast(in_tensors_[1]->Data()); - - float16_t *input0_data = input0_fp16_ == nullptr ? input0 : input0_fp16_; - float16_t *input1_data1 = input1_fp16_ == nullptr ? input1 : input1_fp16_; - - TileDimensionsFp16(input0_data, input1_data1, tile_data0_, tile_data1_, arithmeticParameter_); - } - - ret = LiteBackendParallelLaunch(ArithmeticsRun, this, context_->thread_num_); - if (ret != RET_OK) { - MS_LOG(ERROR) << "Arithmetic function fail!ret: " << ret; - FreeTmpBuffer(); - return ret; + Float32ToFloat16(reinterpret_cast(in_tensors_[1]->Data()), input1_fp16_, + arithmeticParameter_->in_elements_num1_); } - return RET_OK; + ret = LiteBackendParallelLaunch(ArithmeticsRun_Fp16, this, context_->thread_num_); + return ret; } kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vector &inputs, @@ -446,21 +441,21 @@ kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vectorout_elements_num_ = out_tensors_[0]->ElementsNum(); if (arithmeticParameter_->in_elements_num0_ == 1 || arithmeticParameter_->in_elements_num1_ == 1) { - if (arithmeticParameter_->activation_type_ == schema::ActivationType_NO_ACTIVATION) { - switch (arithmeticParameter_->op_parameter_.type_) { - case PrimitiveType_Mul: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptMul; - break; - case PrimitiveType_Add: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptAdd; - break; - case PrimitiveType_Sub: - arithmeticParameter_->broadcasting_ = false; - arithmetic_opt_run_ = ElementOptSub; - break; - default: - break; - } + switch (arithmeticParameter_->op_parameter_.type_) { + case PrimitiveType_Mul: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptMulRelu; + break; + case schema::ActivationType_RELU6: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptMulRelu6; + break; + default: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptMul; + break; + } + break; + case PrimitiveType_Add: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptAddRelu; + break; + case schema::ActivationType_RELU6: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptAddRelu6; + break; + default: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptAdd; + break; + } + break; + case PrimitiveType_Sub: + switch (arithmeticParameter_->activation_type_) { + case schema::ActivationType_RELU: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptSubRelu; + break; + case schema::ActivationType_RELU6: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptSubRelu6; + break; + default: + arithmeticParameter_->broadcasting_ = false; + arithmetic_opt_run_ = ElementOptSub; + break; + } + break; + default: + break; + } + } + return RET_OK; +} + +int ArithmeticCPUKernel::BroadcastRun(float *input0, float *input1, float *output, int dim, int out_count, + int out_thread_stride) { + if (dim > break_pos_) { + return arithmetic_run_(input0 + out_thread_stride, input1 + out_thread_stride, output + out_thread_stride, + out_count); + } + for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { + int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i; + int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i; + int error_code = + BroadcastRun(input0 + pos0_ * arithmeticParameter_->in_strides0_[dim], + input1 + pos1_ * arithmeticParameter_->in_strides1_[dim], + output + i * arithmeticParameter_->out_strides_[dim], dim + 1, out_count, out_thread_stride); + if (error_code != RET_OK) { + return error_code; } } return RET_OK; @@ -81,8 +138,10 @@ int ArithmeticCPUKernel::DoArithmetic(int task_id) { int error_code = RET_OK; if (arithmeticParameter_->broadcasting_) { - error_code = arithmetic_run_(tile_data0_ + stride * task_id, tile_data1_ + stride * task_id, - output_data + stride * task_id, count); + stride = UP_DIV(outside_, thread_count_); + out_count_ = MSMIN(stride, outside_ - stride * task_id); + out_thread_stride_ = stride * task_id; + error_code = BroadcastRun(input0_data, input1_data1, output_data, 0, out_count_, out_thread_stride_); } else if (arithmetic_opt_run_ != nullptr) { if (arithmeticParameter_->in_elements_num0_ == 1) { error_code = arithmetic_opt_run_(input0_data, input1_data1 + stride * task_id, output_data + stride * task_id, @@ -120,31 +179,27 @@ int ArithmeticCPUKernel::Run() { MS_LOG(ERROR) << "Prepare fail!ret: " << ret; return ret; } - if (arithmeticParameter_->broadcasting_) { - auto input_data0 = reinterpret_cast(in_tensors_[0]->Data()); - auto input_data1 = reinterpret_cast(in_tensors_[1]->Data()); - auto length = arithmeticParameter_->out_elements_num_ * sizeof(float); - MS_ASSERT(context_->allocator != nullptr); - tile_data0_ = reinterpret_cast(context_->allocator->Malloc(length)); - tile_data1_ = reinterpret_cast(context_->allocator->Malloc(length)); - if (tile_data0_ == nullptr || tile_data1_ == nullptr) { - MS_LOG(ERROR) << "Memory allocation failed"; - context_->allocator->Free(tile_data0_); - context_->allocator->Free(tile_data1_); - return RET_ERROR; + outside_ = 1; + for (auto i = arithmeticParameter_->ndim_ - 1; i >= 0; --i) { + if (arithmeticParameter_->in_shape0_[i] != arithmeticParameter_->in_shape1_[i]) { + break_pos_ = i; + break; + } + outside_ *= arithmeticParameter_->out_shape_[i]; } - TileDimensions(input_data0, input_data1, tile_data0_, tile_data1_, arithmeticParameter_); + ComputeStrides(arithmeticParameter_->in_shape0_, arithmeticParameter_->in_strides0_, arithmeticParameter_->ndim_); + ComputeStrides(arithmeticParameter_->in_shape1_, arithmeticParameter_->in_strides1_, arithmeticParameter_->ndim_); + ComputeStrides(arithmeticParameter_->out_shape_, arithmeticParameter_->out_strides_, arithmeticParameter_->ndim_); } - ret = LiteBackendParallelLaunch(ArithmeticsRun, this, thread_count_); - if (arithmeticParameter_->broadcasting_) { - context_->allocator->Free(tile_data0_); - context_->allocator->Free(tile_data1_); - } - if (ret != RET_OK) { - MS_LOG(ERROR) << "Arithmetic function error error_code[" << ret << "]"; + + int error_code = LiteBackendParallelLaunch(ArithmeticsRun, this, thread_count_); + + if (error_code != RET_OK) { + MS_LOG(ERROR) << "Arithmetic function error error_code[" << error_code << "]"; + return RET_ERROR; } - return ret; + return RET_OK; } kernel::LiteKernel *CpuArithmeticFp32KernelCreator(const std::vector &inputs, diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h index f4b42a25a1..c55bf35bfa 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/arithmetic.h @@ -45,8 +45,6 @@ class ArithmeticCPUKernel : public LiteKernel { typedef int (*ArithmeticRun)(float *input0, float *input1, float *output, int element_size); typedef int (*ArithmeticOptRun)(float *input0, float *input1, float *output, int element_size, ArithmeticParameter *param); - typedef int (*ArithmeticBroadcastRun)(float *input0, float *input1, float *tile_input0, float *tile_input1, - float *output, int element_size, ArithmeticParameter *param); public: ArithmeticCPUKernel(OpParameter *parameter, const std::vector &inputs, @@ -109,64 +107,50 @@ class ArithmeticCPUKernel : public LiteKernel { break; case PrimitiveType_LogicalAnd: arithmetic_run_ = ElementLogicalAnd; - arithmetic_broadcast_run_ = BroadcastLogicalAnd; break; case PrimitiveType_LogicalOr: arithmetic_run_ = ElementLogicalOr; - arithmetic_broadcast_run_ = BroadcastLogicalOr; break; case PrimitiveType_Maximum: arithmetic_run_ = ElementMaximum; - arithmetic_broadcast_run_ = BroadcastMaximum; break; case PrimitiveType_Minimum: arithmetic_run_ = ElementMinimum; - arithmetic_broadcast_run_ = BroadcastMinimum; break; case PrimitiveType_FloorDiv: arithmetic_run_ = ElementFloorDiv; - arithmetic_broadcast_run_ = BroadcastFloorDiv; break; case PrimitiveType_FloorMod: arithmetic_run_ = ElementFloorMod; - arithmetic_broadcast_run_ = BroadcastFloorMod; break; case PrimitiveType_Equal: arithmetic_run_ = ElementEqual; - arithmetic_broadcast_run_ = BroadcastEqual; break; case PrimitiveType_NotEqual: arithmetic_run_ = ElementNotEqual; - arithmetic_broadcast_run_ = BroadcastNotEqual; break; case PrimitiveType_Less: arithmetic_run_ = ElementLess; - arithmetic_broadcast_run_ = BroadcastLess; break; case PrimitiveType_LessEqual: arithmetic_run_ = ElementLessEqual; - arithmetic_broadcast_run_ = BroadcastLessEqual; break; case PrimitiveType_Greater: arithmetic_run_ = ElementGreater; - arithmetic_broadcast_run_ = BroadcastGreater; break; case PrimitiveType_GreaterEqual: arithmetic_run_ = ElementGreaterEqual; - arithmetic_broadcast_run_ = BroadcastGreaterEqual; break; case PrimitiveType_SquaredDifference: arithmetic_run_ = ElementSquaredDifference; - arithmetic_broadcast_run_ = BroadcastSquaredDifference; break; default: MS_LOG(ERROR) << "Error Operator type " << parameter->type_; arithmetic_run_ = nullptr; - arithmetic_broadcast_run_ = nullptr; break; } } - ~ArithmeticCPUKernel() = default; + ~ArithmeticCPUKernel() override; int Init() override; int ReSize() override; @@ -174,12 +158,14 @@ class ArithmeticCPUKernel : public LiteKernel { int DoArithmetic(int task_id); private: + int BroadcastRun(float *input0, float *input1, float *output, int dim, int out_count, int out_thread_stride); + int break_pos_; + int outside_; + int out_thread_stride_; + int out_count_; int thread_count_; - float *tile_data0_ = nullptr; - float *tile_data1_ = nullptr; ArithmeticParameter *arithmeticParameter_; ArithmeticRun arithmetic_run_ = nullptr; - ArithmeticBroadcastRun arithmetic_broadcast_run_ = nullptr; ArithmeticOptRun arithmetic_opt_run_ = nullptr; }; } // namespace mindspore::kernel