!4849 [MS][LITE][Develop]Deconv int8 neon code

Merge pull request !4849 from ling/deconv
pull/4849/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 01017492b2

@ -15,9 +15,7 @@
*/
#include "src/runtime/kernel/arm/int8/deconvolution_int8.h"
#include "src/runtime/kernel/arm/nnacl/quantization/fixed_point.h"
#include "src/runtime/runtime_api.h"
#include "src/kernel_registry.h"
#include "src/runtime/kernel/arm/nnacl/optimized_kernel.h"
using mindspore::kernel::KERNEL_ARCH::kCPU;
@ -89,9 +87,8 @@ int DeConvInt8CPUKernel::Init() {
}
void DeConvInt8CPUKernel::CheckSupportOptimize() {
matmul_func_ = nullptr;
support_optimize_ = true;
matmul_func_ = MatMulInt8_16x4;
#ifdef ENABLE_ARM64
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
if (optimize_op_handler != nullptr) {
@ -102,12 +99,15 @@ void DeConvInt8CPUKernel::CheckSupportOptimize() {
MS_LOG(ERROR) << "load matmul func failed! " << dlopen_error << ".";
support_optimize_ = false;
matmul_func_ = nullptr;
} else {
support_optimize_ = true;
}
} else {
support_optimize_ = false;
matmul_func_ = nullptr;
}
#endif
return;
}
int DeConvInt8CPUKernel::InitParam() {
@ -120,6 +120,7 @@ int DeConvInt8CPUKernel::InitParam() {
matmul_param_->deep_ = conv_param_->input_channel_;
matmul_param_->col_ = conv_param_->output_channel_ * conv_param_->kernel_h_ * conv_param_->kernel_w_;
/* optimize normal -> same data layout */
input_trans_func_ = RowMajor2Row16x4MajorInt8;
size_t oc4 = UP_DIV(conv_param_->output_channel_, C4NUM);
thread_count_ = MSMIN(op_parameter_->thread_num_, oc4);

@ -0,0 +1,246 @@
#ifdef __aarch64__
.text
.align 5
//.p2align 5,,15
.global PostFuncInt8C4Neon64
#ifndef __APPLE__
.type PostFuncInt8C4Neon64, %function
#endif
//void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res,
// size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift,
// int32_t zp, int32_t mini, int32_t maxi);
// x0 in
// x1 bias
// x2 out
// x3 oc4div
// x4 oc4res
// x5 plane
// x6 stride
// x7 multiplier
// x8 left_shift
// x9 right_shift
// x10 zp
// x11 mini
// x12 maxi
// v0 ~ v15 value
// x24 x25 write loop tmp buf
// v16 bias data
// v26 multiplier
// v27 left_shift
// v28 right_shift
// v29 zp
// v30 min
// v31 max
// w15 oc4 loop control
// w16 hw loop control
PostFuncInt8C4Neon64:
ldr w8, [sp]
ldr w9, [sp, #8]
ldr w10, [sp, #16]
ldr w11, [sp, #24]
ldr w12, [sp, #32]
ldr w13, [sp, #40]
dup v26.4s, w7
dup v27.4s, w8
dup v28.4s, w9
dup v29.4s, w10
dup v30.4s, w11
dup v31.4s, w12
mov w15, #0
Loop_C4:
cmp w15, w3
beq Loop_C1
mov x25, #4
mul x24, x15, x25
add x25, x2, x24
add w15, w15, #4
mov w16, w5
ld1 {v16.4s}, [x1], #16
Loop_4x4:
cmp w16, #4
blt Loop_1x4
sub w16, w16, #4
ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64
add v0.4s, v0.4s, v16.4s
add v1.4s, v1.4s, v16.4s
add v2.4s, v2.4s, v16.4s
add v3.4s, v3.4s, v16.4s
sqshl v0.4s, v0.4s, v27.4s
sqshl v1.4s, v1.4s, v27.4s
sqshl v2.4s, v2.4s, v27.4s
sqshl v3.4s, v3.4s, v27.4s
sqrdmulh v0.4s, v0.4s, v26.4s
sqrdmulh v1.4s, v1.4s, v26.4s
sqrdmulh v2.4s, v2.4s, v26.4s
sqrdmulh v3.4s, v3.4s, v26.4s
and v4.16b, v28.16b, v0.16b
and v5.16b, v28.16b, v1.16b
and v6.16b, v28.16b, v2.16b
and v7.16b, v28.16b, v3.16b
sshr v4.4s, v4.4s, #31
sshr v5.4s, v5.4s, #31
sshr v6.4s, v6.4s, #31
sshr v7.4s, v7.4s, #31
sqadd v0.4s, v0.4s, v4.4s
sqadd v1.4s, v1.4s, v5.4s
sqadd v2.4s, v2.4s, v6.4s
sqadd v3.4s, v3.4s, v7.4s
srshl v0.4s, v0.4s, v28.4s
srshl v1.4s, v1.4s, v28.4s
srshl v2.4s, v2.4s, v28.4s
srshl v3.4s, v3.4s, v28.4s
add v0.4s, v0.4s, v29.4s
add v1.4s, v1.4s, v29.4s
add v2.4s, v2.4s, v29.4s
add v3.4s, v3.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smax v1.4s, v1.4s, v30.4s
smax v2.4s, v2.4s, v30.4s
smax v3.4s, v3.4s, v30.4s
smin v0.4s, v0.4s, v31.4s
smin v1.4s, v1.4s, v31.4s
smin v2.4s, v2.4s, v31.4s
smin v3.4s, v3.4s, v31.4s
sqxtn v4.4h, v0.4s
sqxtn v5.4h, v1.4s
sqxtn v6.4h, v2.4s
sqxtn v7.4h, v3.4s
sqxtn v0.8b, v4.8h
sqxtn v1.8b, v5.8h
sqxtn v2.8b, v6.8h
sqxtn v3.8b, v7.8h
st1 {v0.s}[0], [x2], x6
st1 {v1.s}[0], [x2], x6
st1 {v2.s}[0], [x2], x6
st1 {v3.s}[0], [x2], x6
b Loop_4x4
Loop_1x4:
cmp w16, #0
beq Loop_C4
sub w16, w16, #1
ld1 {v0.4s}, [x0], #16
add v0.4s, v0.4s, v16.4s
sqshl v0.4s, v0.4s, v27.4s
sqrdmulh v0.4s, v0.4s, v26.4s
and v2.16b, v28.16b, v0.16b
sshr v2.4s, v2.4s, #31
sqadd v0.4s, v0.4s, v2.4s
srshl v0.4s, v0.4s, v28.4s
add v0.4s, v0.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smin v0.4s, v0.4s, v31.4s
sqxtn v1.4h, v0.4s
sqxtn v0.8b, v1.8h
st1 {v0.s}[0], [x2], x6
b Loop_1x4
Loop_C1:
cmp x4, #0
beq End
mov w16, w5
ld1 {v16.4s}, [x1], #16
mov x25, #4
mul x24, x15, x25
add x25, x2, x24
add x24, x25, #2
cmp x4, #1
beq Loop_C1_1
cmp x4, #2
beq Loop_C1_2
cmp x4, #3
beq Loop_C1_3
Loop_C1_1:
cmp w16, #0
beq End
sub w16, w16, #1
ld1 {v0.4s}, [x0], #16
add v0.4s, v0.4s, v16.4s
sqshl v0.4s, v0.4s, v27.4s
sqrdmulh v0.4s, v0.4s, v26.4s
and v2.16b, v28.16b, v0.16b
sshr v2.4s, v2.4s, #31
sqadd v0.4s, v0.4s, v2.4s
srshl v0.4s, v0.4s, v28.4s
add v0.4s, v0.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smin v0.4s, v0.4s, v31.4s
sqxtn v1.4h, v0.4s
sqxtn v0.8b, v1.8h
st1 {v0.b}[0], [x25], x6
b Loop_C1_1
Loop_C1_2:
cmp w16, #0
beq End
sub w16, w16, #1
ld1 {v0.4s}, [x0], #16
add v0.4s, v0.4s, v16.4s
sqshl v0.4s, v0.4s, v27.4s
sqrdmulh v0.4s, v0.4s, v26.4s
and v2.16b, v28.16b, v0.16b
sshr v2.4s, v2.4s, #31
sqadd v0.4s, v0.4s, v2.4s
srshl v0.4s, v0.4s, v28.4s
add v0.4s, v0.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smin v0.4s, v0.4s, v31.4s
sqxtn v1.4h, v0.4s
sqxtn v0.8b, v1.8h
st1 {v0.h}[0], [x25], x6
b Loop_C1_2
Loop_C1_3:
cmp w16, #0
beq End
sub w16, w16, #1
ld1 {v0.4s}, [x0], #16
add v0.4s, v0.4s, v16.4s
sqshl v0.4s, v0.4s, v27.4s
sqrdmulh v0.4s, v0.4s, v26.4s
and v2.16b, v28.16b, v0.16b
sshr v2.4s, v2.4s, #31
sqadd v0.4s, v0.4s, v2.4s
srshl v0.4s, v0.4s, v28.4s
add v0.4s, v0.4s, v29.4s
smax v0.4s, v0.4s, v30.4s
smin v0.4s, v0.4s, v31.4s
sqxtn v1.4h, v0.4s
sqxtn v0.8b, v1.8h
st1 {v0.h}[0], [x25], x6
st1 {v0.b}[2], [x24], x6
b Loop_C1_3
End:
ret
#endif

@ -15,9 +15,10 @@
*/
#include "nnacl/int8/common_func.h"
#include "nnacl/quantization/fixed_point.h"
void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, size_t oc, size_t plane,
size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int8_t mini, int8_t maxi,
size_t out_oc_stride, size_t in_plane_stride, int32_t multiplier, int32_t mini, int32_t maxi,
int32_t left_shift, int32_t right_shift, int32_t zp, int size) {
if (size == 0) {
return;
@ -40,18 +41,26 @@ void PostConvFuncCommInt8(const int32_t *in, int8_t *out, const int32_t *bias, s
return;
}
void PostFuncInt8C8(const int *in, const int *bias, int8_t *out, int oc, int plane, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) {
void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi) {
/* ((int32_t)row8x8-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */
PostConvFuncCommInt8(in, out, bias, oc, plane, oc, UP_ROUND(plane, C8NUM) * C8NUM, multiplier, mini, maxi, left_shift,
right_shift, zp, C8NUM);
return;
}
void PostFuncInt8C4(const int *in, const int *bias, int8_t *out, int oc, int plane, int stride, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi) {
/* ((int32_t)row4x4-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */
void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride,
int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini,
int32_t maxi) {
/* ((int32_t)row4x4-major + bias) * multiplier + output_zp => (int8)relu => (int8_t)row-major */
#ifndef ENABLE_ARM64
PostConvFuncCommInt8(in, out, bias, oc, plane, stride, UP_ROUND(plane, C4NUM) * C4NUM, multiplier, mini, maxi,
left_shift, right_shift, zp, C4NUM);
#else
size_t oc4div = oc / C4NUM * C4NUM;
size_t oc4res = oc % C4NUM;
PostFuncInt8C4Neon64(in, bias, out, oc4div, oc4res, plane, stride * sizeof(int8_t), multiplier, left_shift,
right_shift, zp, mini, maxi);
#endif
return;
}

@ -27,30 +27,21 @@
extern "C" {
#endif
void PostFuncInt8C8(const int *in, const int *bias, int8_t *out, int oc, int plane, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi);
void PostFuncInt8C4(const int *in, const int *bias, int8_t *out, int oc, int plane, int stride, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int8_t mini, int8_t maxi);
#ifdef ENABLE_ARM
void PostFuncInt8C8(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, int32_t multiplier,
int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini, int32_t maxi);
void PostFuncInt8C4(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc, size_t plane, size_t stride,
int32_t multiplier, int32_t left_shift, int32_t right_shift, int32_t zp, int32_t mini,
int32_t maxi);
#ifdef ENABLE_ARM64
void PostFuncInt8C4Neon64(const int32_t *in, const int32_t *bias, int8_t *out, size_t oc4div, size_t oc4res,
size_t plane, size_t stride, int32_t multiplier, int32_t left_shift, int32_t right_shift,
int32_t zp, int32_t mini, int32_t maxi);
void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *weight, size_t ksize, size_t ic8,
size_t oc4, size_t offset);
#ifdef ENABLE_ARM64
void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before,
size_t shift_after);
// #elif defined(ENABLE_ARM32)
// void IndirectGemmInt8_2x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias,
// size_t ksize,
// size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
// size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before,
// size_t shift_after);
#endif
#endif
#ifdef ENABLE_ARM
void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);

@ -136,60 +136,109 @@ int DeConvPostInt8C4(const int32_t *src, const int32_t *bias, int32_t *tmp, int8
void DeConvWeightTransInt8(int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane,
bool support_optimize_) {
if (support_optimize_) {
int ic16 = UP_ROUND(input_channel, C16NUM);
int oc4 = UP_ROUND(output_channel, C4NUM);
for (int ic = 0; ic < input_channel; ic++) {
int ic16div = ic / C16NUM, ic16mod = ic % C16NUM;
for (int oc = 0; oc < output_channel; oc++) {
int oc4div = oc / C4NUM, oc4mod = oc % C4NUM;
for (int hw = 0; hw < plane; hw++) {
int src_index = ic * output_channel * plane + hw * output_channel + oc;
int dst_index =
hw * ic16 * oc4 + oc4div * ic16 * C4NUM + ic16div * C16NUM * C4NUM + oc4mod * C16NUM + ic16mod;
dst[dst_index] = src[src_index];
}
/* optimize normal -> same layout */
int ic16 = UP_ROUND(input_channel, C16NUM);
int oc4 = UP_ROUND(output_channel, C4NUM);
for (int ic = 0; ic < input_channel; ic++) {
int ic16div = ic / C16NUM, ic16mod = ic % C16NUM;
for (int oc = 0; oc < output_channel; oc++) {
int oc4div = oc / C4NUM, oc4mod = oc % C4NUM;
for (int hw = 0; hw < plane; hw++) {
int src_index = ic * output_channel * plane + hw * output_channel + oc;
int dst_index = hw * ic16 * oc4 + oc4div * ic16 * C4NUM + ic16div * C16NUM * C4NUM + oc4mod * C16NUM + ic16mod;
dst[dst_index] = src[src_index];
}
}
} else {
/* normal int8 deconv */
}
return;
}
void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16, int col4,
bool suppport_opt) {
if (suppport_opt) {
for (int c = 0; c < col4; c++) {
int c4div = c / C4NUM, c4mod = c % C4NUM;
int32_t value = 0;
for (int r = 0; r < deep16; r++) {
int r16div = r / 16, r16mod = r % 16;
int src_index = c4div * deep16 * C4NUM + r16div * C4NUM * C16NUM + c4mod * C16NUM + r16mod;
value += weight[src_index];
}
weight_sum[c] = filter_zp * input_zp * deep16 - value * input_zp;
/* optimize normal -> same layout */
for (int c = 0; c < col4; c++) {
int c4div = c / C4NUM, c4mod = c % C4NUM;
int32_t value = 0;
for (int r = 0; r < deep16; r++) {
int r16div = r / C16NUM, r16mod = r % C16NUM;
int src_index = c4div * deep16 * C4NUM + r16div * C4NUM * C16NUM + c4mod * C16NUM + r16mod;
value += weight[src_index];
}
} else {
/* normal int8 deconv */
weight_sum[c] = filter_zp * input_zp * deep16 - value * input_zp;
}
return;
}
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, int row4, int col16, bool suppport_opt) {
if (suppport_opt) {
for (int r = 0; r < row4; r++) {
int32_t tmp_value = 0;
for (int c = 0; c < col16; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM;
int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod;
tmp_value += src[src_index];
}
dst[r] = tmp_value * filter_zp;
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16,
bool suppport_opt) {
/* optimize normal -> same layout */
#ifdef ENABLE_ARM64
asm volatile(
"mov x10, %[src] \n"
"mov x11, %[dst] \n"
"dup v15.4s, %w[filter_zp] \n"
"mov x0, #0 \n"
"1: \n"
"cmp x0, %[row4] \n"
"beq 4f \n"
"add x0, x0, #4\n"
"dup v10.4s, wzr \n"
"mov x2, #0 \n"
"2: \n"
"cmp x2, %[col16] \n"
"beq 3f \n"
"add x2, x2, #16\n"
"ld1 {v0.16b}, [x10], #16\n"
"ld1 {v1.16b}, [x10], #16\n"
"ld1 {v2.16b}, [x10], #16\n"
"ld1 {v3.16b}, [x10], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v6.8h, v2.16b \n"
"saddlp v7.8h, v3.16b \n"
"saddlp v0.4S, v4.8h \n"
"saddlp v1.4S, v5.8h \n"
"saddlp v2.4S, v6.8h \n"
"saddlp v3.4S, v7.8h \n"
"addv s4, v0.4S \n"
"addv s5, v1.4S \n"
"addv s6, v2.4S \n"
"addv s7, v3.4S \n"
"mov v0.s[0], v4.s[0] \n"
"mov v0.s[1], v5.s[0] \n"
"mov v0.s[2], v6.s[0] \n"
"mov v0.s[3], v7.s[0] \n"
"add v10.4s, v10.4s, v0.4s \n"
"b 2b\n"
"3: \n"
"mul v10.4s, v10.4s, v15.4s \n"
"st1 {v10.4s}, [x11], #16 \n"
"beq 1b \n"
"4: \n"
:
: [ dst ] "r"(dst), [ src ] "r"(src), [ row4 ] "r"(row4), [ col16 ] "r"(col16), [ filter_zp ] "r"(filter_zp)
: "x0", "x1", "x2", "x3", "x10", "x11", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15");
#else
for (int r = 0; r < row4; r++) {
int32_t tmp_value = 0;
for (int c = 0; c < col16; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM, c16div = c / C16NUM, c16mod = c % C16NUM;
int src_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod;
tmp_value += src[src_index];
}
} else {
/* normal int8 deconv */
}
#endif
return;
}
@ -199,18 +248,14 @@ int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, int32
if (matmul_func != NULL) {
matmul_func(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum);
} else {
/* normal int8 deconv */
MatMulInt8_16x4(input, weight, output, act_row, act_col, act_deep, input_sum, weight_sum);
}
return NNACL_OK;
}
int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t *out, int output_channel,
ConvParameter *conv_param, bool support_optimize) {
int error_code = NNACL_OK;
if (support_optimize) {
error_code = DeConvPostInt8C4(src, bias, tmp, out, output_channel, conv_param);
} else {
/* normal int8 deconv post */
}
/* optimize normal -> same layout (C4) */
int error_code = DeConvPostInt8C4(src, bias, tmp, out, output_channel, conv_param);
return error_code;
}

@ -29,7 +29,8 @@ extern "C" {
#endif
void DeConvPackWeightSum(int8_t *weight, int32_t *weight_sum, int32_t input_zp, int32_t filter_zp, int deep16, int col4,
bool suppport_opt);
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, int row4, int col16, bool suppport_opt);
void DeConvPackInputSum(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16,
bool suppport_opt);
void DeConvWeightTransInt8(int8_t *src, int8_t *dst, int input_channel, int output_channel, int plane,
bool support_optimize_);

@ -28,18 +28,66 @@ void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col)
}
}
void MatrixPack4x16UnitInt8(int8_t *src, int8_t *dst, int row, int col, int stride) {
for (int r = 0; r < row; r++) {
int8_t *src_r = src + r * stride;
int8_t *dst_r = dst + r * C16NUM;
memcpy(dst_r, src_r, col * sizeof(int8_t));
}
return;
}
void RowMajor2Row16x4MajorInt8(void *src_ptr, void *dst_ptr, int row, int col) {
/* Row-major to row16x4-major (block row-major) */
int col16 = UP_ROUND(col, C16NUM);
for (int r = 0; r < row; r++) {
int r4div = r / C4NUM;
int r4mod = r % C4NUM;
for (int c = 0; c < col; c++) {
int c16div = c / C16NUM;
int c16mod = c % C16NUM;
int src_index = r * col + c;
int dst_index = r4div * C4NUM * col16 + c16div * C16NUM * C4NUM + r4mod * C16NUM + c16mod;
((int8_t *)dst_ptr)[dst_index] = ((int8_t *)src_ptr)[src_index];
size_t row_4div = row / C4NUM * C4NUM;
size_t row_4res = row - row_4div;
size_t col_16div = col / C16NUM * C16NUM;
size_t col_16res = col - col_16div;
int8_t *src_r = (int8_t *)src_ptr;
int8_t *dst_r = (int8_t *)dst_ptr;
for (int ri = 0; ri < row_4div; ri += C4NUM) {
for (int ci = 0; ci < col_16div; ci += C16NUM) {
#ifdef ENABLE_ARM64
int8_t *src_c = src_r + ci;
int8_t *dst_c = dst_r + ci * C4NUM;
asm volatile(
"mov x10, %[src_c] \n"
"mov x11, %[dst_c] \n"
"ld1 {v0.16b}, [x10], %[col]\n"
"ld1 {v1.16b}, [x10], %[col]\n"
"ld1 {v2.16b}, [x10], %[col]\n"
"ld1 {v3.16b}, [x10], %[col]\n"
"st1 {v0.16b}, [x11], #16\n"
"st1 {v1.16b}, [x11], #16\n"
"st1 {v2.16b}, [x11], #16\n"
"st1 {v3.16b}, [x11], #16\n"
:
: [ dst_c ] "r"(dst_c), [ src_c ] "r"(src_c), [ col ] "r"(col)
: "x10", "x11", "v0", "v1", "v2", "v3");
#else
MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, C4NUM, C16NUM, col);
#endif
}
if (col != col_16div) {
MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, C4NUM, col_16res, col);
}
src_r += C4NUM * col;
dst_r += C4NUM * col16;
}
if (row != row_4div) {
for (int ci = 0; ci < col_16div; ci += C16NUM) {
MatrixPack4x16UnitInt8(src_r + ci, dst_r + ci * C4NUM, row_4res, C16NUM, col);
}
if (col != col_16div) {
MatrixPack4x16UnitInt8(src_r + col_16div, dst_r + col_16div * C4NUM, row_4res, col_16res, col);
}
}
return;
@ -74,7 +122,7 @@ void MatMulInt8(const int8_t *a, const int8_t *b, int32_t *c, const int row8, co
}
}
void MatMulOptR4Int8(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias) {
/* row4x16-major * row16x4-major => row4x4-major */
for (int r = 0; r < row_4; r++) {

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MATMUL_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_INT8_MATMUL_H_
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/matmul_parameter.h"
@ -25,7 +26,7 @@ extern "C" {
#endif
void MatMulInt8(const int8_t *a, const int8_t *b, int *c, const int row8, const int col8, const int deep,
const int a_zp, const int b_zp);
void MatMulOptR4Int8(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int col_4, int deep_16,
const int *input_sum, const int *bias);
void RowMajor2Row8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);
void RowMajor2Col8MajorInt8(int8_t *src_ptr, int8_t *dst_ptr, int row, int col);

@ -107,6 +107,32 @@ TEST_F(TestDeconvInt8, PackWeight2) {
CompareOutputData(dst, co, 528, 1);
}
TEST_F(TestDeconvInt8, PackInputTest1) {
/* 6 x 20 */
int8_t in[] = {40, 24, 94, 122, 67, 34, -89, 31, -43, 121, 48, -54, 44, -91, 35, 89, -37, 114, -8, 103,
-22, 32, 26, 112, -92, -23, 43, 9, 81, 118, -73, -54, 65, -99, 51, -90, 121, -62, 119, -93,
21, -92, -1, -82, -71, -54, 63, -93, 92, -93, 99, 122, -104, -16, -8, -32, 90, -126, 51, 91,
4, 70, -7, 116, 99, 81, -79, 124, -14, 28, 97, 9, -97, 99, 88, -15, 54, 26, 77, -25,
113, 119, 119, -75, -17, 7, 7, 1, 69, 66, 40, -13, 80, -115, -98, -8, -17, 31, 88, 65,
-1, -15, -98, 77, 56, 119, -20, -32, -54, -58, -16, 52, 121, 126, -33, 43, 92, -34, -17, -52};
int8_t co[] = {40, 24, 94, 122, 67, 34, -89, 31, -43, 121, 48, -54, 44, -91, 35, 89, -22, 32, 26, 112,
-92, -23, 43, 9, 81, 118, -73, -54, 65, -99, 51, -90, 21, -92, -1, -82, -71, -54, 63, -93,
92, -93, 99, 122, -104, -16, -8, -32, 4, 70, -7, 116, 99, 81, -79, 124, -14, 28, 97, 9,
-97, 99, 88, -15, -37, 114, -8, 103, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
121, -62, 119, -93, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 90, -126, 51, 91,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 54, 26, 77, -25, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 113, 119, 119, -75, -17, 7, 7, 1, 69, 66, 40, -13,
80, -115, -98, -8, -1, -15, -98, 77, 56, 119, -20, -32, -54, -58, -16, 52, 121, 126, -33, 43,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -17, 31, 88, 65, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 92, -34, -17, -52, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
int8_t dst[8 * 32] = {0};
RowMajor2Row16x4MajorInt8(in, dst, 6, 20);
CompareOutputData(dst, co, 8 * 32, 1);
}
TEST_F(TestDeconvInt8, MatMulTest1) {
int8_t a_row_major_10_12[] = {
-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, -41, 117, 62, -76, -77, -111, 88, 105,
@ -155,6 +181,30 @@ TEST_F(TestDeconvInt8, MatMulTest1) {
CompareOutputData(out_row_major, co_row_major_10_18, 180, 1);
}
TEST_F(TestDeconvInt8, InputSumTest1) {
int8_t packed_a[] = {
-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, 15, 15, 15, 15, -41, 117, 62, -76, -77, -111,
88, 105, 68, 105, -74, 13, 15, 15, 15, 15, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65,
15, 15, 15, 15, 57, -41, -51, 77, 1, 9, 73, -19, -36, 57, 81, -24, 15, 15, 15, 15, 40, 103,
112, 109, -41, -68, 57, 61, 55, -20, 3, 2, 15, 15, 15, 15, 17, -16, -31, 58, -4, 67, -4, -95,
-5, -72, 81, 15, 15, 15, 15, 15, -7, -16, -47, 112, 114, -26, -98, 53, 15, -49, 26, 19, 15, 15,
15, 15, 19, 8, -57, -35, -79, 118, 29, 21, 37, -48, 83, 7, 15, 15, 15, 15, 124, 113, -5, 15,
-8, 107, -65, -88, 50, -47, -80, -84, 15, 15, 15, 15, 3, -45, 92, 42, -20, -101, 106, -10, 89, 67,
55, 10, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15};
int32_t filter_zp = -20;
int32_t input_sum[12] = {0};
int32_t correct_input_sum[] = {-7100, -4780, 580, -4880, -9460, -1420, -3120, -3260, -1840, -6960, -4800, -4800};
DeConvPackInputSum(packed_a, input_sum, filter_zp, 12, 16, true);
CompareOutputData(input_sum, correct_input_sum, 12, 0);
int32_t input_sum_4[4] = {0};
int32_t correct_input_sum_4[] = {-18400, -13160, -7340, -12940};
DeConvPackInputSum(packed_a, input_sum_4, filter_zp, 4, 16 * 3, true);
CompareOutputData(input_sum_4, correct_input_sum_4, 4, 0);
}
TEST_F(TestDeconvInt8, MatMulOptTest1) {
int8_t a_src_ptr[] = {-6, 76, 32, 80, -73, 8, -85, -3, 114, 80, 30, 42, -41, 117, 62, -76, -77, -111,
88, 105, 68, 105, -74, 13, 51, 94, 31, -52, -92, -4, -35, -71, 101, -93, 46, -65,
@ -191,8 +241,7 @@ TEST_F(TestDeconvInt8, MatMulOptTest1) {
15, 15, 19, 8, -57, -35, -79, 118, 29, 21, 37, -48, 83, 7, 15, 15, 15, 15, 124, 113, -5, 15,
-8, 107, -65, -88, 50, -47, -80, -84, 15, 15, 15, 15, 3, -45, 92, 42, -20, -101, 106, -10, 89, 67,
55, 10, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15,
};
15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15};
RowMajor2Row16x4MajorInt8(a_src_ptr, packed_a, 10, 12);
CompareOutputData(packed_a, correct_packed_a, 16 * 12, 0);
@ -231,12 +280,6 @@ TEST_F(TestDeconvInt8, MatMulOptTest1) {
DeConvPackInputSum(packed_a, input_sum, filter_zp, 12, 16, true);
CompareOutputData(input_sum, correct_input_sum, 12, 0);
for (int i = 0; i < 12; i++) {
if (input_sum[i] != correct_input_sum[i]) {
printf("%d %d %d\n", i, input_sum[i], correct_input_sum[i]);
}
}
/*
* ---------------------- calculate weight_sum ------------------------- */
int32_t weight_sum[3 * 8] = {0};
@ -270,7 +313,8 @@ TEST_F(TestDeconvInt8, MatMulOptTest1) {
7894, -51, 0, 0, -4775, -29785, 0, 0, -12597, 4088, 0, 0, -17420, 1815,
0, 0, 15796, 3101, 0, 0, -37969, -10818, 0, 0, 12714, -7827, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0};
MatMulOptR4Int8(packed_a, packed_b, tmp_output, 12, 24, 16, input_sum, weight_sum);
MatMulInt8_16x4(packed_a, packed_b, tmp_output, 12, 24, 16, input_sum, weight_sum);
CompareOutputData(tmp_output, correct_tmp_output, 12 * 3 * 8, 0);
}

Loading…
Cancel
Save