apply int8 4x16 kernel

pull/8645/head
lixian 4 years ago
parent 0ab808ec9e
commit 7f9d65cce0

File diff suppressed because it is too large Load Diff

@ -378,11 +378,11 @@ void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *
"b 16f \n"
"10: \n"
"ld1 {v16.h}[0], [x10] \n"
"ld1 {v16.d}[0], [x10] \n"
"b 16f \n"
"11: \n"
"ld1 {v16.h}[0], [x10] \n"
"ld1 {v16.d}[0], [x10] \n"
"add x10, x10, #8 \n"
"ld1 {v16.s}[2], [x10] \n"
"b 16f \n"
@ -802,11 +802,12 @@ void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *i
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func) {
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int *filter_zp) {
int is_per_oc = (int)conv_param->conv_quant_arg_.filter_arg_num_ != 1;
matmul_func(packed_input, packed_weight, dst, row, col, deep4, conv_param->output_channel_, input_sum, bias,
left_shift, right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc);
matmul_func(packed_input, packed_weight, dst, row, col, deep4, col, input_sum, bias, left_shift, right_shift,
multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], is_per_oc,
filter_zp);
return;
}

@ -46,7 +46,7 @@ void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t
int32_t *multiplier, ConvParameter *conv_param);
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func);
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_DP_FUNC matmul_func, int32_t *filter_zp);
void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param);

@ -64,6 +64,21 @@ void MatrixEmptyInt8(int8_t *dst, int row, int col) {
return;
}
void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
int col4 = UP_ROUND(col, C4NUM);
for (int r = 0; r < row; r++) {
int rd16 = r / C16NUM;
int rm16 = r % C16NUM;
for (int c = 0; c < col; c++) {
int cd4 = c / C4NUM;
int cm4 = c % C4NUM;
int dst_index = rd16 * col4 * C16NUM + cd4 * C16NUM * C4NUM + rm16 * C4NUM + cm4;
int src_index = r * col + c;
dst_ptr[dst_index] = src_ptr[src_index];
}
}
}
void RowMajor2Row16x4MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col) {
/* Row-major to row16x4-major (block row-major) */
int col16 = UP_ROUND(col, C16NUM);
@ -268,6 +283,223 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
return;
}
void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
size_t per_channel, int32_t *filter_zp) {
/* row4x4-major * row4x16-major => (int8)row-major */
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
int c16div = c / C16NUM, c16mod = c % C16NUM;
size_t ci = r * col + c;
int32_t value = 0;
for (int d = 0; d < deep_4; d++) {
int d4div = d / C4NUM, d4mod = d % C4NUM;
size_t ai = r4div * deep_4 * C4NUM + d4div * C4NUM * C4NUM + r4mod * C4NUM + d4mod;
size_t bi = c16div * deep_4 * C16NUM + d4div * C16NUM * C4NUM + c16mod * C4NUM + d4mod;
value = value + a[ai] * b[bi];
}
int32_t cur_input_sum = per_channel ? input_sum[r] * filter_zp[c] : input_sum[r];
value -= cur_input_sum;
value += bias[c];
int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0];
int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0];
int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0];
value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp;
value = MSMIN(maxi, value);
value = MSMAX(mini, value);
dst[ci] = (int8_t)value;
}
}
return;
}
void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum,
size_t input_channel, size_t plane_size, int32_t filter_zp) {
int ic4 = UP_ROUND(input_channel, C4NUM);
int hw4 = UP_ROUND(plane_size, C4NUM);
size_t hw_4div = plane_size / C4NUM * C4NUM;
size_t ic_4div = input_channel / C4NUM * C4NUM;
const int8_t *src_r = src_input;
int8_t *pack_r = packed_input;
/* per layer */
for (int hwi = 0; hwi < hw_4div; hwi += C4NUM) {
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
int32_t *input_sum_r = input_sum + hwi;
#ifdef ENABLE_ARM64
size_t src_stride = input_channel;
size_t ic_4res = input_channel - ic_4div;
asm volatile(
"dup v2.4s, wzr \n"
"mov x14, %[input_sum_r] \n"
"dup v3.4s, %w[filter_zp] \n"
"mov x10, %[src_ic] \n"
"mov x11, %[pack_ic] \n"
"mov x15, #0 \n"
"1: \n"
"cmp x15, %[ic_4div] \n"
"add x15, x15, #4\n"
"mov x12, x10 \n"
"add x10, x10, #4\n"
"blt 2f \n"
"cmp %[ic_4res], #0\n"
"beq 6f \n"
"cmp %[ic_4res], #1\n"
"beq 3f \n"
"cmp %[ic_4res], #2\n"
"beq 4f \n"
"cmp %[ic_4res], #3\n"
"beq 5f \n"
"2: \n"
"ld1 {v0.s}[0], [x12], %[src_stride]\n"
"ld1 {v0.s}[1], [x12], %[src_stride]\n"
"ld1 {v0.s}[2], [x12], %[src_stride]\n"
"ld1 {v0.s}[3], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 1b \n"
"3: \n" /* ic res 1 */
"dup v0.4s, wzr \n"
"ld1 {v0.b}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[8], [x12], %[src_stride]\n"
"ld1 {v0.b}[12], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"4: \n" /* ic res 2 */
"dup v0.4s, wzr \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"5: \n" /* ic res 3 */
"dup v0.4s, wzr \n"
"add x13, x12, #2 \n"
"ld1 {v0.h}[0], [x12], %[src_stride]\n"
"ld1 {v0.b}[2], [x13], %[src_stride]\n"
"ld1 {v0.h}[2], [x12], %[src_stride]\n"
"ld1 {v0.b}[6], [x13], %[src_stride]\n"
"ld1 {v0.h}[4], [x12], %[src_stride]\n"
"ld1 {v0.b}[10], [x13], %[src_stride]\n"
"ld1 {v0.h}[6], [x12], %[src_stride]\n"
"ld1 {v0.b}[14], [x13], %[src_stride]\n"
"st1 {v0.16b}, [x11], #16\n"
"saddlp v1.8h, v0.16b \n"
"saddlp v0.4s, v1.8h \n"
"add v2.4s, v2.4s, v0.4s \n"
"b 6f \n"
"6: \n"
"mul v2.4s, v2.4s, v3.4s \n"
"st1 {v2.4s}, [x14], #16 \n"
:
: [ src_ic ] "r"(src_ic), [ pack_ic ] "r"(pack_ic), [ input_sum_r ] "r"(input_sum_r),
[ src_stride ] "r"(src_stride), [ ic_4div ] "r"(ic_4div), [ ic_4res ] "r"(ic_4res), [ filter_zp ] "r"(filter_zp)
: "x10", "x11", "x12", "x13", "x14", "x15", "v0", "v1", "v2", "v3");
#else
int32_t tmp_sum_value[4] = {0};
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
for (int i = 0; i < C4NUM; i++) {
tmp_sum_value[i] += src_ic[0 + i * input_channel];
tmp_sum_value[i] += src_ic[1 + i * input_channel];
tmp_sum_value[i] += src_ic[2 + i * input_channel];
tmp_sum_value[i] += src_ic[3 + i * input_channel];
pack_ic[0 + i * C4NUM] = src_ic[0 + i * input_channel];
pack_ic[1 + i * C4NUM] = src_ic[1 + i * input_channel];
pack_ic[2 + i * C4NUM] = src_ic[2 + i * input_channel];
pack_ic[3 + i * C4NUM] = src_ic[3 + i * input_channel];
}
src_ic += C4NUM;
pack_ic += C4NUM * C4NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
for (int i = 0; i < C4NUM; i++) {
tmp_sum_value[i] += src_ic[i * input_channel];
pack_ic[i * C4NUM] = src_ic[i * input_channel];
}
src_ic += 1;
pack_ic += 1;
}
for (int ici = input_channel; ici < ic4; ici += 1) {
for (int i = 0; i < C4NUM; i++) {
pack_ic[i * C4NUM] = 0;
}
pack_ic += 1;
}
for (int i = 0; i < C4NUM; i++) {
input_sum_r[i] = tmp_sum_value[i] * filter_zp;
}
#endif
src_r += input_channel * C4NUM;
pack_r += ic4 * C4NUM;
}
if (hw_4div != plane_size) {
memset(pack_r, 0, C4NUM * ic4);
for (int hwi = hw_4div; hwi < plane_size; hwi += 1) {
int32_t tmp_sum_value = 0;
const int8_t *src_ic = src_r;
int8_t *pack_ic = pack_r;
for (int ici = 0; ici < ic_4div; ici += C4NUM) {
tmp_sum_value += src_ic[0];
tmp_sum_value += src_ic[1];
tmp_sum_value += src_ic[2];
tmp_sum_value += src_ic[3];
pack_ic[0] = src_ic[0];
pack_ic[1] = src_ic[1];
pack_ic[2] = src_ic[2];
pack_ic[3] = src_ic[3];
src_ic += C4NUM;
pack_ic += C4NUM * C4NUM;
}
for (int ici = ic_4div; ici < input_channel; ici += 1) {
tmp_sum_value += src_ic[0];
pack_ic[0] = src_ic[0];
src_ic += 1;
pack_ic += 1;
}
input_sum[hwi] = tmp_sum_value * filter_zp;
src_r += input_channel;
pack_r += C4NUM;
}
for (int hwi = plane_size; hwi < hw4; hwi++) {
input_sum[hwi] = 0;
}
}
return;
}
void RowMajor2Col16x4MajorInt8(int8_t *src, int row, int col, int8_t *dst) {
int row_16 = UP_ROUND(row, C16NUM);
int stride = sizeof(int8_t) * 16 * 4;

@ -52,6 +52,15 @@ void MatMulInt8_4x2_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool peroc);
/* 4x4 4x16 -> 4x16 */
void RowMajor2Row4x16MajorInt8(const int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void PackInput4x4AndInputSumPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum,
size_t input_channel, size_t plane_size, int32_t filter_zp);
void MatMulInt8_4x16_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
size_t per_channel, int32_t *filter_zp);
#ifdef ENABLE_ARM64
void MatmulInt8Neon64(const int8_t *a, const int8_t *b, int8_t *dst, int row4, int col4, int deep16, const int *a_sums,
const int *bias, int act_min, int act_max, int out_zp, int32_t *multiplier, int32_t *left_shift,

@ -27,6 +27,11 @@ typedef void (*MATMUL_OPT_R_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, size_t per_channel);
typedef void (*MATMUL_OPT_DP_FUNC)(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, size_t per_channel, int *filter_zp);
typedef enum OutType { OutType_C8 = 0, OutType_Nhwc = 1, OutType_TileC8 = 2 } OutType;
typedef struct MatMulParameter {
@ -40,6 +45,7 @@ typedef struct MatMulParameter {
int col_2_;
int col_4_;
int col_8_;
int col_16_;
int deep_;
int deep_4_;
int deep_16_;

@ -38,6 +38,12 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int8_
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16);
void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel);
void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel);
void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_param, int data_size);
void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param);

@ -45,8 +45,12 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
void FreeRunBuf();
public:
int RunImpl(int task_id);
int RunPre(int task_id);
int DoRun(int task_id);
private:
int RunArm32(int task_id);
int RunArm64(int task_id);
int RunArm64Opt(int task_id);
private:
void FreeResizeBuf();
@ -58,8 +62,8 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
int InitBiasByzp(void *src_weight, int input_channel, int output_channel, int round_oc);
private:
int32_t *input_sum_ = nullptr; /* per-oc: oc4 format */
int32_t *filter_zp_ptr_ = nullptr; /* per-oc */
int32_t *input_sum_ = nullptr; /* per-oc */
int32_t *filter_zp_ptr_ = nullptr; /* per-oc up round */
int32_t *left_shift_ = nullptr; /* per-oc up round */
int32_t *right_shift_ = nullptr; /* per-oc up round */
int32_t *multiplier_ = nullptr; /* per-oc up round */
@ -69,12 +73,10 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
int8_t *output_ptr_ = nullptr;
size_t thread_count_ = 1;
size_t thread_stride_ = 0;
size_t thread_count_hw_ = 1;
size_t thread_stride_hw_ = 0;
bool pre_trans_input_ = false;
size_t input_sum_size_ = 0;
MatMulParameter *matmul_param_ = nullptr;
MATMUL_OPT_R_FUNC matmul_func_ = nullptr;
MATMUL_OPT_DP_FUNC matmul_func_ = nullptr;
bool support_optimize_ = false;
bool filter_peroc_ = false;
};

@ -33,6 +33,9 @@ extern void MatmulInt8DpNeon64(const int8_t *a, const int8_t *b, int8_t *dst, in
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp,
int *multiplier, int *left_shift, int *right_shift, int row, int col, int stride,
size_t peroc);
extern void MatmulInt8DpOpt(const int8_t *a, const int8_t *b, int8_t *dst, size_t row8, size_t col8, size_t deep4,
const int *a_sums, const int *bias, int act_min, int act_max, int out_zp, int *multiplier,
int *left_shift, int *right_shift, size_t stride, size_t peroc, int *filter_zp);
#ifdef ENABLE_ARM64
void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
@ -57,6 +60,13 @@ void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst,
return MatmulInt8DpNeon64(a, b, dst, UP_ROUND(row, 8), UP_ROUND(col, 8), deep_4, input_sum, bias, mini, maxi,
output_zp, multiplier, left_shift, right_shift, row, col, stride, per_channel);
}
void MatMulDpInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, size_t per_channel, int32_t *filter_zp) {
return MatmulInt8DpOpt(a, b, dst, row, col, deep_4, input_sum, bias, mini, maxi, output_zp, multiplier, left_shift,
right_shift, stride, per_channel, filter_zp);
}
#endif
#ifdef __cplusplus

@ -33,6 +33,10 @@ void MatMulRInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, size_t per_channel);
void MatMulDpInt8_optimize_handler(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_4,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini,
int32_t maxi, size_t per_channel, int32_t *filter_zp);
#endif
#ifdef __cplusplus

Loading…
Cancel
Save