From 0790e0ca08c0c91fd1f201067d919b5e2b80ac3e Mon Sep 17 00:00:00 2001 From: ling Date: Sat, 28 Nov 2020 18:33:18 +0800 Subject: [PATCH] [MSLITE] add const node bug --- mindspore/lite/nnacl/int8/add_int8.c | 188 +++++++++--------- mindspore/lite/nnacl/int8/add_int8.h | 24 +-- .../src/runtime/kernel/arm/int8/add_int8.cc | 59 ++++-- 3 files changed, 150 insertions(+), 121 deletions(-) diff --git a/mindspore/lite/nnacl/int8/add_int8.c b/mindspore/lite/nnacl/int8/add_int8.c index a4341939e2..e79dfe2449 100644 --- a/mindspore/lite/nnacl/int8/add_int8.c +++ b/mindspore/lite/nnacl/int8/add_int8.c @@ -21,23 +21,23 @@ #include "nnacl/quantization/fixed_point.h" void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params) { - int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_left_shift_); - int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_left_shift_); + int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_args_.left_shift_); + int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_args_.left_shift_); int index = 0; #ifdef ENABLE_ARM const int8x16_t min_vec = vdupq_n_s8(params->min_); const int8x16_t max_vac = vdupq_n_s8(params->max_); - const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_zp_); - const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_zp_); + const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_args_.zp_); + const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_args_.zp_); const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_); const int32x4_t in0_left_vec = vdupq_n_s32(in0_left_shift); const int32x4_t in1_left_vec = vdupq_n_s32(in1_left_shift); - const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_right_shift_); - const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_right_shift_); + const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_args_.right_shift_); + const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_args_.right_shift_); const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_); const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_); @@ -76,14 +76,14 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz in1_4 = vmulq_s32(in1_4, in1_left_vec); // Apply the fixed-point part of the multiplier. - in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_multiplier_); - in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_multiplier_); - in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_multiplier_); - in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_multiplier_); - in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_multiplier_); - in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_multiplier_); - in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_multiplier_); - in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_multiplier_); + in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_args_.multiplier_); + in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_args_.multiplier_); + in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_args_.multiplier_); + in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_args_.multiplier_); + in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_args_.multiplier_); + in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_args_.multiplier_); + in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_args_.multiplier_); + in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_args_.multiplier_); // Apply right shift in0_1 = vqaddq_s32(in0_1, vshrq_n_s32(vandq_s32(in0_1, in0_right_vec), 31)); @@ -149,10 +149,12 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz #endif for (; index < size; index++) { - const int32_t in0_left = (input0[index] + params->in0_zp_) * in0_left_shift; - const int32_t in1_left = (input1[index] + params->in1_zp_) * in1_left_shift; - const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, params->in0_multiplier_, params->in0_right_shift_); - const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, params->in1_multiplier_, params->in1_right_shift_); + const int32_t in0_left = (input0[index] + params->in0_args_.zp_) * in0_left_shift; + const int32_t in1_left = (input1[index] + params->in1_args_.zp_) * in1_left_shift; + const int32_t in0 = + MultiplyByMultiplierAndRightShift(in0_left, params->in0_args_.multiplier_, params->in0_args_.right_shift_); + const int32_t in1 = + MultiplyByMultiplierAndRightShift(in1_left, params->in1_args_.multiplier_, params->in1_args_.right_shift_); int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, -params->out_right_shift_); @@ -162,110 +164,116 @@ void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int siz return; } -void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params) { - int in0_left_shift = (1 << params->left_shift_) * (1 << params->in0_left_shift_); - int in1_left_shift = (1 << params->left_shift_) * (1 << params->in1_left_shift_); +void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params, + AddQuantQrgs *ptr_args, AddQuantQrgs *ele_args) { + int ptr_left_shift = (1 << params->left_shift_) * (1 << ptr_args->left_shift_); + int ele_left_shift = (1 << params->left_shift_) * (1 << ele_args->left_shift_); int index = 0; #ifdef ENABLE_ARM - const int8x16_t in1_src = vdupq_n_s8(element_in); - + /* const value init */ const int8x16_t min_vec = vdupq_n_s8(params->min_); const int8x16_t max_vac = vdupq_n_s8(params->max_); - const int16x8_t in0_zp_vec = vdupq_n_s16(params->in0_zp_); - const int16x8_t in1_zp_vec = vdupq_n_s16(params->in1_zp_); + const int16x8_t ptr_zp_vec = vdupq_n_s16(ptr_args->zp_); + const int16x8_t ele_zp_vec = vdupq_n_s16(ele_args->zp_); const int16x8_t out_zp_vec = vdupq_n_s16(params->out_zp_); - const int32x4_t in0_left_vec = vdupq_n_s32(in0_left_shift); - const int32x4_t in1_left_vec = vdupq_n_s32(in1_left_shift); + const int32x4_t ptr_left_vec = vdupq_n_s32(ptr_left_shift); + const int32x4_t ele_left_vec = vdupq_n_s32(ele_left_shift); - const int32x4_t in0_right_vec = vdupq_n_s32(-params->in0_right_shift_); - const int32x4_t in1_right_vec = vdupq_n_s32(-params->in1_right_shift_); + const int32x4_t ptr_right_vec = vdupq_n_s32(-ptr_args->right_shift_); + const int32x4_t ele_right_vec = vdupq_n_s32(-ptr_args->right_shift_); const int32x4_t out_left_vec = vdupq_n_s32(params->out_left_shift_); const int32x4_t out_right_vec = vdupq_n_s32(-params->out_right_shift_); + /* deal with const node */ + const int8x16_t ele_src = vdupq_n_s8(element_in); + const int16x8_t ele_s16_low = vmovl_s8(vget_low_s8(ele_src)); + const int16x8_t ele_s16_high = vmovl_s8(vget_high_s8(ele_src)); + const int16x8_t ele_zp_low = vaddq_s16(ele_s16_low, ele_zp_vec); + const int16x8_t ele_zp_high = vaddq_s16(ele_s16_high, ele_zp_vec); + int32x4_t ele1 = vmovl_s16(vget_low_s16(ele_zp_low)); + int32x4_t ele2 = vmovl_s16(vget_high_s16(ele_zp_low)); + int32x4_t ele3 = vmovl_s16(vget_low_s16(ele_zp_high)); + int32x4_t ele4 = vmovl_s16(vget_high_s16(ele_zp_high)); + // Apply left shift + ele1 = vmulq_s32(ele1, ele_left_vec); + ele2 = vmulq_s32(ele2, ele_left_vec); + ele3 = vmulq_s32(ele3, ele_left_vec); + ele4 = vmulq_s32(ele4, ele_left_vec); + // Apply the fixed-point part of the multiplier. + ele1 = vqrdmulhq_n_s32(ele1, ele_args->multiplier_); + ele2 = vqrdmulhq_n_s32(ele2, ele_args->multiplier_); + ele3 = vqrdmulhq_n_s32(ele3, ele_args->multiplier_); + ele4 = vqrdmulhq_n_s32(ele4, ele_args->multiplier_); + // Apply right shift + ele1 = vqaddq_s32(ele1, vshrq_n_s32(vandq_s32(ele1, ele_right_vec), 31)); + ele2 = vqaddq_s32(ele2, vshrq_n_s32(vandq_s32(ele2, ele_right_vec), 31)); + ele3 = vqaddq_s32(ele3, vshrq_n_s32(vandq_s32(ele3, ele_right_vec), 31)); + ele4 = vqaddq_s32(ele4, vshrq_n_s32(vandq_s32(ele4, ele_right_vec), 31)); + ele1 = vrshlq_s32(ele1, ele_right_vec); + ele2 = vrshlq_s32(ele2, ele_right_vec); + ele3 = vrshlq_s32(ele3, ele_right_vec); + ele4 = vrshlq_s32(ele4, ele_right_vec); + for (; index <= size - 16; index += 16) { - const int8x16_t in0_src = vld1q_s8(ptr_in + index); + const int8x16_t ptr_src = vld1q_s8(ptr_in + index); - const int16x8_t in0_s16_low = vmovl_s8(vget_low_s8(in0_src)); - const int16x8_t in0_s16_high = vmovl_s8(vget_high_s8(in0_src)); - const int16x8_t in1_s16_low = vmovl_s8(vget_low_s8(in1_src)); - const int16x8_t in1_s16_high = vmovl_s8(vget_high_s8(in1_src)); + const int16x8_t ptr_s16_low = vmovl_s8(vget_low_s8(ptr_src)); + const int16x8_t ptr_s16_high = vmovl_s8(vget_high_s8(ptr_src)); - const int16x8_t in0_zp_low = vaddq_s16(in0_s16_low, in0_zp_vec); - const int16x8_t in0_zp_high = vaddq_s16(in0_s16_high, in0_zp_vec); - const int16x8_t in1_zp_low = vaddq_s16(in1_s16_low, in1_zp_vec); - const int16x8_t in1_zp_high = vaddq_s16(in1_s16_high, in1_zp_vec); + const int16x8_t ptr_zp_low = vaddq_s16(ptr_s16_low, ptr_zp_vec); + const int16x8_t ptr_zp_high = vaddq_s16(ptr_s16_high, ptr_zp_vec); - int32x4_t in0_1 = vmovl_s16(vget_low_s16(in0_zp_low)); - int32x4_t in0_2 = vmovl_s16(vget_high_s16(in0_zp_low)); - int32x4_t in0_3 = vmovl_s16(vget_low_s16(in0_zp_high)); - int32x4_t in0_4 = vmovl_s16(vget_high_s16(in0_zp_high)); - int32x4_t in1_1 = vmovl_s16(vget_low_s16(in1_zp_low)); - int32x4_t in1_2 = vmovl_s16(vget_high_s16(in1_zp_low)); - int32x4_t in1_3 = vmovl_s16(vget_low_s16(in1_zp_high)); - int32x4_t in1_4 = vmovl_s16(vget_high_s16(in1_zp_high)); + int32x4_t ptr1 = vmovl_s16(vget_low_s16(ptr_zp_low)); + int32x4_t ptr2 = vmovl_s16(vget_high_s16(ptr_zp_low)); + int32x4_t ptr3 = vmovl_s16(vget_low_s16(ptr_zp_high)); + int32x4_t ptr4 = vmovl_s16(vget_high_s16(ptr_zp_high)); // Apply left shift - in0_1 = vmulq_s32(in0_1, in0_left_vec); - in0_2 = vmulq_s32(in0_2, in0_left_vec); - in0_3 = vmulq_s32(in0_3, in0_left_vec); - in0_4 = vmulq_s32(in0_4, in0_left_vec); - in1_1 = vmulq_s32(in1_1, in1_left_vec); - in1_2 = vmulq_s32(in1_2, in1_left_vec); - in1_3 = vmulq_s32(in1_3, in1_left_vec); - in1_4 = vmulq_s32(in1_4, in1_left_vec); + ptr1 = vmulq_s32(ptr1, ptr_left_vec); + ptr2 = vmulq_s32(ptr2, ptr_left_vec); + ptr3 = vmulq_s32(ptr3, ptr_left_vec); + ptr4 = vmulq_s32(ptr4, ptr_left_vec); // Apply the fixed-point part of the multiplier. - in0_1 = vqrdmulhq_n_s32(in0_1, params->in0_multiplier_); - in0_2 = vqrdmulhq_n_s32(in0_2, params->in0_multiplier_); - in0_3 = vqrdmulhq_n_s32(in0_3, params->in0_multiplier_); - in0_4 = vqrdmulhq_n_s32(in0_4, params->in0_multiplier_); - in1_1 = vqrdmulhq_n_s32(in1_1, params->in1_multiplier_); - in1_2 = vqrdmulhq_n_s32(in1_2, params->in1_multiplier_); - in1_3 = vqrdmulhq_n_s32(in1_3, params->in1_multiplier_); - in1_4 = vqrdmulhq_n_s32(in1_4, params->in1_multiplier_); + ptr1 = vqrdmulhq_n_s32(ptr1, ptr_args->multiplier_); + ptr2 = vqrdmulhq_n_s32(ptr2, ptr_args->multiplier_); + ptr3 = vqrdmulhq_n_s32(ptr3, ptr_args->multiplier_); + ptr4 = vqrdmulhq_n_s32(ptr4, ptr_args->multiplier_); // Apply right shift - in0_1 = vqaddq_s32(in0_1, vshrq_n_s32(vandq_s32(in0_1, in0_right_vec), 31)); - in0_2 = vqaddq_s32(in0_2, vshrq_n_s32(vandq_s32(in0_2, in0_right_vec), 31)); - in0_3 = vqaddq_s32(in0_3, vshrq_n_s32(vandq_s32(in0_3, in0_right_vec), 31)); - in0_4 = vqaddq_s32(in0_4, vshrq_n_s32(vandq_s32(in0_4, in0_right_vec), 31)); - in1_1 = vqaddq_s32(in1_1, vshrq_n_s32(vandq_s32(in1_1, in1_right_vec), 31)); - in1_2 = vqaddq_s32(in1_2, vshrq_n_s32(vandq_s32(in1_2, in1_right_vec), 31)); - in1_3 = vqaddq_s32(in1_3, vshrq_n_s32(vandq_s32(in1_3, in1_right_vec), 31)); - in1_4 = vqaddq_s32(in1_4, vshrq_n_s32(vandq_s32(in1_4, in1_right_vec), 31)); + ptr1 = vqaddq_s32(ptr1, vshrq_n_s32(vandq_s32(ptr1, ptr_right_vec), 31)); + ptr2 = vqaddq_s32(ptr2, vshrq_n_s32(vandq_s32(ptr2, ptr_right_vec), 31)); + ptr3 = vqaddq_s32(ptr3, vshrq_n_s32(vandq_s32(ptr3, ptr_right_vec), 31)); + ptr4 = vqaddq_s32(ptr4, vshrq_n_s32(vandq_s32(ptr4, ptr_right_vec), 31)); - in0_1 = vrshlq_s32(in0_1, in0_right_vec); - in0_2 = vrshlq_s32(in0_2, in0_right_vec); - in0_3 = vrshlq_s32(in0_3, in0_right_vec); - in0_4 = vrshlq_s32(in0_4, in0_right_vec); - in1_1 = vrshlq_s32(in1_1, in1_right_vec); - in1_2 = vrshlq_s32(in1_2, in1_right_vec); - in1_3 = vrshlq_s32(in1_3, in1_right_vec); - in1_4 = vrshlq_s32(in1_4, in1_right_vec); + ptr1 = vrshlq_s32(ptr1, ptr_right_vec); + ptr2 = vrshlq_s32(ptr2, ptr_right_vec); + ptr3 = vrshlq_s32(ptr3, ptr_right_vec); + ptr4 = vrshlq_s32(ptr4, ptr_right_vec); /* calculate output */ - int32x4_t out1 = vaddq_s32(in0_1, in1_1); - int32x4_t out2 = vaddq_s32(in0_2, in1_2); - int32x4_t out3 = vaddq_s32(in0_3, in1_3); - int32x4_t out4 = vaddq_s32(in0_4, in1_4); + int32x4_t out1 = vaddq_s32(ptr1, ele1); + int32x4_t out2 = vaddq_s32(ptr2, ele2); + int32x4_t out3 = vaddq_s32(ptr3, ele3); + int32x4_t out4 = vaddq_s32(ptr4, ele4); - // Apply left shift + // Apply output left shift out1 = vshlq_s32(out1, out_left_vec); out2 = vshlq_s32(out2, out_left_vec); out3 = vshlq_s32(out3, out_left_vec); out4 = vshlq_s32(out4, out_left_vec); - // Apply the fixed-point part of the multiplier. + // Apply output fixed-point part of the multiplier. out1 = vqrdmulhq_n_s32(out1, params->out_multiplier_); out2 = vqrdmulhq_n_s32(out2, params->out_multiplier_); out3 = vqrdmulhq_n_s32(out3, params->out_multiplier_); out4 = vqrdmulhq_n_s32(out4, params->out_multiplier_); - // Apply right shift + // Apply output right shift out1 = vqaddq_s32(out1, vshrq_n_s32(vandq_s32(out1, out_right_vec), 31)); out2 = vqaddq_s32(out2, vshrq_n_s32(vandq_s32(out2, out_right_vec), 31)); out3 = vqaddq_s32(out3, vshrq_n_s32(vandq_s32(out3, out_right_vec), 31)); @@ -292,12 +300,12 @@ void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, i #endif for (; index < size; index++) { - const int32_t in0_left = (ptr_in[index] + params->in0_zp_) * in0_left_shift; - const int32_t in1_left = (element_in + params->in1_zp_) * in1_left_shift; - const int32_t in0 = MultiplyByMultiplierAndRightShift(in0_left, params->in0_multiplier_, params->in0_right_shift_); - const int32_t in1 = MultiplyByMultiplierAndRightShift(in1_left, params->in1_multiplier_, params->in1_right_shift_); + const int32_t ptr_left = (ptr_in[index] + ptr_args->zp_) * ptr_left_shift; + const int32_t ele_left = (element_in + ele_args->zp_) * ele_left_shift; + const int32_t ptr = MultiplyByMultiplierAndRightShift(ptr_left, ptr_args->multiplier_, ptr_args->right_shift_); + const int32_t ele = MultiplyByMultiplierAndRightShift(ele_left, ele_args->multiplier_, ele_args->right_shift_); - int32_t out = MultiplyByQuantizedMultiplier(in0 + in1, params->out_multiplier_, params->out_left_shift_, + int32_t out = MultiplyByQuantizedMultiplier(ptr + ele, params->out_multiplier_, params->out_left_shift_, -params->out_right_shift_); out += params->out_zp_; output[index] = (int8_t)MSMAX(params->min_, MSMIN(out, params->max_)); diff --git a/mindspore/lite/nnacl/int8/add_int8.h b/mindspore/lite/nnacl/int8/add_int8.h index 44e0d066d0..15fdc03d7a 100644 --- a/mindspore/lite/nnacl/int8/add_int8.h +++ b/mindspore/lite/nnacl/int8/add_int8.h @@ -19,23 +19,22 @@ #include "nnacl/op_base.h" +typedef struct AddQuantQrgs { + int32_t zp_; + int32_t left_shift_; + int32_t right_shift_; + int32_t multiplier_; +} AddQuantQrgs; + typedef struct AddQuantParameter { int left_shift_; int32_t min_; int32_t max_; - int32_t in0_zp_; - int32_t in1_zp_; - int32_t out_zp_; - - int32_t in0_left_shift_; - int32_t in0_right_shift_; - int32_t in0_multiplier_; - - int32_t in1_left_shift_; - int32_t in1_right_shift_; - int32_t in1_multiplier_; + AddQuantQrgs in0_args_; + AddQuantQrgs in1_args_; + int32_t out_zp_; int32_t out_left_shift_; int32_t out_right_shift_; int32_t out_multiplier_; @@ -46,7 +45,8 @@ extern "C" { #endif void AddInt8(const int8_t *input0, const int8_t *input1, int8_t *output, int size, AddQuantParameter *params); -void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params); +void AddOptInt8(const int8_t *ptr_in, const int8_t element_in, int8_t *output, int size, AddQuantParameter *params, + AddQuantQrgs *ptr_args, AddQuantQrgs *ele_args); #ifdef __cplusplus } diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc index 8842abd401..cb40d35035 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/add_int8.cc @@ -35,8 +35,8 @@ int QuantizedAddCPUKernel::Init() { auto *input1 = in_tensors_.at(1); auto *output = out_tensors_.at(0); - para_.in0_zp_ = input0->quant_params().front().zeroPoint * -1; - para_.in1_zp_ = input1->quant_params().front().zeroPoint * -1; + para_.in0_args_.zp_ = input0->quant_params().front().zeroPoint * -1; + para_.in1_args_.zp_ = input1->quant_params().front().zeroPoint * -1; para_.out_zp_ = output->quant_params().front().zeroPoint; const double in0_scale = input0->quant_params().front().scale; @@ -49,16 +49,16 @@ int QuantizedAddCPUKernel::Init() { const double in1_multiplier = in1_scale / twice_max_input_scale; const double out_multiplier = twice_max_input_scale / ((1 << para_.left_shift_) * out_scale); - QuantizeMultiplierSmallerThanOne(in0_multiplier, ¶_.in0_multiplier_, ¶_.in0_left_shift_); - QuantizeMultiplierSmallerThanOne(in1_multiplier, ¶_.in1_multiplier_, ¶_.in1_left_shift_); + QuantizeMultiplierSmallerThanOne(in0_multiplier, ¶_.in0_args_.multiplier_, ¶_.in0_args_.left_shift_); + QuantizeMultiplierSmallerThanOne(in1_multiplier, ¶_.in1_args_.multiplier_, ¶_.in1_args_.left_shift_); QuantizeMultiplierSmallerThanOne(out_multiplier, ¶_.out_multiplier_, ¶_.out_left_shift_); - para_.in0_right_shift_ = -para_.in0_left_shift_ > 0 ? 0 : para_.in0_left_shift_; - para_.in1_right_shift_ = -para_.in1_left_shift_ > 0 ? 0 : para_.in1_left_shift_; + para_.in0_args_.right_shift_ = -para_.in0_args_.left_shift_ > 0 ? 0 : para_.in0_args_.left_shift_; + para_.in1_args_.right_shift_ = -para_.in1_args_.left_shift_ > 0 ? 0 : para_.in1_args_.left_shift_; para_.out_right_shift_ = -para_.out_left_shift_ > 0 ? 0 : para_.out_left_shift_; - para_.in0_left_shift_ = -para_.in0_left_shift_ > 0 ? -para_.in0_left_shift_ : 0; - para_.in1_left_shift_ = -para_.in1_left_shift_ > 0 ? -para_.in1_left_shift_ : 0; + para_.in0_args_.left_shift_ = -para_.in0_args_.left_shift_ > 0 ? -para_.in0_args_.left_shift_ : 0; + para_.in1_args_.left_shift_ = -para_.in1_args_.left_shift_ > 0 ? -para_.in1_args_.left_shift_ : 0; para_.out_left_shift_ = -para_.out_left_shift_ > 0 ? -para_.out_left_shift_ : 0; auto act = arith_para_->activation_type_; @@ -87,9 +87,24 @@ int QuantizedAddCPUKernel::ReSize() { arith_para_->in_elements_num1_ = in_tensors_[1]->ElementsNum(); arith_para_->out_elements_num_ = out_tensors_[0]->ElementsNum(); - memcpy(arith_para_->in_shape0_, input0->shape().data(), input0->shape().size() * sizeof(int)); - memcpy(arith_para_->in_shape1_, input1->shape().data(), input1->shape().size() * sizeof(int)); - memcpy(arith_para_->out_shape_, output->shape().data(), output->shape().size() * sizeof(int)); + for (size_t i = 0; i < in_tensors_[0]->shape().size(); i++) { + if (arith_para_->in_shape0_[i] == -1) { + memcpy(arith_para_->in_shape0_, input0->shape().data(), input0->shape().size() * sizeof(int)); + break; + } + } + for (size_t i = 0; i < in_tensors_[1]->shape().size(); i++) { + if (arith_para_->in_shape1_[i] == -1) { + memcpy(arith_para_->in_shape1_, input1->shape().data(), input1->shape().size() * sizeof(int)); + break; + } + } + for (size_t i = 0; i < out_tensors_[0]->shape().size(); i++) { + if (arith_para_->out_shape_[i] == -1) { + memcpy(arith_para_->out_shape_, output->shape().data(), output->shape().size() * sizeof(int)); + break; + } + } if (arith_para_->broadcasting_) { size_t break_pos_ = 0; @@ -128,14 +143,18 @@ void QuantizedAddCPUKernel::BroadcastRun(int task_id) { if (real_out_count <= 0) { return; } - - int8_t *const_in = arith_para_->in_elements_num0_ == arith_para_->out_elements_num_ ? input1_data_ : input0_data_; - int8_t *offset_in = arith_para_->in_elements_num0_ == arith_para_->out_elements_num_ ? input0_data_ : input1_data_; - offset_in += task_id * stride * in_size_; - int8_t *cur_out = output_data_ + task_id * stride * in_size_; - + int8_t *cur_in0, *cur_in1, *cur_out; for (int i = 0; i < real_out_count; i++) { - AddInt8(offset_in + i * in_size_, const_in, cur_out + i * in_size_, in_size_, ¶_); + if (arith_para_->in_elements_num0_ == arith_para_->out_elements_num_) { + cur_in0 = input0_data_ + task_id * stride * in_size_ + i * in_size_; + cur_in1 = input1_data_; + cur_out = output_data_ + task_id * stride * in_size_ + i * in_size_; + } else { + cur_in0 = input0_data_; + cur_in1 = input1_data_ + task_id * stride * in_size_ + i * in_size_; + cur_out = output_data_ + task_id * stride * in_size_ + i * in_size_; + } + AddInt8(cur_in0, cur_in1, cur_out, in_size_, ¶_); } return; } @@ -160,7 +179,9 @@ int QuantizedAddCPUKernel::DoExecute(int task_id) { if (support_opt_add_) { int8_t *ptr_in = arith_para_->in_elements_num0_ == 1 ? cur_in1 : cur_in0; int8_t element_in = arith_para_->in_elements_num0_ == 1 ? input0_data_[0] : input1_data_[0]; - AddOptInt8(ptr_in, element_in, cur_out, rest_count, ¶_); + AddQuantQrgs *ptr_args = arith_para_->in_elements_num0_ == 1 ? ¶_.in1_args_ : ¶_.in0_args_; + AddQuantQrgs *ele_args = arith_para_->in_elements_num0_ == 1 ? ¶_.in0_args_ : ¶_.in1_args_; + AddOptInt8(ptr_in, element_in, cur_out, rest_count, ¶_, ptr_args, ele_args); } else { AddInt8(cur_in0, cur_in1, cur_out, rest_count, ¶_); }