!4861 [MS][LITE][Develop]add conv per channel support for int8

Merge pull request !4861 from lixian/master
pull/4861/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 90552c4933

@ -8,8 +8,8 @@
#endif
// void IndirectGemmInt8_4x4(int8_t *output, int8_t *input, int8_t *weight, int32_t *bias, size_t ksize, size_t ic4,
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp, size_t out_multiplier,
// size_t shift_before, size_t shift_after);
// size_t oc, size_t offset, int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
// int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after, size_t asymmetric, size_t per_channel);
// x0: output, x1: input, x2: weight, x3: bias, x4: kSize, x5: ic4, x6: oc, x7: offset
IndirectGemmInt8_4x4:
@ -36,18 +36,26 @@ IndirectGemmInt8_4x4:
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// r19 ~ r29 should be also preserved
// whereas our coding style do not permit such amount of parameters
sub sp, sp, #144
sub sp, sp, #176
st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
stp x19, x20, [sp], #16
stp x21, x22, [sp], #16
stp x23, x24, [sp], #16
ldr x15, [sp]
ldr w8, [sp, #8]
ldr w9, [sp, #16]
ldr w16, [sp, #24]
ldr w17, [sp, #32]
ldr w18, [sp, #40]
ldr w19, [sp, #48]
ldr x17, [sp, #32]
ldr x18, [sp, #40]
ldr x19, [sp, #48]
ldr x20, [sp, #56]
ldr x21, [sp, #64]
add x24, x6, #3
mov x23, #4
sdiv x23, x24, x23
mul x5, x4, x5
mov x4, #1
@ -189,12 +197,6 @@ IndirectGemmInt8_4x4:
sadalp v30.4s, v14.8h
sadalp v31.4s, v15.8h
// load sum
mov x20, x15
ld1r {v8.4s}, [x20], #4
ld1r {v9.4s}, [x20], #4
ld1r {v10.4s}, [x20], #4
ld1r {v11.4s}, [x20]
// pairwise add
addp v16.4s, v16.4s, v17.4s
addp v18.4s, v18.4s, v19.4s
@ -212,28 +214,51 @@ IndirectGemmInt8_4x4:
addp v20.4s, v20.4s, v22.4s
addp v24.4s, v24.4s, v26.4s
addp v28.4s, v28.4s, v30.4s
cbz x20, NoSum
// load sum
mov x22, x15
cbz x21, SymSum
ld1r {v8.4s}, [x22], x23
ld1r {v9.4s}, [x22], x23
ld1r {v10.4s}, [x22], x23
ld1r {v11.4s}, [x22]
b AddSum
SymSum:
ld1r {v8.4s}, [x22], #4
ld1r {v9.4s}, [x22], #4
ld1r {v10.4s}, [x22], #4
ld1r {v11.4s}, [x22]
AddSum:
sub v16.4s, v16.4s, v8.4s
sub v20.4s, v20.4s, v9.4s
sub v24.4s, v24.4s, v10.4s
sub v28.4s, v28.4s, v11.4s
NoSum:
add v16.4s, v16.4s, v12.4s
add v20.4s, v20.4s, v12.4s
add v24.4s, v24.4s, v12.4s
add v28.4s, v28.4s, v12.4s
dup v2.4s, w18
cbnz x21, PerChannel
ld1r {v2.4s}, [x18]
ld1r {v3.4s}, [x17]
ld1r {v4.4s}, [x19]
b QuantizeStart
PerChannel:
ld1 {v2.4s}, [x18]
ld1 {v3.4s}, [x17]
ld1 {v4.4s}, [x19]
QuantizeStart:
sqshl v16.4s, v16.4s, v2.4s
sqshl v20.4s, v20.4s, v2.4s
sqshl v24.4s, v24.4s, v2.4s
sqshl v28.4s, v28.4s, v2.4s
dup v3.4s, w17
sqrdmulh v16.4s, v16.4s, v3.4s
sqrdmulh v20.4s, v20.4s, v3.4s
sqrdmulh v24.4s, v24.4s, v3.4s
sqrdmulh v28.4s, v28.4s, v3.4s
dup v4.4s, w19
and v0.16b, v4.16b, v16.16b
sshr v0.4s, v0.4s, #31
sqadd v16.4s, v16.4s, v0.4s
@ -325,15 +350,25 @@ IndirectGemmInt8_4x4:
bne LoopKsize
subs x6, x6, #4
cbz x21, NoChannelForward
cbz x20, NoSumForward
add x15, x15, #16
NoSumForward:
add x17, x17, #16
add x18, x18, #16
add x19, x19, #16
NoChannelForward:
cbz x3, NoStepFowrard
add x3, x3, #16
NoStepFowrard:
bgt LoopOc
sub sp, sp, #144
sub sp, sp, #176
ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [sp], #64
ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [sp], #64
ldp x19, x20, [sp], #16
ldp x21, x22, [sp], #16
ldp x23, x24, [sp], #16
ret
#endif

@ -16,7 +16,7 @@
#include "nnacl/fp32/common_func.h"
#ifndef __aarch64__
#ifndef ENABLE_ARM64
void MatrixAdd(const float *a_ptr, const float *b_ptr, float *dst, size_t a_stride, size_t b_stride, size_t c_stride,
size_t row, size_t col) {
for (int r = 0; r < row; r++) {

@ -40,8 +40,8 @@ void IndirectGemmInt16to32_8x4(int32_t *dst, const int16_t *src, const int16_t *
size_t oc4, size_t offset);
void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *weight, const int32_t *bias, size_t ksize,
size_t ic4, size_t oc, size_t offset, const int32_t *input_sum, size_t act_min,
size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before,
size_t shift_after);
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
int32_t *shift_after, size_t asymmetric, size_t per_channel);
void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);

@ -29,14 +29,12 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0];
int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0];
int oc4 = UP_DIV(output_channel, C4NUM);
#ifdef __aarch64__
#ifdef ENABLE_ARM64
size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC;
size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
IndirectGemmInt8_4x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel,
output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier,
shift_before, shift_after);
// #elif defined(ENABLE_ARM32)
// IndirectGemmInt8_2x4(dst, src, weight, bias, UP_DIV(kernel_plane, C4NUM), ic4, output_channel,
// output_channel * sizeof(int8_t), input_sum, act_min, act_max, out_zp, out_multiplier,
// shift_before, shift_after);
shift_before, shift_after, asymmetric, per_channel);
#else
int tile_num = conv_param->tile_num_;
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
@ -124,8 +122,10 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const
int oc4 = UP_DIV(output_channel, C4NUM);
if (gemm_func != NULL) {
#ifdef __aarch64__
size_t asymmetric = conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC;
size_t per_channel = conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL;
gemm_func(dst, src, weight, bias, kernel_plane, ic4, output_channel, output_channel * sizeof(int8_t), input_sum,
act_min, act_max, out_zp, out_multiplier, shift_before, shift_after);
act_min, act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel);
#endif
} else {
int tile_num = conv_param->tile_num_;

@ -28,8 +28,8 @@
typedef void (*GEMM_FUNC)(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias, size_t ksize,
size_t ic4, size_t output_channel, size_t offset, const int32_t *input_sum, size_t act_min,
size_t act_max, size_t out_zp, size_t out_multiplier, size_t shift_before,
size_t shift_after);
size_t act_max, size_t out_zp, int32_t *out_multiplier, int32_t *shift_before,
int32_t *shift_after, size_t asymmetric, size_t per_channel);
#ifdef __cplusplus
extern "C" {

@ -22,11 +22,11 @@ extern "C" {
extern void IndirectGemmInt8_24x4_dp(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
size_t out_multiplier, size_t shift_before, size_t shift_after);
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
size_t asymmetric, size_t per_channel);
extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,
const int *input_sum, const int *bias);
#ifdef __cplusplus
}
#endif
@ -35,9 +35,10 @@ extern void MatMulOptR4Int8Neon64(const int8_t *a, const int8_t *b, int *dst, in
void IndirectGemmInt8_optimize_handler(int8_t *dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
size_t ksize, size_t ic4, size_t output_channel, size_t offset,
const int32_t *input_sum, size_t act_min, size_t act_max, size_t out_zp,
size_t out_multiplier, size_t shift_before, size_t shift_after) {
int32_t *out_multiplier, int32_t *shift_before, int32_t *shift_after,
size_t asymmetric, size_t per_channel) {
return IndirectGemmInt8_24x4_dp(dst, src, weight, bias, ksize, ic4, output_channel, offset, input_sum, act_min,
act_max, out_zp, out_multiplier, shift_before, shift_after);
act_max, out_zp, out_multiplier, shift_before, shift_after, asymmetric, per_channel);
}
void MatMulR4Int8_optimize_handler(const int8_t *a, const int8_t *b, int *dst, int row4, int col4, int deep16,

@ -879,8 +879,8 @@ void PackNHWCToNCHWFp32(const void *src, void *dst, int batches, int plane, int
const float *src_ptr = src_batch + hw * channel + c;
float *dst_ptr = dst_batch + c * plane + hw;
#ifdef ENABLE_ARM64
int srcStride = channel * 4;
int dstStride = plane * 4;
size_t srcStride = channel * sizeof(float);
size_t dstStride = plane * sizeof(float);
asm volatile(
"mov x10, %[src_ptr]\n"
"mov x11, %[dst_ptr]\n"

Loading…
Cancel
Save