diff --git a/mindspore/lite/nnacl/arithmetic_common.h b/mindspore/lite/nnacl/arithmetic_common.h index 744a6a797e..b30c762153 100644 --- a/mindspore/lite/nnacl/arithmetic_common.h +++ b/mindspore/lite/nnacl/arithmetic_common.h @@ -53,6 +53,8 @@ void ComputeStrides(const int *shape, int *strides, const int ndim); void CalcMultiplesAndStrides(ArithmeticParameter *param); +void TileOneDimensionUint8(uint8_t *inData, uint8_t *outData, int dim, size_t ndim, int *inShape, int *inStrides, + int *outStrides, int *multiple); void TileDimensions(float *data0, float *data1, float *tile_data0, float *tile_data1, ArithmeticParameter *param); void TileDimensionsUint8(uint8_t *data0, uint8_t *data1, uint8_t *tile_data0, uint8_t *tile_data1, ArithmeticParameter *param); diff --git a/mindspore/lite/nnacl/int8/scale_int8.c b/mindspore/lite/nnacl/int8/scale_int8.c index bfe1dbc3f8..dd9374c572 100644 --- a/mindspore/lite/nnacl/int8/scale_int8.c +++ b/mindspore/lite/nnacl/int8/scale_int8.c @@ -17,78 +17,148 @@ #include "nnacl/int8/scale_int8.h" #include "nnacl/quantization/fixed_point.h" -void ScaleInnerInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int outer_start, int outer_end, - int axis_size, int inner_size, const ScaleParameter *scale_param, int max, int min) { - for (int out = outer_start; out < outer_end; out++) { - int out_offset = out * axis_size * inner_size; - for (int i = 0; i < axis_size; i++) { - int axis_offset = out_offset + i * inner_size; - int in_index = 0; - - for (; in_index < inner_size; in_index++) { - int in_offset = axis_offset + in_index; - int tmp_input_scale = (in_data[in_offset] - scale_param->input_zp_) * (scale[i] - scale_param->scale_zp_); - int input_mul_scale = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - tmp_input_scale * (1 << (unsigned int)scale_param->scale_mul_arg_.left_shift_), - scale_param->scale_mul_arg_.multiplier_), - scale_param->scale_mul_arg_.right_shift_); - int tmp = input_mul_scale + scale_param->output_zp_; - tmp = tmp > max ? max : tmp; - tmp = tmp < min ? min : tmp; - out_data[in_offset] = tmp; - } - } - } +#ifdef ENABLE_NEON +int16x4_t ClacSumHalfWordMul2(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec, + int32x4_t output_multiplier_vec, const ScaleParameter *scale_param) { + int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + scale_param->scale_mul_arg_.right_shift_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_)); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_)); + return vqmovn_s32(raw_sum); } -void ScaleInnerWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, - int outer_start, int outer_end, int axis_size, int inner_size, - const ScaleParameter *scale_param, int max, int min) { - for (int out = outer_start; out < outer_end; out++) { - int out_offset = out * axis_size * inner_size; - for (int i = 0; i < axis_size; i++) { - int axis_offset = out_offset + i * inner_size; - int in_index = 0; - - for (; in_index < inner_size; in_index++) { - int in_offset = axis_offset + in_index; - int tmp_input_scale = (in_data[in_offset] - scale_param->input_zp_) * (scale[i] - scale_param->scale_zp_); - int input_mul_scale = - RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( - tmp_input_scale * (1 << (unsigned int)scale_param->scale_mul_arg_.left_shift_), - scale_param->scale_mul_arg_.multiplier_), - scale_param->scale_mul_arg_.right_shift_); - int tmp_bias = offset[i] - scale_param->offset_zp_; - int bias = RoundingDivideByPOT( - SaturatingRoundingDoublingHighMul(tmp_bias * (1 << (unsigned int)scale_param->offset_mul_arg_.left_shift_), - scale_param->offset_mul_arg_.multiplier_), - scale_param->offset_mul_arg_.right_shift_); - int tmp = input_mul_scale + bias + scale_param->output_zp_; - tmp = tmp > max ? max : tmp; - tmp = tmp < min ? min : tmp; - out_data[in_offset] = tmp; - } - } - } +int16x4_t ClacSumHalfWordMul3(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t scaled_input2, + const ScaleParameter *scale_param) { + int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_); + int32x4_t output_multiplier_vec2 = vdupq_n_s32(scale_param->offset_mul_arg_.multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << scale_param->scale_mul_arg_.left_shift_); + int32x4_t left_shift_out_vec2 = vdupq_n_s32(1 << scale_param->offset_mul_arg_.left_shift_); + int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1); + int32x4_t raw_sum = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec), + scale_param->scale_mul_arg_.right_shift_); + int32x4_t raw_sum2 = RoundingDivideByPOTInt32x4( + SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(scaled_input2, left_shift_out_vec2), output_multiplier_vec2), + scale_param->offset_mul_arg_.right_shift_); + raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(scale_param->output_zp_)); + raw_sum = vaddq_s32(raw_sum, raw_sum2); + raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_min_)); + raw_sum = vminq_s32(raw_sum, vdupq_n_s32(scale_param->output_activation_max_)); + return vqmovn_s32(raw_sum); } +#endif + +void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleParameter *scale_param, + int real_dst_count) { + int index = 0; +#ifdef ENABLE_NEON + int32x4_t output_multiplier_vec = vdupq_n_s32(scale_param->scale_mul_arg_.multiplier_); + int32x4_t left_shift_out_vec = vdupq_n_s32(1 << scale_param->scale_mul_arg_.left_shift_); + + for (; index <= real_dst_count - 8; index += 8) { + int8x8_t input_s8 = vld1_s8(in_data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_)); + + int8x8_t input1_s8 = vld1_s8(scale + index); + int16x8_t input1_s16 = vmovl_s8(input1_s8); + int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_)); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + + int16x4_t sum_low = + ClacSumHalfWordMul2(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, scale_param); + int16x4_t sum_high = + ClacSumHalfWordMul2(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, scale_param); + + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(out_data, res_u8_n0); + out_data += 8; + } +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = scale_param->input_zp_ + in_data[index]; + const int32_t input1_val = scale_param->scale_zp_ + scale[index]; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_), + scale_param->scale_mul_arg_.multiplier_), + scale_param->scale_mul_arg_.right_shift_); -void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int task_id, - const ScaleParameter *scale_param, int max, int min) { - int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_); - int outer_start = task_id * outer_step; - int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + mul_result += scale_param->output_zp_; - ScaleInnerInt8(in_data, out_data, scale, outer_start, outer_end, scale_param->axis_size_, scale_param->inner_size_, - scale_param, max, min); + if (mul_result > scale_param->output_activation_max_) { + out_data[index] = scale_param->output_activation_max_; + } else if (mul_result < scale_param->output_activation_min_) { + out_data[index] = scale_param->output_activation_min_; + } else { + out_data[index] = (int8_t)mul_result; + } + } + return; } void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, - int task_id, const ScaleParameter *scale_param, int max, int min) { - int outer_step = UP_DIV(scale_param->outer_size_, scale_param->op_parameter_.thread_num_); - int outer_start = task_id * outer_step; - int outer_end = MSMIN(outer_start + outer_step, scale_param->outer_size_); + const ScaleParameter *scale_param, int real_dst_count) { + int index = 0; +#ifdef ENABLE_NEON + for (; index <= real_dst_count - 8; index += 8) { + int8x8_t input_s8 = vld1_s8(in_data + index); + int16x8_t input_s16 = vmovl_s8(input_s8); + int16x8_t input0_val = vaddq_s16(input_s16, vdupq_n_s16(scale_param->input_zp_)); + + int8x8_t input1_s8 = vld1_s8(scale + index); + int16x8_t input1_s16 = vmovl_s8(input1_s8); + int16x8_t input1_val = vaddq_s16(input1_s16, vdupq_n_s16(scale_param->scale_zp_)); + + int8x8_t input2_s8 = vld1_s8(offset + index); + int16x8_t input2_s16 = vmovl_s8(input2_s8); + int16x8_t input2_val = vaddq_s16(input2_s16, vdupq_n_s16(scale_param->offset_zp_)); + + int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val)); + int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val)); + int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val)); + int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val)); + int32x4_t input2_low = vmovl_s16(vget_low_s16(input2_val)); + int32x4_t input2_high = vmovl_s16(vget_high_s16(input2_val)); + + int16x4_t sum_low = ClacSumHalfWordMul3(input0_low, input1_low, input2_low, scale_param); + int16x4_t sum_high = ClacSumHalfWordMul3(input0_high, input1_high, input2_high, scale_param); - ScaleInnerWithBiasInt8(in_data, out_data, scale, offset, outer_start, outer_end, scale_param->axis_size_, - scale_param->inner_size_, scale_param, max, min); + int16x8_t res_s16 = vcombine_s16(sum_low, sum_high); + int8x8_t res_u8_n0 = vqmovn_s16(res_s16); + vst1_s8(out_data, res_u8_n0); + out_data += 8; + } +#endif + for (; index < real_dst_count; ++index) { + const int32_t input0_val = in_data[index] - scale_param->input_zp_; + const int32_t input1_val = scale[index] - scale_param->scale_zp_; + int32_t mul_result = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << scale_param->scale_mul_arg_.left_shift_), + scale_param->scale_mul_arg_.multiplier_), + scale_param->scale_mul_arg_.right_shift_); + int tmp_bias = offset[index] - scale_param->offset_zp_; + int bias = RoundingDivideByPOT( + SaturatingRoundingDoublingHighMul(tmp_bias * (1 << (unsigned int)scale_param->offset_mul_arg_.left_shift_), + scale_param->offset_mul_arg_.multiplier_), + scale_param->offset_mul_arg_.right_shift_); + + mul_result += bias + scale_param->output_zp_; + + if (mul_result > scale_param->output_activation_max_) { + out_data[index] = scale_param->output_activation_max_; + } else if (mul_result < scale_param->output_activation_min_) { + out_data[index] = scale_param->output_activation_min_; + } else { + out_data[index] = (int8_t)mul_result; + } + } + return; } diff --git a/mindspore/lite/nnacl/int8/scale_int8.h b/mindspore/lite/nnacl/int8/scale_int8.h index 7a4fcb855b..a773d6df1a 100644 --- a/mindspore/lite/nnacl/int8/scale_int8.h +++ b/mindspore/lite/nnacl/int8/scale_int8.h @@ -22,10 +22,10 @@ #ifdef __cplusplus extern "C" { #endif -void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, int task_id, - const ScaleParameter *scale_param, int max, int min); +void DoScaleInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const ScaleParameter *scale_param, + int real_dst_count); void DoScaleWithBiasInt8(const int8_t *in_data, int8_t *out_data, const int8_t *scale, const int8_t *offset, - int task_id, const ScaleParameter *scale_param, int max, int min); + const ScaleParameter *scale_param, int real_dst_count); #ifdef __cplusplus } #endif diff --git a/mindspore/lite/nnacl/scale.h b/mindspore/lite/nnacl/scale.h index fb9c881a93..8bcf0e1220 100644 --- a/mindspore/lite/nnacl/scale.h +++ b/mindspore/lite/nnacl/scale.h @@ -34,6 +34,8 @@ typedef struct ScaleParameter { int offset_zp_; int output_zp_; int activation_type_; + int output_activation_min_; + int output_activation_max_; } ScaleParameter; #endif // MINDSPORE_LITE_NNACL_SCALE_H_ diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.cc index a06c821dce..ba193432b4 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.cc @@ -19,6 +19,7 @@ #include #include #include "nnacl/int8/scale_int8.h" +#include "nnacl/arithmetic_common.h" #include "schema/model_generated.h" #include "src/kernel_registry.h" #include "include/errorcode.h" @@ -35,63 +36,65 @@ constexpr size_t kScaleInputsSize = 2; constexpr size_t kScaleBiasInputsSize = 3; } // namespace ScaleInt8CPUKernel::~ScaleInt8CPUKernel() { - if (scale_param_->const_scale_) { - if (scale_ != nullptr) { - free(scale_); - scale_ = nullptr; - } + if (tile_para != nullptr) { + free(tile_para); + tile_para = nullptr; } - if (has_bias_ && scale_param_->const_offset_) { - if (offset_ != nullptr) { - free(offset_); - offset_ = nullptr; - } + if (input1_data_ != nullptr && malloced_scale_) { + free(input1_data_); + } + if (input2_data_ != nullptr && malloced_offset_) { + free(input2_data_); } } int ScaleInt8CPUKernel::InitScaleOffset() { - auto scale_tensor = in_tensors_.at(1); - int8_t *scale_ptr = reinterpret_cast(in_tensors_.at(1)->data_c()); + CalcMultiplesAndStrides(tile_para); + scale_param_->const_scale_ = false; + auto *scale_ptr = reinterpret_cast(in_tensors_.at(1)->data_c()); + // scale may be const value ,can be processed in prepare stage if (scale_ptr != nullptr) { scale_param_->const_scale_ = true; - if (scale_ != nullptr) { - free(scale_); - scale_ = nullptr; - } - scale_ = reinterpret_cast(malloc(scale_tensor->ElementsNum() * sizeof(int8_t))); - if (scale_ == nullptr) { - MS_LOG(ERROR) << "Malloc buffer failed."; - return RET_ERROR; + input1_data_ = scale_ptr; + // need broadcasting + if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { + input1_data_ = reinterpret_cast(malloc(out_tensors_.at(0)->Size())); + if (input1_data_ == nullptr) { + MS_LOG(ERROR) << "malloc input1_data_ failed."; + return RET_ERROR; + } + malloced_scale_ = true; + TileOneDimensionUint8(reinterpret_cast(in_tensors_.at(1)->data_c()), + reinterpret_cast(input1_data_), 0, tile_para->ndim_, tile_para->in_shape1_, + tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_); } - memcpy(scale_, scale_ptr, scale_tensor->ElementsNum() * sizeof(int8_t)); - } else { - scale_param_->const_scale_ = false; - scale_ = nullptr; } + scale_param_->const_offset_ = false; if (in_tensors_.size() == 3) { has_bias_ = true; auto offset_tensor = in_tensors_.at(2); - int8_t *offset_ptr = reinterpret_cast(offset_tensor->data_c()); + auto *offset_ptr = reinterpret_cast(offset_tensor->data_c()); + // offset may be const value ,can be processed in prepare stage if (offset_ptr != nullptr) { scale_param_->const_offset_ = true; - if (offset_ != nullptr) { - free(offset_); - offset_ = nullptr; - } - offset_ = reinterpret_cast(malloc(offset_tensor->ElementsNum() * sizeof(int8_t))); - if (offset_ == nullptr) { - MS_LOG(ERROR) << "Malloc buffer failed."; - return RET_ERROR; + input2_data_ = offset_ptr; + // need broadcasting + if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(2)->ElementsNum()) { + input2_data_ = reinterpret_cast(malloc(out_tensors_.at(0)->Size())); + if (input2_data_ == nullptr) { + MS_LOG(ERROR) << "malloc input2_data_ failed."; + free(input1_data_); + return RET_ERROR; + } + malloced_offset_ = true; + TileOneDimensionUint8(reinterpret_cast(in_tensors_.at(2)->data_c()), + reinterpret_cast(input2_data_), 0, tile_para->ndim_, tile_para->in_shape1_, + tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_); } - memcpy(offset_, offset_ptr, offset_tensor->ElementsNum() * sizeof(int8_t)); - } else { - scale_param_->const_offset_ = false; - offset_ = nullptr; } - } else { - has_bias_ = false; } + return RET_OK; } @@ -102,29 +105,66 @@ int ScaleInt8CPUKernel::InitParameter() { auto scale_shape = scale_tensor->shape(); if (scale_param_->axis_ < 0) { - scale_param_->axis_ = scale_param_->axis_ + in_shape.size(); + scale_param_->axis_ += in_shape.size(); } if (scale_shape.size() + scale_param_->axis_ > in_shape.size()) { MS_LOG(ERROR) << "Scale tensor shape is incorrect."; return RET_ERROR; } - scale_param_->outer_size_ = 1; - scale_param_->axis_size_ = 1; - scale_param_->inner_size_ = 1; - for (int i = 0; i < scale_param_->axis_; i++) { - scale_param_->outer_size_ *= in_shape[i]; - } + for (size_t i = 0; i < scale_shape.size(); i++) { if (in_shape[i + scale_param_->axis_] != scale_shape[i]) { MS_LOG(ERROR) << "Scale tensor shape is incorrect."; return RET_ERROR; } - scale_param_->axis_size_ *= in_shape[i + scale_param_->axis_]; } - for (size_t i = scale_param_->axis_ + scale_shape.size(); i < in_shape.size(); i++) { - scale_param_->inner_size_ *= in_shape[i]; + + tile_para = reinterpret_cast(malloc(sizeof(ArithmeticParameter))); + if (tile_para == nullptr) { + MS_LOG(ERROR) << "malloc tile parameter failed."; + return RET_ERROR; } - scale_param_->op_parameter_.thread_num_ = MSMIN(scale_param_->op_parameter_.thread_num_, scale_param_->outer_size_); + size_t input0_size = in_tensors_.at(0)->shape().size(); + size_t input1_size = in_tensors_.at(1)->shape().size(); + size_t output_size = out_tensors_.at(0)->shape().size(); + auto input1_shape = in_tensors_.at(1)->shape(); + tile_para->ndim_ = output_size; + // supplement shape of scale tensor with number 1 + size_t len = input0_size - scale_param_->axis_; + second_in_shape_ = input1_shape; + if (len != input1_size) { + second_in_shape_.resize(len); + size_t i = 0; + for (; i < input1_size; ++i) { + second_in_shape_[i] = input1_shape[i]; + } + for (; i < len; ++i) { + second_in_shape_[i] = 1; + } + input1_size = len; + } + + if (input0_size == input1_size) { + for (size_t i = 0; i < output_size; i++) { + tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i); + tile_para->in_shape1_[i] = in_tensors_.at(1)->DimensionSize(i); + tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i); + } + } else { + MS_ASSERT(input0_size > input1_size); + size_t fill_dim_num = input0_size - input1_size; + int j = 0; + for (size_t i = 0; i < output_size; i++) { + tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i); + if (i < fill_dim_num) { + tile_para->in_shape1_[i] = 1; + } else { + tile_para->in_shape1_[i] = second_in_shape_[j++]; + } + tile_para->out_shape_[i] = out_tensors_.at(0)->DimensionSize(i); + } + } + return RET_OK; } @@ -156,6 +196,24 @@ int ScaleInt8CPUKernel::InitQuantArgs() { scale_param_->offset_mul_arg_.left_shift_ = shift > 0 ? shift : 0; scale_param_->offset_mul_arg_.right_shift_ = shift < 0 ? -shift : 0; } + + switch (scale_param_->activation_type_) { + case schema::ActivationType_RELU: + scale_param_->output_activation_min_ = 0; + scale_param_->output_activation_max_ = INT8_MAX; + break; + case schema::ActivationType_RELU6: + scale_param_->output_activation_min_ = 0; + scale_param_->output_activation_max_ = 6; + break; + case schema::ActivationType_NO_ACTIVATION: + scale_param_->output_activation_min_ = INT8_MIN; + scale_param_->output_activation_max_ = INT8_MAX; + break; + default: + MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_; + return RET_ERROR; + } return RET_OK; } @@ -176,13 +234,13 @@ int ScaleInt8CPUKernel::Init() { int ScaleInt8CPUKernel::ReSize() { auto ret = InitParameter(); if (ret != RET_OK) { - MS_LOG(ERROR) << "Scale fp32 InitParameter failed."; + MS_LOG(ERROR) << "Scale int8 InitParameter failed."; return RET_ERROR; } ret = InitScaleOffset(); if (ret != RET_OK) { - MS_LOG(ERROR) << "Scale fp32 InitScaleOffset failed."; + MS_LOG(ERROR) << "Scale int8 InitScaleOffset failed."; return RET_ERROR; } @@ -195,38 +253,21 @@ int ScaleInt8CPUKernel::ReSize() { } int ScaleInt8CPUKernel::Scale(int task_id) { + int real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_); + if (real_dst_count <= 0) { + return lite::RET_OK; + } + int8_t *cur_input0_data = input0_data_ + task_id * count_unit_; + int8_t *cur_input1_data = input1_data_ + task_id * count_unit_; + int8_t *cur_output_data = output_data_ + task_id * count_unit_; + if (has_bias_) { - switch (scale_param_->activation_type_) { - case schema::ActivationType_RELU: - DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, INT8_MAX, 0); - break; - case schema::ActivationType_RELU6: - DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, 6, 0); - break; - case schema::ActivationType_NO_ACTIVATION: - DoScaleWithBiasInt8(input_ptr_, output_ptr_, scale_, offset_, task_id, scale_param_, INT8_MAX, INT8_MIN); - break; - default: - MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_; - return RET_ERROR; - } + int8_t *cur_input2_data = input2_data_ + task_id * count_unit_; + DoScaleWithBiasInt8(cur_input0_data, cur_output_data, cur_input1_data, cur_input2_data, scale_param_, + real_dst_count); } else { - switch (scale_param_->activation_type_) { - case schema::ActivationType_RELU: - DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, INT8_MAX, 0); - break; - case schema::ActivationType_RELU6: - DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, 6, 0); - break; - case schema::ActivationType_NO_ACTIVATION: - DoScaleInt8(input_ptr_, output_ptr_, scale_, task_id, scale_param_, INT8_MAX, INT8_MIN); - break; - default: - MS_LOG(ERROR) << "Scale does not support activation type " << scale_param_->activation_type_; - return RET_ERROR; - } + DoScaleInt8(cur_input0_data, cur_output_data, cur_input1_data, scale_param_, real_dst_count); } - return RET_OK; } @@ -241,18 +282,59 @@ int ScaleRunInt8(void *cdata, int task_id) { } int ScaleInt8CPUKernel::Run() { - auto in_tensor = in_tensors_.front(); - input_ptr_ = reinterpret_cast(in_tensor->data_c()); - if (scale_ == nullptr) { - auto scale_tensor = in_tensors_[1]; - scale_ = reinterpret_cast(scale_tensor->data_c()); + elements_num_ = out_tensors_.at(0)->ElementsNum(); + count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_; + input0_data_ = reinterpret_cast(in_tensors_.at(0)->data_c()); + output_data_ = reinterpret_cast(out_tensors_.at(0)->data_c()); + + // need broadcasting + if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) { + // scale is passed by previous node, need do broadcasting online + if (!scale_param_->const_scale_) { + input1_data_ = reinterpret_cast(ctx_->allocator->Malloc(out_tensors_.at(0)->Size())); + if (input1_data_ == nullptr) { + MS_LOG(ERROR) << "malloc input1_data_ failed."; + return RET_ERROR; + } + TileOneDimensionUint8(reinterpret_cast(in_tensors_.at(1)->data_c()), + reinterpret_cast(input1_data_), 0, tile_para->ndim_, tile_para->in_shape1_, + tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_); + } + + // If has bias, bias is passed by previous node case, need do broadcasting online + if (has_bias_ && !scale_param_->const_offset_) { + input2_data_ = reinterpret_cast(ctx_->allocator->Malloc(out_tensors_.at(0)->Size())); + if (input2_data_ == nullptr) { + MS_LOG(ERROR) << "malloc input2_data_ failed."; + ctx_->allocator->Free(input1_data_); + input1_data_ = nullptr; + return RET_ERROR; + } + TileOneDimensionUint8(reinterpret_cast(in_tensors_.at(2)->data_c()), + reinterpret_cast(input2_data_), 0, tile_para->ndim_, tile_para->in_shape1_, + tile_para->in_strides1_, tile_para->out_strides_, tile_para->multiples1_); + } + + auto ret = ParallelLaunch(this->context_->thread_pool_, ScaleRunInt8, this, op_parameter_->thread_num_); + // free memory malloced from memory pool + if (!scale_param_->const_scale_) { + ctx_->allocator->Free(input1_data_); + input1_data_ = nullptr; + } + if (has_bias_ && !scale_param_->const_offset_) { + ctx_->allocator->Free(input2_data_); + input2_data_ = nullptr; + } + return ret; + } + + // input1 has the same shape with input0 situation + if (input1_data_ == nullptr) { + input1_data_ = reinterpret_cast(in_tensors_.at(1)->data_c()); } if (has_bias_ && !scale_param_->const_offset_) { - offset_ = reinterpret_cast(in_tensors_.at(2)->data_c()); + input2_data_ = reinterpret_cast(in_tensors_.at(2)->data_c()); } - auto out_tensor = out_tensors_.front(); - output_ptr_ = reinterpret_cast(out_tensor->data_c()); - auto ret = ParallelLaunch(this->context_->thread_pool_, ScaleRunInt8, this, op_parameter_->thread_num_); if (ret != RET_OK) { MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; @@ -260,6 +342,7 @@ int ScaleInt8CPUKernel::Run() { } return RET_OK; } + kernel::LiteKernel *CpuScaleInt8KernelCreator(const std::vector &inputs, const std::vector &outputs, OpParameter *opParameter, const lite::InnerContext *ctx, const kernel::KernelKey &desc, diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.h b/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.h index db76d4373e..17c896fa68 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.h +++ b/mindspore/lite/src/runtime/kernel/arm/int8/scale_int8.h @@ -21,6 +21,7 @@ #include "src/lite_kernel.h" #include "nnacl/scale.h" #include "nnacl/quantization/quantize.h" +#include "nnacl/arithmetic_common.h" namespace mindspore::kernel { @@ -29,7 +30,7 @@ class ScaleInt8CPUKernel : public LiteKernel { ScaleInt8CPUKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs, const lite::InnerContext *ctx, const mindspore::lite::PrimitiveC *primitive) - : LiteKernel(parameter, inputs, outputs, ctx, primitive) { + : LiteKernel(parameter, inputs, outputs, ctx, primitive), ctx_(ctx), thread_count_(ctx_->thread_num_) { scale_param_ = reinterpret_cast(op_parameter_); } ~ScaleInt8CPUKernel() override; @@ -42,12 +43,20 @@ class ScaleInt8CPUKernel : public LiteKernel { int Scale(int task_id); private: - int8_t *input_ptr_ = nullptr; - int8_t *scale_ = nullptr; - int8_t *offset_ = nullptr; - int8_t *output_ptr_ = nullptr; - bool has_bias_ = false; + int8_t *input0_data_ = nullptr; + int8_t *input1_data_ = nullptr; + int8_t *input2_data_ = nullptr; + int8_t *output_data_ = nullptr; + const lite::InnerContext *ctx_; ScaleParameter *scale_param_; + ArithmeticParameter *tile_para = nullptr; + std::vector second_in_shape_; + int thread_count_; + int64_t elements_num_; + int64_t count_unit_; + bool has_bias_ = false; + bool malloced_scale_ = false; + bool malloced_offset_ = false; int InitQuantArgs(); };