[MSLITE][Develop]Conv1x1 preTrasn neon code -> .S

pull/6194/head
ling 4 years ago
parent 19874b83e7
commit 96d01f17ec

@ -0,0 +1,130 @@
.text
.align 5
.global PreSum4x16Int8Peroc
#ifndef __APPLE__
.type PreSum4x16Int8Peroc, %function
#endif
//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div2,
// size_t oc_res2, size_t stride);
// r0 src
// r1 sum
// r2 zp
// r3 hw4
// r4 ic16
// r5 oc_div2
// r6 oc_res2
// r7 stride
PreSum4x16Int8Peroc:
push {r4-r8, r10, r11, lr}
vpush {q4-q7}
add sp, sp, #96
ldr r4, [sp]
ldr r5, [sp, #4]
ldr r6, [sp, #8]
ldr r7, [sp, #12]
mov r8, #0
mov r10, #8
RowLoop:
cmp r8, r3
beq End
add r8, r8, #4
vmov.s32 q13, #0
mov r9, #0
mov r11, r2
Sum:
cmp r9, r4
beq Mul
add r9, r9, #16
vld1.8 {q0, q1}, [r0]!
vld1.8 {q2, q3}, [r0]!
vpaddl.s8 q4, q0
vpaddl.s8 q5, q1
vpaddl.s8 q6, q2
vpaddl.s8 q7, q3
vpaddl.s16 q0, q4
vpaddl.s16 q1, q5
vpaddl.s16 q2, q6
vpaddl.s16 q3, q7
vpaddl.s32 q4, q0
vpaddl.s32 q5, q1
vpaddl.s32 q6, q2
vpaddl.s32 q7, q3
vqmovn.s64 d0, q4
vqmovn.s64 d1, q5
vqmovn.s64 d2, q6
vqmovn.s64 d3, q7
vpaddl.s32 q4, q0
vpaddl.s32 q5, q1
vqmovn.s64 d0, q4
vqmovn.s64 d1, q5
vadd.i32 q13, q13, q0
b Sum
Mul:
mov r12, r1
add r1, r1, #32
mov r9, #0
vdup.32 d1, d26[0]
vdup.32 d2, d26[1]
vdup.32 d3, d27[0]
vdup.32 d4, d27[1]
Write:
cmp r9, r5
beq OcRes
add r9, r9, #2
vld1.32 {d9}, [r11]!
vmul.i32 d5, d1, d9
vmul.i32 d6, d2, d9
vmul.i32 d7, d3, d9
vmul.i32 d8, d4, d9
vst1.32 d5, [r12], r10
vst1.32 d6, [r12], r10
vst1.32 d7, [r12], r10
vst1.32 d8, [r12], r10
add r12, r12, r7
b Write
OcRes:
cmp r6, #0
beq RowLoop
vmov.s32 d9, #0
vld1.8 {d9[0]}, [r11]
vmul.i32 d5, d1, d9
vmul.i32 d6, d2, d9
vmul.i32 d7, d3, d9
vmul.i32 d8, d4, d9
vst1.32 d5, [r12], r10
vst1.32 d6, [r12], r10
vst1.32 d7, [r12], r10
vst1.32 d8, [r12], r10
b RowLoop
End:
sub sp, sp, #96
vpop {q4-q7}
pop {r4-r8, r10, r11, pc}

@ -0,0 +1,81 @@
.text
.align 5
.global PreSum4x16Int8Pert
#ifndef __APPLE__
.type PreSum4x16Int8Pert, %function
#endif
// void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp);
// r0 src
// r1 sum
// r2 row4
// r3 co16
// r4 filter_zp
PreSum4x16Int8Pert:
push {r4-r8, r10, r11, lr}
vpush {q4-q7}
add sp, sp, #96
ldr r4, [sp]
vdup.32 q10, r4
mov r5, #0
mov r7, #16
RowLoop:
cmp r5, r2
beq End
add r5, r5, #4
vmov.s32 q13, #0
mov r6, #0
CalLoop:
cmp r6, r3
beq Write
add r6, r6, #16
vld1.8 {q0, q1}, [r0]!
vld1.8 {q2, q3}, [r0]!
vpaddl.s8 q4, q0
vpaddl.s8 q5, q1
vpaddl.s8 q6, q2
vpaddl.s8 q7, q3
vpaddl.s16 q0, q4
vpaddl.s16 q1, q5
vpaddl.s16 q2, q6
vpaddl.s16 q3, q7
vpaddl.s32 q4, q0
vpaddl.s32 q5, q1
vpaddl.s32 q6, q2
vpaddl.s32 q7, q3
vqmovn.s64 d0, q4
vqmovn.s64 d1, q5
vqmovn.s64 d2, q6
vqmovn.s64 d3, q7
vpaddl.s32 q4, q0
vpaddl.s32 q5, q1
vqmovn.s64 d0, q4
vqmovn.s64 d1, q5
vadd.i32 q13, q13, q0
b CalLoop
Write:
vmul.i32 q13, q13, q10
vst1.32 q13, [r1], r7
beq RowLoop
End:
sub sp, sp, #96
vpop {q4-q7}
pop {r4-r8, r10, r11, pc}

@ -0,0 +1,129 @@
#ifdef __aarch64__
.text
.align 5
//.p2align 5,,15
.global PreSum4x16Int8Peroc
#ifndef __APPLE__
.type PreSum4x16Int8Peroc, %function
#endif
//void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div4,
// size_t oc_res4, size_t stride);
// x0 src
// x1 sum
// x2 zp
// w3 hw4
// w4 ic16
// w5 oc_div4
// w6 oc_res4
// w7 stride
PreSum4x16Int8Peroc:
mov w8, #0
RowLoop:
cmp w8, w3
beq End
add w8, w8, #4
dup v16.4s, wzr
mov w9, #0
mov x16, x2
Sum:
cmp w9, w4
beq Mul
add w9, w9, #16
ld1 {v0.16b}, [x0], #16
ld1 {v1.16b}, [x0], #16
ld1 {v2.16b}, [x0], #16
ld1 {v3.16b}, [x0], #16
saddlp v4.8h, v0.16b
saddlp v5.8h, v1.16b
saddlp v6.8h, v2.16b
saddlp v7.8h, v3.16b
saddlp v0.4S, v4.8h
saddlp v1.4S, v5.8h
saddlp v2.4S, v6.8h
saddlp v3.4S, v7.8h
addv s4, v0.4S
addv s5, v1.4S
addv s6, v2.4S
addv s7, v3.4S
mov v0.s[0], v4.s[0]
mov v0.s[1], v5.s[0]
mov v0.s[2], v6.s[0]
mov v0.s[3], v7.s[0]
add v16.4s, v16.4s, v0.4s
b Sum
Mul:
mov x12, x1
add x1, x1, #64
mov w9, #0
dup v1.4s, v16.s[0]
dup v2.4s, v16.s[1]
dup v3.4s, v16.s[2]
dup v4.4s, v16.s[3]
WriteOc4:
cmp w9, w5
beq OcRes4
add w9, w9, #4
ld1 {v5.4s}, [x16], #16
mul v16.4s, v5.4s, v1.4s
mul v17.4s, v5.4s, v2.4s
mul v18.4s, v5.4s, v3.4s
mul v19.4s, v5.4s, v4.4s
st1 {v16.4s}, [x12], #16
st1 {v17.4s}, [x12], #16
st1 {v18.4s}, [x12], #16
st1 {v19.4s}, [x12], #16
add x12, x12, x7
b WriteOc4
OcRes4:
cmp w6, #0
beq RowLoop
dup v15.4s, wzr
cmp w6, #1
beq OcRes4_1
cmp w6, #2
beq OcRes4_2
cmp w6, #3
beq OcRes4_3
OcRes4_1:
ld1 {v15.s}[0], [x16]
b OcRes4End
OcRes4_2:
ld1 {v15.h}[0], [x16]
b OcRes4End
OcRes4_3:
ld1 {v15.h}[0], [x16]
add x16, x16, #8
ld1 {v15.s}[2], [x16]
b OcRes4End
OcRes4End:
mul v16.4s, v15.4s, v1.4s
mul v17.4s, v15.4s, v2.4s
mul v18.4s, v15.4s, v3.4s
mul v19.4s, v15.4s, v4.4s
st1 {v16.4s}, [x12], #16
st1 {v17.4s}, [x12], #16
st1 {v18.4s}, [x12], #16
st1 {v19.4s}, [x12], #16
b RowLoop
End:
ret
#endif

@ -0,0 +1,70 @@
#ifdef __aarch64__
.text
.align 5
//.p2align 5,,15
.global PreSum4x16Int8Pert
#ifndef __APPLE__
.type PreSum4x16Int8Pert, %function
#endif
// void PreSum4x16Int8Pert(const int8_t *src, int32_t *dst, size_t row4, size_t col16, int32_t filter_zp);
// x0 src
// x1 dst
// w2 row4
// w3 co16
// w4 filter_zp
PreSum4x16Int8Pert:
dup v17.4s, w4
mov w5, #0
RowLoop:
cmp w5, w2
beq End
add w5, w5, #4
dup v16.4s, wzr
mov w6, #0
CalLoop:
cmp w6, w3
beq Write
add w6, w6, #16
ld1 {v0.16b}, [x0], #16
ld1 {v1.16b}, [x0], #16
ld1 {v2.16b}, [x0], #16
ld1 {v3.16b}, [x0], #16
saddlp v4.8h, v0.16b
saddlp v5.8h, v1.16b
saddlp v6.8h, v2.16b
saddlp v7.8h, v3.16b
saddlp v0.4S, v4.8h
saddlp v1.4S, v5.8h
saddlp v2.4S, v6.8h
saddlp v3.4S, v7.8h
addv s4, v0.4S
addv s5, v1.4S
addv s6, v2.4S
addv s7, v3.4S
mov v0.s[0], v4.s[0]
mov v0.s[1], v5.s[0]
mov v0.s[2], v6.s[0]
mov v0.s[3], v7.s[0]
add v16.4s, v16.4s, v0.4s
b CalLoop
Write:
mul v16.4s, v16.4s, v17.4s
st1 {v16.4s}, [x1], #16
beq RowLoop
End:
ret
#endif

@ -1029,6 +1029,14 @@ void Conv1x1Int8Arm32(const int8_t *packed_input, const int8_t *packed_weight, i
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param) {
int is_per_channel = conv_param->conv_quant_arg_.filter_arg_num_ != 1 ? true : false;
if (is_per_channel == 1) {
return MatMulInt8_4x2_r(
packed_input, packed_weight, dst, row, col, deep16, conv_param->output_channel_, input_sum, bias, left_shift,
right_shift, multiplier, conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0], true);
}
#ifdef ENABLE_ARM32
MatmulInt8Neon32(packed_input, packed_weight, dst, row, col, deep16, input_sum, bias,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0],

@ -117,10 +117,10 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
for (int ri = 0; ri < row_4div; ri += C4NUM) {
for (int ci = 0; ci < col_16div; ci += C16NUM) {
#ifdef ENABLE_ARM64
size_t col_offset = col;
int8_t *src_c = src_r + ci;
int8_t *dst_c = dst_r + ci * C4NUM;
#ifdef ENABLE_ARM64
asm volatile(
"mov x10, %[src_c] \n"
"mov x11, %[dst_c] \n"
@ -138,8 +138,28 @@ void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset)
: "x10", "x11", "v0", "v1", "v2", "v3");
#elif ENABLE_ARM32
asm volatile(
"mov r0, %[src_c] \n"
"mov r1, %[dst_c] \n"
"mov r2, %[col_offset] \n"
"mov r3, #16 \n"
"vld1.8 {q0}, [r0], r2 \n"
"vld1.8 {q1}, [r0], r2 \n"
"vld1.8 {q2}, [r0], r2 \n"
"vld1.8 {q3}, [r0], r2 \n"
"vst1.32 q0, [r1], r3 \n"
"vst1.32 q1, [r1], r3 \n"
"vst1.32 q2, [r1], r3 \n"
"vst1.32 q3, [r1], r3 \n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col_offset ] "r"(col_offset)
: "r0", "r1", "r2", "r3", "q0", "q1", "q2", "q3");
#else
MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, C4NUM, C16NUM, col);
MatrixPack4x16UnitInt8(src_c, dst_c, C4NUM, C16NUM, col_offset);
#endif
}

@ -189,63 +189,8 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) {
/* normal matmul : 4x16 * 16x4 -> 4x4 */
#ifdef ENABLE_ARM64
asm volatile(
"mov x10, %[src] \n"
"mov x11, %[dst] \n"
"dup v15.4s, %w[filter_zp] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[row4] \n"
"beq 4f \n"
"add x0, x0, #4\n"
"dup v10.4s, wzr \n"
"mov x2, #0 \n"
"2: \n"
"cmp x2, %[col16] \n"
"beq 3f \n"
"add x2, x2, #16\n"
"ld1 {v0.16b}, [x10], #16\n"
"ld1 {v1.16b}, [x10], #16\n"
"ld1 {v2.16b}, [x10], #16\n"
"ld1 {v3.16b}, [x10], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v6.8h, v2.16b \n"
"saddlp v7.8h, v3.16b \n"
"saddlp v0.4S, v4.8h \n"
"saddlp v1.4S, v5.8h \n"
"saddlp v2.4S, v6.8h \n"
"saddlp v3.4S, v7.8h \n"
"addv s4, v0.4S \n"
"addv s5, v1.4S \n"
"addv s6, v2.4S \n"
"addv s7, v3.4S \n"
"mov v0.s[0], v4.s[0] \n"
"mov v0.s[1], v5.s[0] \n"
"mov v0.s[2], v6.s[0] \n"
"mov v0.s[3], v7.s[0] \n"
"add v10.4s, v10.4s, v0.4s \n"
"b 2b\n"
"3: \n"
"mul v10.4s, v10.4s, v15.4s \n"
"st1 {v10.4s}, [x11], #16 \n"
"beq 1b \n"
"4: \n"
:
: [ dst ] "r"(dst), [ src ] "r"(src), [ row4 ] "r"(row4), [ col16 ] "r"(col16), [ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x2", "x3", "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15");
#ifdef ENABLE_ARM
PreSum4x16Int8Pert(src, dst, row4, col16, filter_zp);
#else
for (int r = 0; r < row4; r++) {
int32_t tmp_value = 0;
@ -268,121 +213,7 @@ void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, i
size_t oc_div4 = output_channel / C4NUM * C4NUM;
size_t oc_res4 = output_channel - oc_div4;
size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4;
asm volatile(
"mov x10, %[input_value] \n"
"mov x11, %[input_sum] \n"
"mov x15, %[filter_zp_ptr] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[hw4] \n"
"beq 11f \n"
"add x0, x0, #4\n"
"dup v10.4s, wzr \n"
"mov x2, #0 \n"
"mov x16, x15 \n"
"2: \n"
"cmp x2, %[ic16] \n"
"beq 3f \n"
"add x2, x2, #16 \n"
"ld1 {v0.16b}, [x10], #16\n"
"ld1 {v1.16b}, [x10], #16\n"
"ld1 {v2.16b}, [x10], #16\n"
"ld1 {v3.16b}, [x10], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v6.8h, v2.16b \n"
"saddlp v7.8h, v3.16b \n"
"saddlp v0.4S, v4.8h \n"
"saddlp v1.4S, v5.8h \n"
"saddlp v2.4S, v6.8h \n"
"saddlp v3.4S, v7.8h \n"
"addv s4, v0.4S \n"
"addv s5, v1.4S \n"
"addv s6, v2.4S \n"
"addv s7, v3.4S \n"
"mov v0.s[0], v4.s[0] \n"
"mov v0.s[1], v5.s[0] \n"
"mov v0.s[2], v6.s[0] \n"
"mov v0.s[3], v7.s[0] \n"
"add v10.4s, v10.4s, v0.4s \n"
"b 2b \n"
"3: \n"
"mov x12, x11 \n"
"add x11, x11, #64 \n"
"mov x4, #0 \n"
"dup v1.4s, v10.s[0] \n"
"dup v2.4s, v10.s[1] \n"
"dup v3.4s, v10.s[2] \n"
"dup v4.4s, v10.s[3] \n"
"4: \n"
"cmp x4, %[oc_div4] \n"
"beq 6f \n"
"add x4, x4, #4\n"
"ld1 {v15.4s}, [x16], #16\n"
"mul v16.4s, v15.4s, v1.4s \n"
"mul v17.4s, v15.4s, v2.4s \n"
"mul v18.4s, v15.4s, v3.4s \n"
"mul v19.4s, v15.4s, v4.4s \n"
"st1 {v16.4s}, [x12], #16 \n"
"st1 {v17.4s}, [x12], #16 \n"
"st1 {v18.4s}, [x12], #16 \n"
"st1 {v19.4s}, [x12], #16 \n"
"add x12, x12, %[inputsun_stride] \n"
"b 4b \n"
"6: \n"
"cmp %[oc_res4], #0\n"
"beq 1b \n"
"dup v15.4s, wzr \n"
"cmp %[oc_res4], #1\n"
"beq 7f \n"
"cmp %[oc_res4], #2\n"
"beq 8f \n"
"cmp %[oc_res4], #3\n"
"beq 9f \n"
"7: \n"
"ld1 {v15.s}[0], [x16] \n"
"b 10f \n"
"8: \n"
"ld1 {v15.h}[0], [x16] \n"
"b 10f \n"
"9: \n"
"ld1 {v15.h}[0], [x16] \n"
"add x16, x16, #8 \n"
"ld1 {v15.s}[2], [x16] \n"
"b 10f \n"
"10: \n"
"mul v16.4s, v15.4s, v1.4s \n"
"mul v17.4s, v15.4s, v2.4s \n"
"mul v18.4s, v15.4s, v3.4s \n"
"mul v19.4s, v15.4s, v4.4s \n"
"st1 {v16.4s}, [x12], #16 \n"
"st1 {v17.4s}, [x12], #16 \n"
"st1 {v18.4s}, [x12], #16 \n"
"st1 {v19.4s}, [x12], #16 \n"
"b 1b \n"
"11: \n"
:
: [ input_value ] "r"(input_value), [ input_sum ] "r"(input_sum), [ filter_zp_ptr ] "r"(filter_zp_ptr),
[ hw4 ] "r"(hw4), [ ic16 ] "r"(ic16), [ oc_div4 ] "r"(oc_div4), [ oc_res4 ] "r"(oc_res4),
[ inputsun_stride ] "r"(inputsun_stride)
: "x0", "x2", "x4", "x10", "x11", "x12", "x15", "x16", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15",
"v16", "v17", "v18", "v19");
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div4, oc_res4, inputsun_stride);
#else
for (int ri = 0; ri < plane_size; ri++) {
@ -409,6 +240,12 @@ void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_s
size_t hw4 = UP_ROUND(plane_size, C4NUM);
size_t ic16 = UP_ROUND(input_channel, C16NUM);
#ifdef ENABLE_ARM32
size_t oc_div2 = output_channel / C2NUM * C2NUM;
size_t oc_res2 = output_channel - oc_div2;
size_t inputsun_stride = hw4 * C2NUM * 4 - C4NUM * C2NUM * 4;
PreSum4x16Int8Peroc(input_value, input_sum, filter_zp_ptr, hw4, ic16, oc_div2, oc_res2, inputsun_stride);
#else
for (int ri = 0; ri < plane_size; ri++) {
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
for (int ci = 0; ci < output_channel; ci++) {
@ -424,6 +261,7 @@ void PackInputSum16x4PerChannelArm32(const int8_t *input_value, int32_t *input_s
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
#endif
return;
}

@ -121,6 +121,13 @@ void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight
void PackDeconvDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, int plane, int channel,
ConvQuantArg *quant_qrg);
#ifdef ENABLE_ARM
void PreSum4x16Int8Pert(const int8_t *src, int32_t *sum, size_t row4, size_t col16, int32_t filter_zp);
void PreSum4x16Int8Peroc(const int8_t *src, int32_t *sum, int32_t *zp, size_t hw4, size_t ic16, int32_t oc_div,
size_t oc_res, size_t stride);
#endif
#ifdef __cplusplus
}
#endif

@ -71,7 +71,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() {
}
void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
support_optimize_ = true;
support_optimize_ = false;
matmul_func_ = MatMulInt8_8x8_r;
#ifdef ENABLE_ARM64
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
@ -94,7 +94,7 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
return;
}
int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channel, int output_channel) {
int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channel, int output_channel, int round_oc) {
/* bias = bias - v2 x zp1 + zp1 x zp2 */
int32_t *bias_data = reinterpret_cast<int32_t *>(bias_data_);
int8_t *weight = reinterpret_cast<int8_t *>(src_weight);
@ -118,24 +118,23 @@ int Convolution1x1Int8CPUKernel::InitBiasByzp(void *src_weight, int input_channe
filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_;
}
int up_round_oc_size = support_optimize_ ? UP_ROUND(output_channel, C8NUM) : UP_ROUND(output_channel, C4NUM);
left_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
left_shift_ = reinterpret_cast<int32_t *>(malloc(round_oc * sizeof(int32_t)));
if (left_shift_ == nullptr) {
return RET_ERROR;
}
memset(left_shift_, 0, up_round_oc_size * sizeof(int32_t));
memset(left_shift_, 0, round_oc * sizeof(int32_t));
memcpy(left_shift_, conv_param_->conv_quant_arg_.left_shift_, output_channel * sizeof(int32_t));
right_shift_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
right_shift_ = reinterpret_cast<int32_t *>(malloc(round_oc * sizeof(int32_t)));
if (right_shift_ == nullptr) {
return RET_ERROR;
}
memset(right_shift_, 0, up_round_oc_size * sizeof(int32_t));
memset(right_shift_, 0, round_oc * sizeof(int32_t));
memcpy(right_shift_, conv_param_->conv_quant_arg_.right_shift_, output_channel * sizeof(int32_t));
multiplier_ = reinterpret_cast<int32_t *>(malloc(up_round_oc_size * sizeof(int32_t)));
multiplier_ = reinterpret_cast<int32_t *>(malloc(round_oc * sizeof(int32_t)));
if (multiplier_ == nullptr) {
return RET_ERROR;
}
memset(multiplier_, 0, up_round_oc_size * sizeof(int32_t));
memset(multiplier_, 0, round_oc * sizeof(int32_t));
memcpy(multiplier_, conv_param_->conv_quant_arg_.quant_multiplier_, output_channel * sizeof(int32_t));
}
return RET_OK;
@ -165,18 +164,18 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() {
int col4 = UP_ROUND(output_channel, C4NUM);
int col8 = UP_ROUND(output_channel, C8NUM);
size = support_optimize_ ? col8 * sizeof(int32_t) : col4 * sizeof(int32_t);
bias_data_ = malloc(size);
size = support_optimize_ ? col8 : col4;
bias_data_ = malloc(size * sizeof(int32_t));
if (bias_data_ == nullptr) {
MS_LOG(ERROR) << "Conv1x1 int8 Malloc bias_ptr_ error!";
return RET_ERROR;
}
memset(bias_data_, 0, size);
memset(bias_data_, 0, size * sizeof(int32_t));
if (in_tensors_.size() == 3) {
memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t));
}
InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel);
InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, size);
return RET_OK;
}
@ -208,7 +207,7 @@ int Convolution1x1Int8CPUKernel::InitWeightBiasArm32() {
memcpy(bias_data_, in_tensors_[kBiasIndex]->MutableData(), output_channel * sizeof(int32_t));
}
InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel);
InitBiasByzp(filter_tensor->MutableData(), input_channel, output_channel, UP_ROUND(output_channel, C2NUM));
return RET_OK;
}
@ -342,6 +341,12 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
if (cur_oc <= 0) {
return RET_OK;
}
if (filter_peroc_) {
cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C2NUM;
cur_left_shift = left_shift_ + task_id * thread_stride_ * C2NUM;
cur_right_shift = right_shift_ + task_id * thread_stride_ * C2NUM;
cur_multiplier = multiplier_ + task_id * thread_stride_ * C2NUM;
}
Conv1x1Int8Arm32(packed_input_, packed_weight_ + task_id * thread_stride_ * C2NUM * matmul_param_->deep_16_,
output_ptr_ + task_id * thread_stride_ * C2NUM, cur_input_sum,
reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C2NUM, matmul_param_->row_,

@ -55,7 +55,7 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
int InitWeightBiasArm32();
void Pre1x1Trans(int8_t *src_input, int8_t *src_output);
void CheckSupportOptimize();
int InitBiasByzp(void *src_weight, int input_channel, int output_channel);
int InitBiasByzp(void *src_weight, int input_channel, int output_channel, int round_oc);
private:
int32_t *input_sum_ = nullptr; /* per-oc: oc4 format */

Loading…
Cancel
Save