|
|
@ -451,118 +451,83 @@ void MaxPoolingOptInt8(const int8_t *input_ptr, int8_t *output_ptr, PoolingParam
|
|
|
|
int output_batch = pooling_param->output_batch_;
|
|
|
|
int output_batch = pooling_param->output_batch_;
|
|
|
|
int out_plane = output_w * output_h;
|
|
|
|
int out_plane = output_w * output_h;
|
|
|
|
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
|
|
|
int out_tile_count = UP_DIV(out_plane, TILE_NUM);
|
|
|
|
int thread_num = out_tile_count < pooling_param->thread_num_ ? out_tile_count : pooling_param->thread_num_;
|
|
|
|
int thread_num = MSMIN(out_tile_count, pooling_param->thread_num_);
|
|
|
|
int c16 = channel / 16;
|
|
|
|
int8_t out_array[MAX_MAXPOOL_SIZE];
|
|
|
|
|
|
|
|
|
|
|
|
for (int batch = 0; batch < output_batch; batch++) {
|
|
|
|
for (int batch = 0; batch < output_batch; batch++) {
|
|
|
|
int in_batch_offset = batch * in_h * in_w * channel;
|
|
|
|
int in_batch_offset = batch * in_h * in_w * channel;
|
|
|
|
int out_batch_offset = batch * output_h * output_w * channel;
|
|
|
|
int out_batch_offset = batch * output_h * output_w * channel;
|
|
|
|
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
|
|
|
|
for (int thread_id = task_id; thread_id < out_tile_count; thread_id += thread_num) {
|
|
|
|
int cal_start_index = thread_id * TILE_NUM;
|
|
|
|
int cal_start_index = thread_id * TILE_NUM;
|
|
|
|
int real_cal_num = (out_plane - cal_start_index) > TILE_NUM ? TILE_NUM : (out_plane - cal_start_index);
|
|
|
|
int real_cal_num = out_plane - cal_start_index;
|
|
|
|
|
|
|
|
real_cal_num = MSMIN(real_cal_num, TILE_NUM);
|
|
|
|
for (int i = 0; i < real_cal_num; i++) {
|
|
|
|
for (int i = 0; i < real_cal_num; i++) {
|
|
|
|
int index = cal_start_index + i;
|
|
|
|
int index = cal_start_index + i;
|
|
|
|
int out_w_index = index % output_w;
|
|
|
|
int out_w_index = index % output_w;
|
|
|
|
int out_h_index = index / output_w;
|
|
|
|
int out_h_index = index / output_w;
|
|
|
|
int in_w_index = out_w_index * stride_w - pad_w;
|
|
|
|
int in_w_index = out_w_index * stride_w - pad_w;
|
|
|
|
int in_h_index = out_h_index * stride_h - pad_h;
|
|
|
|
int in_h_index = out_h_index * stride_h - pad_h;
|
|
|
|
|
|
|
|
int ky_s = 0 > (-in_h_index) ? 0 : (-in_h_index);
|
|
|
|
|
|
|
|
int ky_e = MSMIN(win_h, in_h - in_h_index);
|
|
|
|
|
|
|
|
int kx_s = 0 > (-in_w_index) ? 0 : (-in_w_index);
|
|
|
|
|
|
|
|
int kx_e = MSMIN(win_w, in_w - in_w_index);
|
|
|
|
|
|
|
|
int input_stride = (in_h_index * in_w + in_w_index) * channel + in_batch_offset;
|
|
|
|
int out_plane_offset = out_batch_offset + index * channel;
|
|
|
|
int out_plane_offset = out_batch_offset + index * channel;
|
|
|
|
for (int j = 0; j < c16; j++) {
|
|
|
|
|
|
|
|
int in_channel_offset = in_batch_offset + j * 16;
|
|
|
|
|
|
|
|
int out_channel_offset = out_plane_offset + j * 16;
|
|
|
|
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
|
|
|
|
int8x16_t tmp_max = vdupq_n_s8(INT8_MIN);
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
|
|
|
int8_t tmp_max[16];
|
|
|
|
|
|
|
|
for (int m = 0; m < C16NUM; ++m) {
|
|
|
|
|
|
|
|
tmp_max[m] = INT8_MIN;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
for (int h = 0; h < win_h; h++) {
|
|
|
|
|
|
|
|
for (int w = 0; w < win_w; w++) {
|
|
|
|
|
|
|
|
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
|
|
|
|
|
|
|
|
(in_w_index + w) >= in_w) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
|
|
|
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
|
|
|
|
tmp_max = vmaxq_s8(tmp_max, vld1q_s8(input_ptr + in_offset));
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
|
|
|
for (int k = 0; k < C16NUM; ++k) {
|
|
|
|
|
|
|
|
tmp_max[k] = MaxInt8(tmp_max[k], *(input_ptr + in_offset + k));
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} // win_w loop
|
|
|
|
|
|
|
|
} // win_h loop
|
|
|
|
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
|
|
|
|
vst1q_s8(output_ptr + out_channel_offset, tmp_max);
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
|
|
|
for (int l = 0; l < C16NUM; ++l) {
|
|
|
|
|
|
|
|
*(output_ptr + out_channel_offset + l) = tmp_max[l];
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
} // in_channel loop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// 8 channel
|
|
|
|
int c = 0;
|
|
|
|
int tmp_c = c16 * 16;
|
|
|
|
for (; c < channel; c += MAX_MAXPOOL_SIZE) {
|
|
|
|
int c8 = (channel - c16 * 16) / 8;
|
|
|
|
int real_channel = channel - c;
|
|
|
|
for (int k = 0; k < c8; k++) {
|
|
|
|
real_channel = MSMIN(real_channel, MAX_MAXPOOL_SIZE);
|
|
|
|
int in_channel_offset = in_batch_offset + tmp_c + k * 8;
|
|
|
|
memset(out_array, INT8_MIN, real_channel);
|
|
|
|
int out_channel_offset = out_plane_offset + tmp_c + k * 8;
|
|
|
|
int8_t *out_data = output_ptr + out_plane_offset + c;
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
for (int h = ky_s; h < ky_e; ++h) {
|
|
|
|
int8x8_t tmp_max = vdup_n_s8(INT8_MIN);
|
|
|
|
int in_h_offset = input_stride + h * in_w * channel + c;
|
|
|
|
#else
|
|
|
|
for (int w = kx_s; w < kx_e; ++w) {
|
|
|
|
int8_t tmp_max[8];
|
|
|
|
const int8_t *in_data = input_ptr + in_h_offset + w * channel;
|
|
|
|
for (int m = 0; m < C8NUM; ++m) {
|
|
|
|
int j = 0;
|
|
|
|
tmp_max[m] = INT8_MIN;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
for (int h = 0; h < win_h; h++) {
|
|
|
|
|
|
|
|
for (int w = 0; w < win_w; w++) {
|
|
|
|
|
|
|
|
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
|
|
|
|
|
|
|
|
(in_w_index + w) >= in_w) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
|
|
|
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
tmp_max = vmax_s8(tmp_max, vld1_s8(input_ptr + in_offset));
|
|
|
|
int c16 = real_channel / 16 * 16;
|
|
|
|
#else
|
|
|
|
int c8 = real_channel / 8 * 8;
|
|
|
|
for (int l = 0; l < C8NUM; ++l) {
|
|
|
|
for (; j < c16; j += 16) {
|
|
|
|
tmp_max[l] = MaxInt8(tmp_max[l], *(input_ptr + in_offset + l));
|
|
|
|
int8x16_t ori_in = vld1q_s8(in_data);
|
|
|
|
}
|
|
|
|
int8x16_t out_array16 = vld1q_s8(out_array + j);
|
|
|
|
|
|
|
|
in_data += 16;
|
|
|
|
|
|
|
|
out_array16 = vmaxq_s8(ori_in, out_array16);
|
|
|
|
|
|
|
|
vst1q_s8(out_array + j, out_array16);
|
|
|
|
|
|
|
|
} // 16 channel loop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (; j < c8; j += 8) {
|
|
|
|
|
|
|
|
int8x8_t ori_in = vld1_s8(in_data);
|
|
|
|
|
|
|
|
int8x8_t out_array8 = vld1_s8(out_array + j);
|
|
|
|
|
|
|
|
in_data += 8;
|
|
|
|
|
|
|
|
out_array8 = vmax_s8(ori_in, out_array8);
|
|
|
|
|
|
|
|
vst1_s8(out_array + j, out_array8);
|
|
|
|
|
|
|
|
} // 8 channel loop
|
|
|
|
#endif
|
|
|
|
#endif
|
|
|
|
|
|
|
|
for (; j < real_channel; ++j) {
|
|
|
|
|
|
|
|
out_array[j] = out_array[j] > in_data[j] ? out_array[j] : in_data[j];
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // win_w loop
|
|
|
|
} // kw loop
|
|
|
|
} // win_h loop
|
|
|
|
} // kh loop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int j = 0;
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
#ifdef ENABLE_NEON
|
|
|
|
vst1_s8(output_ptr + out_channel_offset, tmp_max);
|
|
|
|
int c16 = real_channel / 16 * 16;
|
|
|
|
#else
|
|
|
|
int c8 = real_channel / 8 * 8;
|
|
|
|
for (int l = 0; l < C8NUM; ++l) {
|
|
|
|
for (; j < c16; j += 16) {
|
|
|
|
*(output_ptr + out_channel_offset + l) = tmp_max[l];
|
|
|
|
vst1q_s8(out_data, vld1q_s8(out_array + j));
|
|
|
|
}
|
|
|
|
out_data += 16;
|
|
|
|
#endif
|
|
|
|
} // 16 channel loop
|
|
|
|
} // 8 channel loop
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// res channel
|
|
|
|
for (; j < c8; j += 8) {
|
|
|
|
int channel_s = c16 * 16 + c8 * 8;
|
|
|
|
vst1_s8(out_data, vld1_s8(out_array + j));
|
|
|
|
for (int k = channel_s; k < channel; k++) {
|
|
|
|
out_data += 8;
|
|
|
|
int in_channel_offset = in_batch_offset + k;
|
|
|
|
} // 8 channel loop
|
|
|
|
int out_channel_offset = out_plane_offset + k;
|
|
|
|
#endif
|
|
|
|
int8_t tmp_max = INT8_MIN;
|
|
|
|
for (; j < real_channel; ++j) {
|
|
|
|
for (int h = 0; h < win_h; h++) {
|
|
|
|
out_data[j] = out_array[j];
|
|
|
|
for (int w = 0; w < win_w; w++) {
|
|
|
|
|
|
|
|
if ((in_h_index + h) < 0 || (in_h_index + h) >= in_h || (in_w_index + w) < 0 ||
|
|
|
|
|
|
|
|
(in_w_index + w) >= in_w) {
|
|
|
|
|
|
|
|
continue;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
int in_offset = in_channel_offset + ((in_h_index + h) * in_w + in_w_index + w) * channel;
|
|
|
|
|
|
|
|
tmp_max = MaxInt8(tmp_max, *(input_ptr + in_offset));
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} // win_w loop
|
|
|
|
} // 256 channel loop
|
|
|
|
} // win_h loop
|
|
|
|
|
|
|
|
*(output_ptr + out_channel_offset) = tmp_max;
|
|
|
|
|
|
|
|
} // channel_res loop
|
|
|
|
|
|
|
|
} // out_plane loop
|
|
|
|
} // out_plane loop
|
|
|
|
} // out_batch loop
|
|
|
|
} // out_batch loop
|
|
|
|
}
|
|
|
|
}
|
|
|
|