optimize int8 mul op && delete useless code

pull/8542/head
fuzhiye 4 years ago
parent 9a8d4da5dc
commit d8467311f4

@ -125,25 +125,15 @@ void IndirectGemmFp16_16x8_c8(float16_t *output, float16_t *input, float16_t *we
void ConvFp16(float16_t *input_data, float16_t *packed_input, float16_t *packed_weight, float16_t *bias_data,
float16_t *col_major_input, float16_t *output_data, int task_id, ConvParameter *conv_param) {
const int tile_n = 16;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int thread_count = conv_param->thread_num_;
int output_count = out_h * out_w;
int output_count = conv_param->output_h_ * conv_param->output_w_;
int output_tile_count = UP_DIV(output_count, tile_n);
int kernel_plane = kernel_h * kernel_w;
int deep = kernel_plane * in_channel;
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * in_channel * in_h * in_w;
int out_batch_offset = b * out_channel * out_h * out_w;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
for (int b = 0; b < conv_param->input_batch_; b++) {
int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
int out_batch_offset = b * out_channel * output_count;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int start_index = thread_id * tile_n;
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
float16_t *gemm_input = packed_input + task_id * deep * tile_n;
@ -166,18 +156,13 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
float16_t *output_data, TmpBufferAddressFp16 *buffer_list, int task_id, ConvParameter *conv_param,
InputTransFp16Func in_func, OutputTransFp16Func out_func) {
const int tile_num = 16;
int thread_num = conv_param->thread_num_;
int input_unit = conv_param->input_unit_;
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
int out_unit = conv_param->output_unit_;
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, tile_num);
int out_channel = conv_param->output_channel_;
int oc8 = UP_DIV(out_channel, C8NUM);
int input_unit_square = input_unit * input_unit;
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_;
float16_t *trans_input = buffer_list[0];
float16_t *gemm_out = buffer_list[1];
@ -189,10 +174,10 @@ void ConvWinogardFp16(float16_t *input_data, float16_t *trans_weight, const floa
int col_buffer_offset = tile_num * in_channel;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
for (int b = 0; b < conv_param->input_batch_; b++) {
int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_h_ * conv_param->output_w_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int out_tile_index = thread_id * tile_num;
int cal_num = output_count - thread_id * tile_num;
cal_num = cal_num > tile_num ? tile_num : cal_num;

@ -23,30 +23,20 @@
// fp32 conv common
void ConvFp32(const float *input_data, float *packed_input, const float *packed_weight, const float *bias_data,
float *col_major_input, float *output_data, int task_id, ConvParameter *conv_param) {
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int thread_count = conv_param->thread_num_;
int output_count = out_h * out_w;
int deep = conv_param->kernel_h_ * conv_param->kernel_w_ * conv_param->input_channel_;
int output_count = conv_param->output_h_ * conv_param->output_w_;
#if defined(ENABLE_ARM32) || defined(ENABLE_X86_64_SSE)
const int cal_num = C4NUM;
#else
const int cal_num = C12NUM;
#endif
int output_tile_count = UP_DIV(output_count, cal_num);
int kernel_plane = kernel_h * kernel_w;
int deep = kernel_plane * in_channel;
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * in_channel * in_h * in_w;
int out_batch_offset = b * out_channel * out_h * out_w;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
for (int b = 0; b < conv_param->input_batch_; b++) {
int in_batch_offset = b * conv_param->input_channel_ * conv_param->input_h_ * conv_param->input_w_;
int out_batch_offset = b * out_channel * output_count;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int start_index = thread_id * cal_num;
int real_cal_num = (output_count - start_index) < cal_num ? (output_count - start_index) : cal_num;
float *gemm_input = packed_input + task_id * deep * cal_num;
@ -73,19 +63,14 @@ void ConvFp32(const float *input_data, float *packed_input, const float *packed_
void ConvWinogardFp32(const float *input_data, const float *trans_weight, const float *bias_data, float *output_data,
TmpBufferAddress *buffer_list, int task_id, ConvParameter *conv_param, InputTransFunc in_func,
OutputTransFunc out_func) {
int thread_num = conv_param->thread_num_;
int input_unit = conv_param->input_unit_;
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
int out_unit = conv_param->output_unit_;
int out_w_block = UP_DIV(conv_param->output_w_, out_unit);
int out_h_block = UP_DIV(conv_param->output_h_, out_unit);
int out_w_block = UP_DIV(conv_param->output_w_, conv_param->output_unit_);
int out_h_block = UP_DIV(conv_param->output_h_, conv_param->output_unit_);
int output_count = out_w_block * out_h_block;
const int tile_num = C12NUM;
int output_tile_count = UP_DIV(output_count, tile_num);
int out_channel = conv_param->output_channel_;
int oc8 = UP_DIV(out_channel, C8NUM);
int input_unit_square = input_unit * input_unit;
int oc8 = UP_DIV(conv_param->output_channel_, C8NUM);
int input_unit_square = conv_param->input_unit_ * conv_param->input_unit_;
float *trans_input = buffer_list[0];
float *gemm_out = buffer_list[1];
@ -97,10 +82,10 @@ void ConvWinogardFp32(const float *input_data, const float *trans_weight, const
int col_buffer_offset = tile_num * in_channel;
// step 1 : filter transform (pre-processed offline)
// step 2 : input transform (online)
for (int b = 0; b < in_batch; b++) {
for (int b = 0; b < conv_param->input_batch_; b++) {
int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
int out_batch_offset = b * out_channel * conv_param->output_w_ * conv_param->output_h_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_num) {
int out_batch_offset = b * conv_param->output_channel_ * conv_param->output_w_ * conv_param->output_h_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int out_tile_index = thread_id * tile_num;
int cal_num = output_count - out_tile_index;
cal_num = cal_num > tile_num ? tile_num : cal_num;

@ -20,10 +20,6 @@
int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id, float minf,
float maxf) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
@ -32,10 +28,8 @@ int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pool
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
int window = win_w * win_h;
#ifdef ENABLE_NEON
@ -43,18 +37,18 @@ int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pool
float32x4_t max_value = vdupq_n_f32(maxf);
#endif
for (int batch = 0; batch < output_batch; batch++) {
for (int batch = 0; batch < pooling_param->output_batch_; batch++) {
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
for (int i = 0; i < real_cal_num; i++) {
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_;
int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_;
const float *src_plane_ptr = src_b_ptr;
float *dst_plane_ptr = dst_b_ptr + index * channel;
@ -152,10 +146,6 @@ int AvgPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pool
void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *pooling_param, int task_id, float minf,
float maxf) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
@ -166,7 +156,6 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = pooling_param->thread_num_;
int c4 = channel / C4NUM; /* oc && ic */
#ifdef ENABLE_NEON
@ -177,15 +166,15 @@ void MaxPooling(const float *input_ptr, float *output_ptr, PoolingParameter *poo
for (int batch = 0; batch < output_batch; batch++) {
const float *src_b_ptr = input_ptr + batch * in_h * in_w * channel;
float *dst_b_ptr = output_ptr + batch * output_h * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += pooling_param->thread_num_) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
for (int i = 0; i < real_cal_num; i++) {
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_;
int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_;
const float *src_plane_ptr = src_b_ptr;
float *dst_plane_ptr = dst_b_ptr + index * channel;

@ -65,20 +65,12 @@ void Conv3x3Int8Gemm(int32_t *dst, const int16_t *src, const int16_t *weight, in
void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, int8_t *packed_weight,
const int32_t *bias_data, int8_t *output_data, int32_t *filter_zp, int32_t *input_sum, int task_id,
ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func, bool is_optimize) {
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int in_batch = conv_param->input_batch_;
int in_channel = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int tile_n = conv_param->tile_num_;
int thread_count = conv_param->thread_num_;
int output_count = out_h * out_w;
int output_count = conv_param->output_h_ * conv_param->output_w_;
int output_tile_count = UP_DIV(output_count, tile_n);
int kernel_plane = kernel_h * kernel_w;
int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_;
int unit_size;
int input_sum_offset;
int up_round_oc;
@ -103,10 +95,10 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *matmul_input, in
per_channel = false;
}
for (int b = 0; b < in_batch; b++) {
int in_batch_offset = b * in_channel * in_h * in_w;
int out_batch_offset = b * out_channel * out_h * out_w;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
for (int b = 0; b < conv_param->input_batch_; b++) {
int in_batch_offset = b * in_channel * conv_param->input_h_ * conv_param->input_w_;
int out_batch_offset = b * out_channel * conv_param->output_h_ * conv_param->output_w_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int start_index = thread_id * tile_n;
int real_cal_num = (output_count - start_index) < tile_n ? (output_count - start_index) : tile_n;
int32_t *tmp_input_sum = input_sum + task_id * input_sum_offset;
@ -858,23 +850,20 @@ void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,
int16_t *tile_buffer, int16_t *block_unit_buffer, int32_t *tmp_dst_buffer, int8_t *tmp_out,
int task_id, ConvParameter *conv_param) {
int thread_count = conv_param->thread_num_;
int ic8 = UP_DIV(conv_param->input_channel_, C8NUM);
int output_channel = conv_param->output_channel_;
int out_w_block = UP_DIV(conv_param->output_w_, OUPUT_UNIT);
int out_h_block = UP_DIV(conv_param->output_h_, OUPUT_UNIT);
int output_count = out_w_block * out_h_block;
int output_tile_count = UP_DIV(output_count, TILE_NUM);
int oc4 = UP_DIV(output_channel, C4NUM);
int oc4 = UP_DIV(conv_param->output_channel_, C4NUM);
int tile_buffer_offset = TILE_NUM * 16 * ic8 * C8NUM;
const int block_unit_buffer_offset = 16 * C8NUM;
int tmp_dst_buffer_offset = TILE_NUM * 16 * oc4 * C4NUM;
int input_batch = conv_param->input_batch_;
for (int batch = 0; batch < input_batch; batch++) {
for (int batch = 0; batch < conv_param->input_batch_; batch++) {
int in_batch_offset = batch * ic8 * C8NUM * conv_param->input_h_ * conv_param->input_w_;
int tmp_out_batch_offset = batch * oc4 * C4NUM * conv_param->output_w_ * conv_param->output_h_;
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += thread_count) {
for (int thread_id = task_id; thread_id < output_tile_count; thread_id += conv_param->thread_num_) {
int start_index = thread_id * TILE_NUM;
int real_cal_num = (output_count - start_index) < TILE_NUM ? (output_count - start_index) : TILE_NUM;
@ -883,7 +872,7 @@ void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bi
out_w_block, conv_param);
Conv3x3Int8Gemm(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tile_buffer + task_id * tile_buffer_offset,
transed_weight, output_channel, ic8, real_cal_num);
transed_weight, conv_param->output_channel_, ic8, real_cal_num);
Conv3x3Int8OutputTransform(tmp_dst_buffer + task_id * tmp_dst_buffer_offset, tmp_out + tmp_out_batch_offset,
bias_data, start_index, real_cal_num, out_w_block, conv_param);

@ -24,15 +24,13 @@
#ifdef ENABLE_NEON
int16x4_t ClacSumHalfWordMul(int32x4_t scaled_input0, int32x4_t scaled_input1, int32x4_t left_shift_out_vec,
int32x4_t output_multiplier_vec, MulQuantArg para) {
int32x4_t input_scale = vmulq_s32(scaled_input0, scaled_input1);
int32x4_t raw_sum = RoundingDivideByPOTInt32x4(
SaturatingRoundingDoublingHighMulInt32x4(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec),
para.shift_right_);
raw_sum = vaddq_s32(raw_sum, vdupq_n_s32(para.out_quant_arg_.zp_));
raw_sum = vmaxq_s32(raw_sum, vdupq_n_s32(para.output_activation_min_));
raw_sum = vminq_s32(raw_sum, vdupq_n_s32(para.output_activation_max_));
int16x4_t ClacSumHalfWordMul(int16x4_t scaled_input0, int16x4_t scaled_input1, int32x4_t left_shift_out_vec,
int32x4_t right_shift_out_vec, int32x4_t output_multiplier_vec) {
int32x4_t input_scale = vmull_s16(scaled_input0, scaled_input1);
int32x4_t raw_sum = vqrdmulhq_s32(vmulq_s32(input_scale, left_shift_out_vec), output_multiplier_vec);
const int32x4_t fixup = vshrq_n_s32(vandq_s32(raw_sum, right_shift_out_vec), 31);
const int32x4_t fixed_up_x = vqaddq_s32(raw_sum, fixup);
raw_sum = vrshlq_s32(fixed_up_x, right_shift_out_vec);
return vqmovn_s32(raw_sum);
}
@ -40,27 +38,189 @@ void MulInt8NEON(int8_t *input0_data, int8_t *input1_data, int8_t *output_data,
MulQuantArg para, int *index) {
int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_);
int32x4_t left_shift_out_vec = vdupq_n_s32(1 << para.shift_left_);
int32x4_t right_shift_out_vec = vdupq_n_s32(-para.shift_right_);
int16x8_t out_zp_vec = vdupq_n_s16(para.out_quant_arg_.zp_);
int8x16_t out_min_vec = vdupq_n_s8(para.output_activation_min_);
int8x16_t out_max_vec = vdupq_n_s8(para.output_activation_max_);
int8x8_t out_min_vec_s8 = vdup_n_s8(para.output_activation_min_);
int8x8_t out_max_vec_s8 = vdup_n_s8(para.output_activation_max_);
for (; (*index) <= real_dst_count - 16; (*index) += 16) {
int16x8_t zp1_vec = vdupq_n_s16(para.in_quant_args_[0].zp_);
int16x8_t zp2_vec = vdupq_n_s16(para.in_quant_args_[1].zp_);
int8x16_t input0_vec = vld1q_s8(input0_data + *index);
int8x16_t input1_vec = vld1q_s8(input1_data + *index);
int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec));
int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec));
int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec));
int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec));
input0_low = vaddq_s16(input0_low, zp1_vec);
input0_high = vaddq_s16(input0_high, zp1_vec);
input1_low = vaddq_s16(input1_low, zp2_vec);
input1_high = vaddq_s16(input1_high, zp2_vec);
int16x4_t input0_low_low = vget_low_s16(input0_low);
int16x4_t input0_low_high = vget_high_s16(input0_low);
int16x4_t input0_high_low = vget_low_s16(input0_high);
int16x4_t input0_high_high = vget_high_s16(input0_high);
int16x4_t input1_low_low = vget_low_s16(input1_low);
int16x4_t input1_low_high = vget_high_s16(input1_low);
int16x4_t input1_high_low = vget_low_s16(input1_high);
int16x4_t input1_high_high = vget_high_s16(input1_high);
int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec, right_shift_out_vec,
output_multiplier_vec);
int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec,
right_shift_out_vec, output_multiplier_vec);
int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec,
right_shift_out_vec, output_multiplier_vec);
int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec,
right_shift_out_vec, output_multiplier_vec);
int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec);
int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec);
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
int8x8_t res_u8_n1 = vqmovn_s16(res_s162);
int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1);
res_s8 = vminq_s8(res_s8, out_max_vec);
res_s8 = vmaxq_s8(res_s8, out_min_vec);
vst1q_s8(output_data, res_s8);
output_data += 16;
}
for (; (*index) <= real_dst_count - 8; (*index) += 8) {
int16x8_t input0_val = LoadAndAddOffset(input0_data, *index, para.in_quant_args_[0].zp_);
int16x8_t input1_val = LoadAndAddOffset(input1_data, *index, para.in_quant_args_[1].zp_);
int32x4_t input0_low = vmovl_s16(vget_low_s16(input0_val));
int32x4_t input0_high = vmovl_s16(vget_high_s16(input0_val));
int32x4_t input1_low = vmovl_s16(vget_low_s16(input1_val));
int32x4_t input1_high = vmovl_s16(vget_high_s16(input1_val));
int16x4_t input0_low = vget_low_s16(input0_val);
int16x4_t input0_high = vget_high_s16(input0_val);
int16x4_t input1_low = vget_low_s16(input1_val);
int16x4_t input1_high = vget_high_s16(input1_val);
int16x4_t sum_low = ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, output_multiplier_vec, para);
int16x4_t sum_high = ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, output_multiplier_vec, para);
int16x4_t sum_low =
ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec);
int16x4_t sum_high =
ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec);
int16x8_t res_s16 = vcombine_s16(sum_low, sum_high);
int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec);
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8);
res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8);
vst1_s8(output_data, res_u8_n0);
output_data += 8;
}
}
#endif
void FastMul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int depth, int64_t real_dst_count,
bool input1_broad, MulQuantArg para) {
// input0 need broadcast
int32_t zp1 = para.in_quant_args_[0].zp_;
int32_t zp2 = para.in_quant_args_[1].zp_;
if (input1_broad) {
zp1 = para.in_quant_args_[1].zp_;
zp2 = para.in_quant_args_[0].zp_;
}
#ifdef ENABLE_ARM
int32x4_t output_multiplier_vec = vdupq_n_s32(para.output_multiplier_);
int32x4_t left_shift_out_vec = vdupq_n_s32(1 << para.shift_left_);
int32x4_t right_shift_out_vec = vdupq_n_s32(-para.shift_right_);
int16x8_t out_zp_vec = vdupq_n_s16(para.out_quant_arg_.zp_);
int8x16_t out_min_vec = vdupq_n_s8(para.output_activation_min_);
int8x16_t out_max_vec = vdupq_n_s8(para.output_activation_max_);
int8x8_t out_min_vec_s8 = vdup_n_s8(para.output_activation_min_);
int8x8_t out_max_vec_s8 = vdup_n_s8(para.output_activation_max_);
int16x8_t zp1_vec = vdupq_n_s16(zp1);
int16x8_t zp2_vec = vdupq_n_s16(zp2);
#endif
for (int index = 0; index < real_dst_count; ++index) {
int j = 0;
#ifdef ENABLE_ARM
for (; j <= depth - 16; j += 16) {
int8x16_t input0_vec = vld1q_s8(input0_data + j);
int8x16_t input1_vec = vld1q_s8(input1_data);
int16x8_t input0_low = vmovl_s8(vget_low_s8(input0_vec));
int16x8_t input0_high = vmovl_s8(vget_high_s8(input0_vec));
int16x8_t input1_low = vmovl_s8(vget_low_s8(input1_vec));
int16x8_t input1_high = vmovl_s8(vget_high_s8(input1_vec));
input0_low = vaddq_s16(input0_low, zp1_vec);
input0_high = vaddq_s16(input0_high, zp1_vec);
input1_low = vaddq_s16(input1_low, zp2_vec);
input1_high = vaddq_s16(input1_high, zp2_vec);
int16x4_t input0_low_low = vget_low_s16(input0_low);
int16x4_t input0_low_high = vget_high_s16(input0_low);
int16x4_t input0_high_low = vget_low_s16(input0_high);
int16x4_t input0_high_high = vget_high_s16(input0_high);
int16x4_t input1_low_low = vget_low_s16(input1_low);
int16x4_t input1_low_high = vget_high_s16(input1_low);
int16x4_t input1_high_low = vget_low_s16(input1_high);
int16x4_t input1_high_high = vget_high_s16(input1_high);
int16x4_t sum_low_low = ClacSumHalfWordMul(input0_low_low, input1_low_low, left_shift_out_vec,
right_shift_out_vec, output_multiplier_vec);
int16x4_t sum_low_high = ClacSumHalfWordMul(input0_low_high, input1_low_high, left_shift_out_vec,
right_shift_out_vec, output_multiplier_vec);
int16x4_t sum_high_low = ClacSumHalfWordMul(input0_high_low, input1_high_low, left_shift_out_vec,
right_shift_out_vec, output_multiplier_vec);
int16x4_t sum_high_high = ClacSumHalfWordMul(input0_high_high, input1_high_high, left_shift_out_vec,
right_shift_out_vec, output_multiplier_vec);
int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low_low, sum_low_high), out_zp_vec);
int16x8_t res_s162 = vaddq_s16(vcombine_s16(sum_high_low, sum_high_high), out_zp_vec);
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
int8x8_t res_u8_n1 = vqmovn_s16(res_s162);
int8x16_t res_s8 = vcombine_s8(res_u8_n0, res_u8_n1);
res_s8 = vminq_s8(res_s8, out_max_vec);
res_s8 = vmaxq_s8(res_s8, out_min_vec);
vst1q_s8(output_data, res_s8);
input1_data += 16;
output_data += 16;
}
for (; j <= depth - 8; j += 8) {
int8x8_t input0_vec = vld1_s8(input0_data + j);
int8x8_t input1_vec = vld1_s8(input1_data);
int16x8_t input0_val = vmovl_s8(input0_vec);
int16x8_t input1_val = vmovl_s8(input1_vec);
input0_val = vaddq_s16(input0_val, zp1_vec);
input1_val = vaddq_s16(input1_val, zp2_vec);
int16x4_t input0_low = vget_low_s16(input0_val);
int16x4_t input0_high = vget_high_s16(input0_val);
int16x4_t input1_low = vget_low_s16(input1_val);
int16x4_t input1_high = vget_high_s16(input1_val);
int16x4_t sum_low =
ClacSumHalfWordMul(input0_low, input1_low, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec);
int16x4_t sum_high =
ClacSumHalfWordMul(input0_high, input1_high, left_shift_out_vec, right_shift_out_vec, output_multiplier_vec);
int16x8_t res_s16 = vaddq_s16(vcombine_s16(sum_low, sum_high), out_zp_vec);
int8x8_t res_u8_n0 = vqmovn_s16(res_s16);
res_u8_n0 = vmin_s8(res_u8_n0, out_max_vec_s8);
res_u8_n0 = vmax_s8(res_u8_n0, out_min_vec_s8);
vst1_s8(output_data, res_u8_n0);
input1_data += 8;
output_data += 8;
}
#endif
for (; j < depth; ++j) {
const int32_t input0_val = zp1 + input0_data[j];
const int32_t input1_val = zp2 + input1_data[0];
int32_t mul_result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(input0_val * input1_val * (1 << para.shift_left_), para.output_multiplier_),
para.shift_right_);
mul_result += para.out_quant_arg_.zp_;
mul_result = mul_result < para.output_activation_max_ ? mul_result : para.output_activation_max_;
mul_result = mul_result > para.output_activation_min_ ? mul_result : para.output_activation_min_;
output_data[0] = (int8_t)mul_result;
input1_data++;
output_data++;
}
}
return;
}
void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, MulQuantArg para) {
int index = 0;
#ifdef ENABLE_NEON
@ -74,14 +234,10 @@ void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t
para.shift_right_);
mul_result += para.out_quant_arg_.zp_;
if (mul_result > para.output_activation_max_) {
output_data[index] = para.output_activation_max_;
} else if (mul_result < para.output_activation_min_) {
output_data[index] = para.output_activation_min_;
} else {
output_data[index] = (int8_t)mul_result;
}
mul_result = mul_result < para.output_activation_max_ ? mul_result : para.output_activation_max_;
mul_result = mul_result > para.output_activation_min_ ? mul_result : para.output_activation_min_;
output_data[0] = (int8_t)mul_result;
output_data++;
}
return;
}

@ -24,6 +24,8 @@
extern "C" {
#endif
void Mul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int64_t real_dst_count, MulQuantArg para);
void FastMul(int8_t *input0_data, int8_t *input1_data, int8_t *output_data, int depth, int64_t real_dst_count,
bool input1_broad, MulQuantArg para);
#ifdef __cplusplus
}
#endif

@ -80,33 +80,24 @@ int AvgPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter
}
int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int c16 = channel / C16NUM;
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_plane = output_w * pooling_param->output_h_;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
float input_scale = pooling_param->quant_args_[0][0].scale_;
int input_zp = pooling_param->quant_args_[0][0].zp_;
float output_scale = pooling_param->quant_args_[1][0].scale_;
int output_zp = pooling_param->quant_args_[1][0].zp_;
double real_multiplier = input_scale / output_scale;
int c16 = channel / C16NUM;
double real_multiplier = pooling_param->quant_args_[0][0].scale_ / pooling_param->quant_args_[1][0].scale_;
const int8_t out_min = INT8_MIN;
const int8_t out_max = INT8_MAX;
for (int batch = 0; batch < output_batch; batch++) {
int in_batch_offset = batch * in_h * in_w * channel;
int out_batch_offset = batch * output_h * output_w * channel;
for (int batch = 0; batch < pooling_param->output_batch_; batch++) {
int in_batch_offset = batch * pooling_param->input_h_ * in_w * channel;
int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
@ -114,14 +105,14 @@ int AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParame
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_;
int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_;
int out_plane_offset = out_batch_offset + index * channel;
int input_stride = (in_h_index * in_w + in_w_index) * channel;
int kw_s = MSMAX(0, -in_w_index);
int kw_e = MSMIN(win_w, in_w - in_w_index);
int kh_s = MSMAX(0, -in_h_index);
int kh_e = MSMIN(win_h, in_h - in_h_index);
int kh_e = MSMIN(win_h, pooling_param->input_h_ - in_h_index);
int real_count = (kw_e - kw_s) * (kh_e - kh_s);
if (real_count == 0) {
return NNACL_ERR;
@ -335,19 +326,11 @@ void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParamete
void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param,
int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_plane = output_w * pooling_param->output_h_;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
int c16 = UP_DIV(channel, 16);
@ -358,9 +341,9 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin
int output_zp = pooling_param->quant_args_[1][0].zp_;
double real_multiplier = input_scale / output_scale;
for (int batch = 0; batch < output_batch; batch++) {
for (int batch = 0; batch < pooling_param->output_batch_; batch++) {
int in_batch_offset = batch * in_h * in_w * channel;
int out_batch_offset = batch * output_h * output_w * channel;
int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
@ -368,8 +351,8 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_;
int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_;
int out_plane_offset = out_batch_offset + index * channel;
for (int j = 0; j < c16 - 1; j++) {
int in_channel_offset = in_batch_offset + j * 16;
@ -382,8 +365,8 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin
tmp_max[m] = INT8_MIN;
}
#endif
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
for (int h = 0; h < pooling_param->window_h_; h++) {
for (int w = 0; w < pooling_param->window_w_; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
@ -418,8 +401,8 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin
int in_channel_offset = in_batch_offset + k;
int out_channel_offset = out_plane_offset + k;
int8_t tmp_max = INT8_MIN;
for (int h = 0; h < win_h; h++) {
for (int w = 0; w < win_w; w++) {
for (int h = 0; h < pooling_param->window_h_; h++) {
for (int w = 0; w < pooling_param->window_w_; w++) {
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
(in_w_index + w) >= in_w) {
continue;
@ -437,26 +420,17 @@ void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, Poolin
}
void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) {
int stride_w = pooling_param->stride_w_;
int stride_h = pooling_param->stride_h_;
int pad_w = pooling_param->pad_l_;
int pad_h = pooling_param->pad_u_;
int win_w = pooling_param->window_w_;
int win_h = pooling_param->window_h_;
int channel = pooling_param->input_channel_;
int in_w = pooling_param->input_w_;
int in_h = pooling_param->input_h_;
int output_w = pooling_param->output_w_;
int output_h = pooling_param->output_h_;
int output_batch = pooling_param->output_batch_;
int out_plane = output_w * output_h;
int out_plane = output_w * pooling_param->output_h_;
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
int thread_num = MSMIN(out_tile_count, pooling_param->thread_num_);
int8_t out_array[MAX_MAXPOOL_SIZE];
for (int batch = 0; batch < output_batch; batch++) {
int in_batch_offset = batch * in_h * in_w * channel;
int out_batch_offset = batch * output_h * output_w * channel;
for (int batch = 0; batch < pooling_param->output_batch_; batch++) {
int in_batch_offset = batch * pooling_param->input_h_ * in_w * channel;
int out_batch_offset = batch * pooling_param->output_h_ * output_w * channel;
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
int cal_start_index = thread_id * TILE_NUM;
int real_cal_num = out_plane - cal_start_index;
@ -465,12 +439,12 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
int index = cal_start_index + i;
int out_w_index = index % output_w;
int out_h_index = index / output_w;
int in_w_index = out_w_index * stride_w - pad_w;
int in_h_index = out_h_index * stride_h - pad_h;
int in_w_index = out_w_index * pooling_param->stride_w_ - pooling_param->pad_l_;
int in_h_index = out_h_index * pooling_param->stride_h_ - pooling_param->pad_u_;
const int ky_s = 0 > (-in_h_index) ? 0 : (-in_h_index);
int ky_e = MSMIN(win_h, in_h - in_h_index);
int ky_e = MSMIN(pooling_param->window_h_, pooling_param->input_h_ - in_h_index);
const int kx_s = 0 > (-in_w_index) ? 0 : (-in_w_index);
int kx_e = MSMIN(win_w, in_w - in_w_index);
int kx_e = MSMIN(pooling_param->window_w_, in_w - in_w_index);
int input_stride = (in_h_index * in_w + in_w_index) * channel + in_batch_offset;
int out_plane_offset = out_batch_offset + index * channel;

@ -47,20 +47,17 @@ void DoSpaceToBatchNHWCInt8(const int8_t *input, int8_t *output, const int *bloc
}
void DoSpaceToBatchPaddingNHWCInt8(const int8_t *input, int8_t *output, SpaceToBatchParameter *param, int32_t zp) {
int *in_shape = param->input_shape_;
int *out_shape = param->output_shape_;
int *paddings = param->paddings_;
int block_shape_h = param->block_sizes_[0];
int block_shape_w = param->m_ == 2 ? param->block_sizes_[1] : 1;
int in_b = in_shape[0];
int in_h = in_shape[1];
int in_w = in_shape[2];
int channel = in_shape[3];
int out_h = out_shape[1];
int out_w = out_shape[2];
int pad_t = paddings[0];
int pad_l = param->m_ == 2 ? paddings[2] : 0;
for (int i = 0; i < out_shape[0]; ++i) {
int in_b = param->input_shape_[0];
int in_h = param->input_shape_[1];
int in_w = param->input_shape_[2];
int channel = param->input_shape_[3];
int out_h = param->output_shape_[1];
int out_w = param->output_shape_[2];
int pad_t = param->paddings_[0];
int pad_l = param->m_ == 2 ? param->paddings_[2] : 0;
for (int i = 0; i < param->output_shape_[0]; ++i) {
int in_batch = i % in_b;
int offset_w = (i / in_b) % block_shape_w;
int offset_h = (i / in_b) / block_shape_w;

@ -219,24 +219,19 @@ void Im2ColPackUnitFp32(const float *input_data, ConvParameter *conv_param, floa
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int kernel_plane = kernel_h * kernel_w;
int stride_h = conv_param->stride_h_;
int stride_w = conv_param->stride_w_;
int pad_h = conv_param->pad_u_;
int pad_w = conv_param->pad_l_;
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
int in_channel = conv_param->input_channel_;
int in_h = conv_param->input_h_;
int in_w = conv_param->input_w_;
int out_w = conv_param->output_w_;
for (int i = 0; i < real_cal_num; i++) {
int block_start = block_index + i;
int input_h = block_start / out_w * stride_h - pad_h;
int input_w = block_start % out_w * stride_w - pad_w;
int input_h = block_start / out_w * conv_param->stride_h_ - conv_param->pad_u_;
int input_w = block_start % out_w * conv_param->stride_w_ - conv_param->pad_l_;
int input_stride = (input_h * in_w + input_w) * in_channel;
int kh_s = MSMAX(0, UP_DIV(-input_h, dilation_h));
int kh_e = MSMIN(kernel_h, UP_DIV(in_h - input_h, dilation_h));
int kh_e = MSMIN(kernel_h, UP_DIV(conv_param->input_h_ - input_h, dilation_h));
int kw_s = MSMAX(0, UP_DIV(-input_w, dilation_w));
int kw_e = MSMIN(kernel_w, UP_DIV(in_w - input_w, dilation_w));
if (dilation_w == 1 && dilation_h == 1) {

@ -62,11 +62,46 @@ int MulInt8CPUKernel::Init() {
return ReSize();
}
void MulInt8CPUKernel::CheckSameShapeSize(std::vector<int> in_tensor0_shape, std::vector<int> in_tensor1_shape) {
bool condition1 = in_tensor0_shape[0] == in_tensor1_shape[0];
bool condition2 = in_tensor0_shape[1] == 1;
bool condition3 = in_tensor0_shape[2] == 1;
bool condition4 = in_tensor0_shape[3] == in_tensor1_shape[3];
bool condition5 = in_tensor1_shape[1] == 1;
bool condition6 = in_tensor1_shape[2] == 1;
if (condition1 && condition2 && condition3 && condition4) {
fast_hw_broadcast_ = true;
} else if (condition1 && condition4 && condition5 && condition6) {
fast_hw_broadcast_ = true;
input1_hw_broadcast_ = true;
}
}
void MulInt8CPUKernel::CheckIfFastImpl() {
auto in_tensor0 = in_tensors_.at(0);
auto in_tensor1 = in_tensors_.at(1);
if (in_tensor0->ElementsNum() != in_tensor1->ElementsNum()) {
if (in_tensor0->shape().size() == 4 && in_tensor1->shape().size() == 4) {
CheckSameShapeSize(in_tensor0->shape(), in_tensor1->shape());
} else if (in_tensor0->shape().size() == 1 && in_tensor1->shape().size() == 4) {
if (in_tensor0->ElementsNum() == in_tensor1->shape()[3]) {
fast_hw_broadcast_ = true;
}
} else if (in_tensor0->shape().size() == 4 && in_tensor1->shape().size() == 1) {
if (in_tensor1->ElementsNum() == in_tensor0->shape()[3]) {
fast_hw_broadcast_ = true;
input1_hw_broadcast_ = true;
}
}
}
}
int MulInt8CPUKernel::ReSize() {
size_t input0_size = in_tensors_.at(0)->shape().size();
size_t input1_size = in_tensors_.at(1)->shape().size();
size_t output_size = out_tensors_.at(0)->shape().size();
tile_para->ndim_ = output_size;
if (input0_size == input1_size) {
for (size_t i = 0; i < output_size; i++) {
tile_para->in_shape0_[i] = in_tensors_.at(0)->DimensionSize(i);
@ -106,6 +141,14 @@ int MulInt8CPUKernel::Run() {
input1_data_ = static_cast<int8_t *>(in_tensors_.at(1)->MutableData());
output_data_ = static_cast<int8_t *>(out_tensors_.at(0)->MutableData());
CheckIfFastImpl();
// can implement fast broadcast mul
if (fast_hw_broadcast_) {
elements_num_ = out_tensors_.front()->Batch() * out_tensors_.front()->Height() * out_tensors_.front()->Width();
count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_;
return ParallelLaunch(this->context_->thread_pool_, FastHWBroadcatMulInt8Run, this, thread_count_);
}
elements_num_ = out_tensors_.at(0)->ElementsNum();
count_unit_ = thread_count_ > 1 ? UP_DIV(elements_num_, thread_count_) : elements_num_;
if (in_tensors_.at(0)->ElementsNum() != in_tensors_.at(1)->ElementsNum()) {
@ -132,12 +175,36 @@ int MulInt8CPUKernel::Run() {
return ret;
}
int FastHWBroadcatMulInt8Run(void *cdata, int task_id) {
auto mul = reinterpret_cast<MulInt8CPUKernel *>(cdata);
mul->FastDoExecute(task_id);
return lite::RET_OK;
}
int MulInt8Run(void *cdata, int task_id) {
auto mul = reinterpret_cast<MulInt8CPUKernel *>(cdata);
mul->DoExecute(task_id);
return lite::RET_OK;
}
int MulInt8CPUKernel::FastDoExecute(int task_id) {
int depth = out_tensors_.front()->Channel();
int64_t real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_);
if (real_dst_count <= 0) {
return lite::RET_OK;
}
int8_t *cur_input0_data = input0_data_;
int8_t *cur_input1_data = input1_data_ + task_id * count_unit_ * depth;
int8_t *cur_output_data = output_data_ + task_id * count_unit_ * depth;
if (input1_hw_broadcast_) {
cur_input0_data = input1_data_;
cur_input1_data = input0_data_ + task_id * count_unit_ * depth;
}
FastMul(cur_input0_data, cur_input1_data, cur_output_data, depth, real_dst_count, input1_hw_broadcast_,
para_.mul_quant_arg_);
return RET_OK;
}
int MulInt8CPUKernel::DoExecute(int task_id) {
int64_t real_dst_count = MSMIN(elements_num_ - task_id * count_unit_, count_unit_);
if (real_dst_count <= 0) {

@ -35,13 +35,18 @@ class MulInt8CPUKernel : public LiteKernel {
int Init() override;
int ReSize() override;
void CheckSameShapeSize(std::vector<int> in_tensor0_shape, std::vector<int> in_tensor1_shape);
void CheckIfFastImpl();
int Run() override;
int DoExecute(int task_id);
int FastDoExecute(int task_id);
private:
const lite::InnerContext *ctx_ = nullptr;
ArithmeticParameter *tile_para = nullptr;
MulParameter para_;
bool fast_hw_broadcast_ = false;
bool input1_hw_broadcast_ = false;
int thread_count_ = 1;
int64_t elements_num_ = 0;
int64_t count_unit_ = 0;
@ -51,6 +56,7 @@ class MulInt8CPUKernel : public LiteKernel {
};
int MulInt8Run(void *cdata, int task_id);
int FastHWBroadcatMulInt8Run(void *cdata, int task_id);
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_INT8_MUL_INT8_H_

Loading…
Cancel
Save