diff --git a/mindspore/lite/nnacl/int8/conv_int8.c b/mindspore/lite/nnacl/int8/conv_int8.c index e36da668fd..9c0b2fc7fb 100644 --- a/mindspore/lite/nnacl/int8/conv_int8.c +++ b/mindspore/lite/nnacl/int8/conv_int8.c @@ -264,7 +264,8 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c int output_tile_count = UP_DIV(output_count, tile_n); int ic4 = UP_DIV(in_channel, C4NUM); int kernel_plane = kernel_h * kernel_w; - int unit_size = kernel_plane * ic4 * C4NUM; + int plane_block = UP_DIV(kernel_plane, C4NUM); + int unit_size = plane_block * C4NUM * ic4 * C4NUM; int packed_input_size = output_tile_count * tile_n * unit_size; int input_sum_offset; if (conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL) { diff --git a/mindspore/lite/nnacl/int8/pooling_int8.c b/mindspore/lite/nnacl/int8/pooling_int8.c index 540d43deed..f27332776a 100644 --- a/mindspore/lite/nnacl/int8/pooling_int8.c +++ b/mindspore/lite/nnacl/int8/pooling_int8.c @@ -89,8 +89,13 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam int output_batch = pooling_param->output_batch_; int out_plane = output_w * output_h; int out_tile_count = UP_DIV(out_plane, TILE_NUM); - int thread_num = pooling_param->thread_num_; - int c8 = UP_DIV(channel, C8NUM); + int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_; + float input_scale = pooling_param->quant_args_[0][0].scale_; + int input_zp = pooling_param->quant_args_[0][0].zp_; + float output_scale = pooling_param->quant_args_[1][0].scale_; + int output_zp = pooling_param->quant_args_[1][0].zp_; + double real_multiplier = input_scale / output_scale; + int c16 = channel / C16NUM; const int8_t out_min = INT8_MIN; const int8_t out_max = INT8_MAX; @@ -107,89 +112,159 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam int in_w_index = out_w_index * stride_w - pad_w; int in_h_index = out_h_index * stride_h - pad_h; int out_plane_offset = out_batch_offset + index * channel; - for (int j = 0; j < c8 - 1; j++) { - int in_channel_offset = in_batch_offset + j * C8NUM; - int out_channel_offset = out_plane_offset + j * C8NUM; - int16_t tmp_avg1 = 0; - int16_t tmp_avg2 = 0; - int16_t tmp_avg3 = 0; - int16_t tmp_avg4 = 0; - int16_t tmp_avg5 = 0; - int16_t tmp_avg6 = 0; - int16_t tmp_avg7 = 0; - int16_t tmp_avg8 = 0; - int real_count = 0; - for (int h = 0; h < win_h; h++) { - for (int w = 0; w < win_w; w++) { - if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || - (in_w_index + w) >= in_w) { - continue; - } else { - int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; - tmp_avg1 += *(input_ptr + in_offset); - tmp_avg2 += *(input_ptr + in_offset + 1); - tmp_avg3 += *(input_ptr + in_offset + 2); - tmp_avg4 += *(input_ptr + in_offset + 3); - tmp_avg5 += *(input_ptr + in_offset + 4); - tmp_avg6 += *(input_ptr + in_offset + 5); - tmp_avg7 += *(input_ptr + in_offset + 6); - tmp_avg8 += *(input_ptr + in_offset + 7); - ++real_count; + int input_stride = (in_h_index * in_w + in_w_index) * channel; + int kw_s = MSMAX(0, -in_w_index); + int kw_e = MSMIN(win_w, in_w - in_w_index); + int kh_s = MSMAX(0, -in_h_index); + int kh_e = MSMIN(win_h, in_h - in_h_index); + int real_count = (kw_e - kw_s) * (kh_e - kh_s); + + // 16 channels + for (int j = 0; j < c16; j++) { +#ifdef ENABLE_NEON + int16x8_t tmp_avg[2]; + tmp_avg[0] = vmovq_n_s16(0); + tmp_avg[1] = vmovq_n_s16(0); +#else + int16_t tmp_avg[16]; + int16_t real_out[16]; + for (int m = 0; m < C16NUM; ++m) { + tmp_avg[m] = 0; + } +#endif + int in_channel_offset = in_batch_offset + j * C16NUM; + int out_channel_offset = out_plane_offset + j * C16NUM; + + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; +#ifdef ENABLE_NEON + int8x16_t in_ptr = vld1q_s8(input_ptr + in_offset); + int8x8_t in_data1 = vget_low_s8(in_ptr); + int8x8_t in_data2 = vget_high_s8(in_ptr); + int16x8_t data1 = vmovl_s8(in_data1); + int16x8_t data2 = vmovl_s8(in_data2); + tmp_avg[0] = vaddq_s16(tmp_avg[0], data1); + tmp_avg[1] = vaddq_s16(tmp_avg[1], data2); +#else + for (int k = 0; k < C16NUM; ++k) { + tmp_avg[k] += input_ptr[in_offset + k]; } +#endif } // win_w loop } // win_h loop - int16_t tmp_out1 = round((float)tmp_avg1 / (float)real_count); - int16_t tmp_out2 = round((float)tmp_avg2 / (float)real_count); - int16_t tmp_out3 = round((float)tmp_avg3 / (float)real_count); - int16_t tmp_out4 = round((float)tmp_avg4 / (float)real_count); - int16_t tmp_out5 = round((float)tmp_avg5 / (float)real_count); - int16_t tmp_out6 = round((float)tmp_avg6 / (float)real_count); - int16_t tmp_out7 = round((float)tmp_avg7 / (float)real_count); - int16_t tmp_out8 = round((float)tmp_avg8 / (float)real_count); - int16_t real_out1 = tmp_out1 < out_min ? out_min : tmp_out1; - int16_t real_out2 = tmp_out2 < out_min ? out_min : tmp_out2; - int16_t real_out3 = tmp_out3 < out_min ? out_min : tmp_out3; - int16_t real_out4 = tmp_out4 < out_min ? out_min : tmp_out4; - int16_t real_out5 = tmp_out5 < out_min ? out_min : tmp_out5; - int16_t real_out6 = tmp_out6 < out_min ? out_min : tmp_out6; - int16_t real_out7 = tmp_out7 < out_min ? out_min : tmp_out7; - int16_t real_out8 = tmp_out8 < out_min ? out_min : tmp_out8; - real_out1 = real_out1 > out_max ? out_max : real_out1; - real_out2 = real_out2 > out_max ? out_max : real_out2; - real_out3 = real_out3 > out_max ? out_max : real_out3; - real_out4 = real_out4 > out_max ? out_max : real_out4; - real_out5 = real_out5 > out_max ? out_max : real_out5; - real_out6 = real_out6 > out_max ? out_max : real_out6; - real_out7 = real_out7 > out_max ? out_max : real_out7; - real_out8 = real_out8 > out_max ? out_max : real_out8; - *(output_ptr + out_channel_offset) = (int8_t)real_out1; - *(output_ptr + out_channel_offset + 1) = (int8_t)real_out2; - *(output_ptr + out_channel_offset + 2) = (int8_t)real_out3; - *(output_ptr + out_channel_offset + 3) = (int8_t)real_out4; - *(output_ptr + out_channel_offset + 4) = (int8_t)real_out5; - *(output_ptr + out_channel_offset + 5) = (int8_t)real_out6; - *(output_ptr + out_channel_offset + 6) = (int8_t)real_out7; - *(output_ptr + out_channel_offset + 7) = (int8_t)real_out8; - } // in_channel loop - int channel_s = (c8 - 1) * C8NUM; - for (int k = channel_s; k < channel; k++) { - int in_channel_offset = in_batch_offset + k; - int out_channel_offset = out_plane_offset + k; - int16_t tmp_avg = 0; - int real_count = 0; - for (int h = 0; h < win_h; h++) { - for (int w = 0; w < win_w; w++) { - if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || - (in_w_index + w) >= in_w) { - continue; - } else { - int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; - tmp_avg += *(input_ptr + in_offset); - ++real_count; +#ifdef ENABLE_NEON + int16_t tmp_data[8]; + int16_t tmp_out[8]; + int16_t tmp_data1[8]; + int16_t tmp_out1[8]; + for (int l = 0; l < C8NUM; l++) { + tmp_data[l] = tmp_avg[0][l] + 128 * real_count; + tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count; + tmp_out[l] -= 128; + tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp; + } + for (int l = 0; l < C8NUM; l++) { + tmp_data1[l] = tmp_avg[1][l] + 128 * real_count; + tmp_out1[l] = (tmp_data1[l] + real_count / 2) / real_count; + tmp_out1[l] -= 128; + tmp_out1[l] = round((tmp_out1[l] - input_zp) * real_multiplier) + output_zp; + } + int8x8_t real_out[2]; + int8x8_t output_min = vdup_n_s8(out_min); + int8x8_t output_max = vdup_n_s8(out_max); + real_out[0] = vqmovn_s16(vld1q_s16(tmp_out)); + real_out[0] = vmin_s8(real_out[0], output_max); + real_out[0] = vmax_s8(real_out[0], output_min); + vst1_s8(output_ptr + out_channel_offset, real_out[0]); + real_out[1] = vqmovn_s16(vld1q_s16(tmp_out1)); + real_out[1] = vmin_s8(real_out[1], output_max); + real_out[1] = vmax_s8(real_out[1], output_min); + vst1_s8(output_ptr + out_channel_offset + 8, real_out[1]); +#else + for (int l = 0; l < C16NUM; ++l) { + int16_t tmp_data = tmp_avg[l] + 128 * real_count; + real_out[l] = (tmp_data + real_count / 2) / real_count - 128; + real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp); + real_out[l] = real_out[l] < out_min ? out_min : real_out[l]; + real_out[l] = real_out[l] > out_max ? out_max : real_out[l]; + *(output_ptr + out_channel_offset + l) = (int8_t)real_out[l]; + } +#endif + } + + // 8 channels + int channel_16_res = channel - c16 * C16NUM; + int c8 = channel_16_res / C8NUM; + int in_c16_offset = in_batch_offset + c16 * C16NUM; + int out_c16_offset = out_plane_offset + c16 * C16NUM; + for (int j = 0; j < c8; j++) { +#ifdef ENABLE_NEON + int16x8_t tmp_avg = vmovq_n_s16(0); +#else + int16_t tmp_avg[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + int16_t real_out[8]; +#endif + int in_channel_offset = in_c16_offset + j * C8NUM; + int out_channel_offset = out_c16_offset + j * C8NUM; + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; +#ifdef ENABLE_NEON + int8x8_t in_ptr = vld1_s8(input_ptr + in_offset); + int16x8_t data = vmovl_s8(in_ptr); + tmp_avg = vaddq_s16(tmp_avg, data); +#else + for (int k = 0; k < C8NUM; ++k) { + tmp_avg[k] += input_ptr[in_offset + k]; } +#endif } // win_w loop } // win_h loop - int16_t tmp_out = round((float)tmp_avg / (float)real_count); +#ifdef ENABLE_NEON + int16_t tmp_data[8]; + int16_t tmp_out[8]; + for (int l = 0; l < C8NUM; l++) { + tmp_data[l] = tmp_avg[l] + 128 * real_count; + tmp_out[l] = (tmp_data[l] + real_count / 2) / real_count; + tmp_out[l] -= 128; + tmp_out[l] = round((tmp_out[l] - input_zp) * real_multiplier) + output_zp; + } + int8x8_t real_out; + int8x8_t output_min = vdup_n_s8(out_min); + int8x8_t output_max = vdup_n_s8(out_max); + real_out = vqmovn_s16(vld1q_s16(tmp_out)); + real_out = vmin_s8(real_out, output_max); + real_out = vmax_s8(real_out, output_min); + vst1_s8(output_ptr + out_channel_offset, real_out); +#else + for (int l = 0; l < C8NUM; ++l) { + int16_t tmp_data = tmp_avg[l] + 128 * real_count; + real_out[l] = (tmp_data + real_count / 2) / real_count - 128; + real_out[l] = (int8_t)(round((real_out[l] - input_zp) * real_multiplier) + output_zp); + real_out[l] = real_out[l] < out_min ? out_min : real_out[l]; + real_out[l] = real_out[l] > out_max ? out_max : real_out[l]; + *(output_ptr + out_channel_offset + l) = (int8_t)real_out[l]; + } +#endif + } + + // less than 8 channel + int channel_8_res = channel_16_res - c8 * C8NUM; + int in_c8_offset = in_c16_offset + c8 * C8NUM; + int out_c8_offset = out_c16_offset + c8 * C8NUM; + for (int k = 0; k < channel_8_res; k++) { + int in_channel_offset = in_c8_offset + k; + int out_channel_offset = out_c8_offset + k; + int16_t tmp_avg = 0; + for (int h = kh_s; h < kh_e; h++) { + for (int w = kw_s; w < kw_e; w++) { + int in_offset = in_channel_offset + input_stride + (h * in_w + w) * channel; + tmp_avg += input_ptr[in_offset]; + } // win_w loop + } // win_h loop + int16_t tmp_out = round((float)tmp_avg / (float)real_count + 128) - 128; + tmp_out = (int8_t)(round((tmp_out - input_zp) * real_multiplier) + output_zp); int16_t real_out = tmp_out < out_min ? out_min : tmp_out; real_out = real_out > out_max ? out_max : real_out; *(output_ptr + out_channel_offset) = (int8_t)real_out; @@ -249,6 +324,109 @@ void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParamete } // out_batch loop } +void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, + int task_id) { + int stride_w = pooling_param->stride_w_; + int stride_h = pooling_param->stride_h_; + int pad_w = pooling_param->pad_l_; + int pad_h = pooling_param->pad_u_; + int win_w = pooling_param->window_w_; + int win_h = pooling_param->window_h_; + int channel = pooling_param->input_channel_; + int in_w = pooling_param->input_w_; + int in_h = pooling_param->input_h_; + int output_w = pooling_param->output_w_; + int output_h = pooling_param->output_h_; + int output_batch = pooling_param->output_batch_; + int out_plane = output_w * output_h; + int out_tile_count = UP_DIV(out_plane, TILE_NUM); + int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_; + int c16 = UP_DIV(channel, 16); + // input channel is equal to output channel + float input_scale = pooling_param->quant_args_[0][0].scale_; + int input_zp = pooling_param->quant_args_[0][0].zp_; + float output_scale = pooling_param->quant_args_[1][0].scale_; + int output_zp = pooling_param->quant_args_[1][0].zp_; + double real_multiplier = input_scale / output_scale; + + for (int batch = 0; batch < output_batch; batch++) { + int in_batch_offset = batch * in_h * in_w * channel; + int out_batch_offset = batch * output_h * output_w * channel; + for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) { + int cal_start_index = thread_id * TILE_NUM; + int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index); + for (int i = 0; i < real_cal_num; i++) { + int index = cal_start_index + i; + int out_w_index = index % output_w; + int out_h_index = index / output_w; + int in_w_index = out_w_index * stride_w - pad_w; + int in_h_index = out_h_index * stride_h - pad_h; + int out_plane_offset = out_batch_offset + index * channel; + for (int j = 0; j < c16 - 1; j++) { + int in_channel_offset = in_batch_offset + j * 16; + int out_channel_offset = out_plane_offset + j * 16; +#ifdef ENABLE_NEON + int8x16_t tmp_max = vdupq_n_s8(INT8_MIN); +#else + int8_t tmp_max[16]; + for (int m = 0; m < C16NUM; ++m) { + tmp_max[m] = INT8_MIN; + } +#endif + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; +#ifdef ENABLE_NEON + tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset)); +#else + for (int k = 0; k < C16NUM; ++k) { + tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k)); + } +#endif + } + } // win_w loop + } // win_h loop +#ifdef ENABLE_NEON + for (int l = 0; l < C16NUM; ++l) { + tmp_max[l] = (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp); + } + vst1q_s8(output_ptr + out_channel_offset, tmp_max); +#else + for (int l = 0; l < C16NUM; ++l) { + *(output_ptr + out_channel_offset + l) = + (int8_t)(round((tmp_max[l] - input_zp) * real_multiplier) + output_zp); + } +#endif + } // in_channel loop + + // res channel + int channel_s = (c16 - 1) * 16; + for (int k = channel_s; k < channel; k++) { + int in_channel_offset = in_batch_offset + k; + int out_channel_offset = out_plane_offset + k; + int8_t tmp_max = INT8_MIN; + for (int h = 0; h < win_h; h++) { + for (int w = 0; w < win_w; w++) { + if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 || + (in_w_index + w) >= in_w) { + continue; + } else { + int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel; + tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset)); + } + } // win_w loop + } // win_h loop + *(output_ptr + out_channel_offset) = (int8_t)(round((tmp_max - input_zp) * real_multiplier) + output_zp); + } // channel_res loop + } // out_plane loop + } // out_batch loop + } +} + void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id) { int stride_w = pooling_param->stride_w_; int stride_h = pooling_param->stride_h_; @@ -264,7 +442,7 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam int output_batch = pooling_param->output_batch_; int out_plane = output_w * output_h; int out_tile_count = UP_DIV(out_plane, TILE_NUM); - int thread_num = pooling_param->thread_num_; + int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_; int c16 = UP_DIV(channel, 16); for (int batch = 0; batch < output_batch; batch++) { @@ -286,22 +464,10 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam #ifdef ENABLE_NEON int8x16_t tmp_max = vdupq_n_s8(INT8_MIN); #else - int8_t tmp_max1 = INT8_MIN; - int8_t tmp_max2 = INT8_MIN; - int8_t tmp_max3 = INT8_MIN; - int8_t tmp_max4 = INT8_MIN; - int8_t tmp_max5 = INT8_MIN; - int8_t tmp_max6 = INT8_MIN; - int8_t tmp_max7 = INT8_MIN; - int8_t tmp_max8 = INT8_MIN; - int8_t tmp_max9 = INT8_MIN; - int8_t tmp_max10 = INT8_MIN; - int8_t tmp_max11 = INT8_MIN; - int8_t tmp_max12 = INT8_MIN; - int8_t tmp_max13 = INT8_MIN; - int8_t tmp_max14 = INT8_MIN; - int8_t tmp_max15 = INT8_MIN; - int8_t tmp_max16 = INT8_MIN; + int8_t tmp_max[16]; + for (int m = 0; m < C16NUM; ++m) { + tmp_max[m] = INT8_MIN; + } #endif for (int h = 0; h < win_h; h++) { for (int w = 0; w < win_w; w++) { @@ -313,22 +479,9 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam #ifdef ENABLE_NEON tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset)); #else - tmp_max1 = MaxInt8(tmp_max1, *(input_ptr + in_offset)); - tmp_max2 = MaxInt8(tmp_max2, *(input_ptr + in_offset + 1)); - tmp_max3 = MaxInt8(tmp_max3, *(input_ptr + in_offset + 2)); - tmp_max4 = MaxInt8(tmp_max4, *(input_ptr + in_offset + 3)); - tmp_max5 = MaxInt8(tmp_max5, *(input_ptr + in_offset + 4)); - tmp_max6 = MaxInt8(tmp_max6, *(input_ptr + in_offset + 5)); - tmp_max7 = MaxInt8(tmp_max7, *(input_ptr + in_offset + 6)); - tmp_max8 = MaxInt8(tmp_max8, *(input_ptr + in_offset + 7)); - tmp_max9 = MaxInt8(tmp_max9, *(input_ptr + in_offset + 8)); - tmp_max10 = MaxInt8(tmp_max10, *(input_ptr + in_offset + 9)); - tmp_max11 = MaxInt8(tmp_max11, *(input_ptr + in_offset + 10)); - tmp_max12 = MaxInt8(tmp_max12, *(input_ptr + in_offset + 11)); - tmp_max13 = MaxInt8(tmp_max13, *(input_ptr + in_offset + 12)); - tmp_max14 = MaxInt8(tmp_max14, *(input_ptr + in_offset + 13)); - tmp_max15 = MaxInt8(tmp_max15, *(input_ptr + in_offset + 14)); - tmp_max16 = MaxInt8(tmp_max16, *(input_ptr + in_offset + 15)); + for (int k = 0; k < C16NUM; ++k) { + tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k)); + } #endif } } // win_w loop @@ -336,24 +489,13 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam #ifdef ENABLE_NEON vst1q_s8(output_ptr + out_channel_offset, tmp_max); #else - *(output_ptr + out_channel_offset) = tmp_max1; - *(output_ptr + out_channel_offset + 1) = tmp_max2; - *(output_ptr + out_channel_offset + 2) = tmp_max3; - *(output_ptr + out_channel_offset + 3) = tmp_max4; - *(output_ptr + out_channel_offset + 4) = tmp_max5; - *(output_ptr + out_channel_offset + 5) = tmp_max6; - *(output_ptr + out_channel_offset + 6) = tmp_max7; - *(output_ptr + out_channel_offset + 7) = tmp_max8; - *(output_ptr + out_channel_offset + 8) = tmp_max9; - *(output_ptr + out_channel_offset + 9) = tmp_max10; - *(output_ptr + out_channel_offset + 10) = tmp_max11; - *(output_ptr + out_channel_offset + 11) = tmp_max12; - *(output_ptr + out_channel_offset + 12) = tmp_max13; - *(output_ptr + out_channel_offset + 13) = tmp_max14; - *(output_ptr + out_channel_offset + 14) = tmp_max15; - *(output_ptr + out_channel_offset + 15) = tmp_max16; + for (int l = 0; l < C16NUM; ++l) { + *(output_ptr + out_channel_offset + l) = tmp_max[l]; + } #endif } // in_channel loop + + // res channel int channel_s = (c16 - 1) * 16; for (int k = channel_s; k < channel; k++) { int in_channel_offset = in_batch_offset + k; diff --git a/mindspore/lite/nnacl/int8/pooling_int8.h b/mindspore/lite/nnacl/int8/pooling_int8.h index 3926f6e682..498ad36774 100644 --- a/mindspore/lite/nnacl/int8/pooling_int8.h +++ b/mindspore/lite/nnacl/int8/pooling_int8.h @@ -32,6 +32,8 @@ void AvgPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam void MaxPoolingInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); +void MaxPoolingWithQuantInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); + void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParameter *pooling_param, int task_id); #ifdef __cplusplus } diff --git a/mindspore/lite/nnacl/pooling_parameter.h b/mindspore/lite/nnacl/pooling_parameter.h index 644d3d4c47..e3d7a239db 100644 --- a/mindspore/lite/nnacl/pooling_parameter.h +++ b/mindspore/lite/nnacl/pooling_parameter.h @@ -19,14 +19,16 @@ #include "nnacl/op_base.h" #include "nnacl/quantization/quantize.h" +typedef enum PoolMode { PoolMode_No, PoolMode_MaxPool, PoolMode_AvgPool } PoolMode; + +typedef enum RoundMode { RoundMode_No, RoundMode_Ceil, RoundMode_Floor } RoundMode; + typedef struct PoolingParameter { OpParameter op_parameter_; + PoolMode pool_mode_; + RoundMode round_mode_; + ActType act_type_; QuantArg **quant_args_; - bool global_; - bool max_pooling_; - bool avg_pooling_; - bool round_ceil_; - bool round_floor_; int window_w_; int window_h_; int input_w_; @@ -44,7 +46,8 @@ typedef struct PoolingParameter { int stride_w_; int stride_h_; int thread_num_; - ActType act_type_; + bool global_; + bool quantize_; } PoolingParameter; #endif // MINDSPORE_LITE_NNACL_POOLING_PARAMETER_H_ diff --git a/mindspore/lite/src/populate_parameter.cc b/mindspore/lite/src/populate_parameter.cc index 31c43fe55c..e3530025d0 100644 --- a/mindspore/lite/src/populate_parameter.cc +++ b/mindspore/lite/src/populate_parameter.cc @@ -294,32 +294,26 @@ OpParameter *PopulatePoolingParameter(const mindspore::lite::PrimitiveC *primiti auto pool_mode = pooling_primitive->GetPoolingMode(); switch (pool_mode) { case schema::PoolMode_MAX_POOLING: - pooling_param->max_pooling_ = true; - pooling_param->avg_pooling_ = false; + pooling_param->pool_mode_ = PoolMode_MaxPool; break; case schema::PoolMode_MEAN_POOLING: - pooling_param->max_pooling_ = false; - pooling_param->avg_pooling_ = true; + pooling_param->pool_mode_ = PoolMode_AvgPool; break; default: - pooling_param->max_pooling_ = false; - pooling_param->avg_pooling_ = false; + pooling_param->pool_mode_ = PoolMode_No; break; } auto round_mode = pooling_primitive->GetRoundMode(); switch (round_mode) { case schema::RoundMode_FLOOR: - pooling_param->round_floor_ = true; - pooling_param->round_ceil_ = false; + pooling_param->round_mode_ = RoundMode_Floor; break; case schema::RoundMode_CEIL: - pooling_param->round_floor_ = false; - pooling_param->round_ceil_ = true; + pooling_param->round_mode_ = RoundMode_Ceil; break; default: - pooling_param->round_floor_ = false; - pooling_param->round_ceil_ = false; + pooling_param->round_mode_ = RoundMode_No; break; } diff --git a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc index c16eea8dc4..472a86a574 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/pooling_base.cc @@ -42,6 +42,12 @@ int PoolingBaseCPUKernel::SetQuantParam() { pooling_quant_arg_[1][0].scale_ = out_quant_arg.front().scale; pooling_quant_arg_[1][0].zp_ = out_quant_arg.front().zeroPoint; pooling_param_->quant_args_ = pooling_quant_arg_; + if (pooling_quant_arg_[0][0].scale_ == pooling_quant_arg_[1][0].scale_ && + pooling_quant_arg_[0][0].zp_ == pooling_quant_arg_[1][0].zp_) { + pooling_param_->quantize_ = false; + } else { + pooling_param_->quantize_ = true; + } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.cc index 041116bed2..80116acb65 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/pooling_fp16.cc @@ -53,7 +53,7 @@ int PoolingFp16CPUKernel::ReSize() { } int PoolingFp16CPUKernel::RunImpl(int task_id) { - if (pooling_param_->max_pooling_) { + if (pooling_param_->pool_mode_ == PoolMode_MaxPool) { MaxPoolingFp16(fp16_input_, fp16_output_, pooling_param_, task_id); } else { AvgPoolingFp16(fp16_input_, fp16_output_, pooling_param_, task_id); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc index 1bcfcaff33..2891ed0494 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/pooling.cc @@ -52,7 +52,7 @@ int PoolingCPUKernel::ReSize() { int PoolingCPUKernel::RunImpl(int task_id) { auto input_ptr = reinterpret_cast(in_tensors_.at(kInputIndex)->Data()); auto output_ptr = reinterpret_cast(out_tensors_.at(kOutputIndex)->Data()); - if (pooling_param_->max_pooling_) { + if (pooling_param_->pool_mode_ == PoolMode_MaxPool) { switch (pooling_param_->act_type_) { case ActType_Relu: MaxPoolingRelu(input_ptr, output_ptr, pooling_param_, task_id); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc index fda7fc1a21..5348bdb3f3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32_grad/pooling_grad.cc @@ -163,7 +163,7 @@ int PoolingGradCPUKernel::Run() { auto input_ptr = reinterpret_cast(inputs_.at(0)->Data()); auto output_ptr = reinterpret_cast(outputs_.at(0)->Data()); - if (pool_param->max_pooling_) { + if (pool_param->pool_mode_ == PoolMode_MaxPool) { auto ind = reinterpret_cast(inputs_.at(1)->Data()); MaxPoolingGrad(input_ptr, ind, output_ptr, pool_param); } else { diff --git a/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc index 72749cddb3..db8ef66042 100644 --- a/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc +++ b/mindspore/lite/src/runtime/kernel/arm/int8/pooling_int8.cc @@ -61,10 +61,14 @@ int PoolingInt8CPUKernel::ReSize() { int PoolingInt8CPUKernel::RunImpl(int task_id) { auto input_data = reinterpret_cast(in_tensors_.at(kInputIndex)->Data()); auto output_data = reinterpret_cast(out_tensors_.at(kOutputIndex)->Data()); - if (pooling_param_->max_pooling_) { - MaxPoolingInt8(input_data, output_data, pooling_param_, task_id); + if (pooling_param_->pool_mode_ == PoolMode_MaxPool) { + if (pooling_param_->quantize_) { + MaxPoolingWithQuantInt8(input_data, output_data, pooling_param_, task_id); + } else { + MaxPoolingOptInt8(input_data, output_data, pooling_param_, task_id); + } } else { - AvgPoolingInt8(input_data, output_data, pooling_param_, task_id); + AvgPoolingOptInt8(input_data, output_data, pooling_param_, task_id); } return RET_OK; } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc index 65cac2c069..baabd30a72 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/pooling2d.cc @@ -43,13 +43,13 @@ int PoolingOpenCLKernel::Init() { std::string source; std::string program_name; #endif - if (parameter_->max_pooling_) { + if (parameter_->pool_mode_ == PoolMode_MaxPool) { kernel_name = "MaxPooling2d"; #ifndef PROGRAM_WITH_IL source = max_pool2d_source; program_name = "MaxPooling2d"; #endif - } else if (parameter_->avg_pooling_) { + } else if (parameter_->pool_mode_ == PoolMode_AvgPool) { kernel_name = "AvgPooling2d"; #ifndef PROGRAM_WITH_IL source = avg_pool2d_source; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc index e1a9053e3e..22aabd7757 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/arm/fp32_grad/pooling_grad_fp32_tests.cc @@ -26,7 +26,7 @@ #include "nnacl/fp32_grad/pooling_grad.h" namespace mindspore { -class TestPoolingGradFp32 : public mindspore::CommonTest { +class TestPoolingGradFp32 : public mindspore::CommonTest { public: TestPoolingGradFp32() {} }; @@ -161,8 +161,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolingGradFp32) { auto pooling_param = new PoolingParameter(); InitPoolingParamFP32(pooling_param); pooling_param->output_channel_ = 3; - pooling_param->avg_pooling_ = false; - pooling_param->max_pooling_ = true; + pooling_param->pool_mode_ = PoolMode_MaxPool; // runtime part printf("Calculating runtime cost...\n"); uint64_t time_avg = 0; @@ -215,8 +214,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolingKernelGradFp32) { // prepare stage auto maxpool = new PoolingParameter(); InitPoolingParamFP32(maxpool); - maxpool->avg_pooling_ = false; - maxpool->max_pooling_ = true; + maxpool->pool_mode_ = PoolMode_MaxPool; maxpool->input_h_ = 30; maxpool->input_w_ = 30; maxpool->input_channel_ = 3; @@ -268,8 +266,7 @@ TEST_F(TestPoolingGradFp32, MaxPoolingKernelGradFp32) { auto pooling_param = new PoolingParameter(); InitPoolingParamFP32(pooling_param); - pooling_param->avg_pooling_ = false; - pooling_param->max_pooling_ = true; + pooling_param->pool_mode_ = PoolMode_MaxPool; pooling_param->input_h_ = 10; pooling_param->input_w_ = 10; pooling_param->input_channel_ = 3; diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc index 1c8edc5b35..f74cd7cfda 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/avg_pooling_tests.cc @@ -48,8 +48,7 @@ void InitAvgPoolingParam(PoolingParameter *param) { param->pad_l_ = 0; param->pad_r_ = 0; - param->max_pooling_ = false; - param->avg_pooling_ = true; + param->pool_mode_ = PoolMode_AvgPool; } TEST_F(TestAvgPoolingOpenCL, AvgPoolFp32) { diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/max_pooling_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/max_pooling_tests.cc index fce2f582c2..321b1ccd06 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/max_pooling_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/max_pooling_tests.cc @@ -35,8 +35,7 @@ void InitParameter(PoolingParameter *param) { param->pad_d_ = 0; param->pad_l_ = 0; param->pad_r_ = 0; - param->avg_pooling_ = false; - param->max_pooling_ = true; + param->pool_mode_ = PoolMode_MaxPool; } TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) {