[MSLITE][Develop] optimize arm cpu int8 op conv dw 3x3: add assembly

pull/7563/head
yangruoqi713 5 years ago
parent a5a5ef9617
commit 89e83b92d0

@ -0,0 +1,106 @@
#ifdef __aarch64__
.text
.align 5
.global ConvDw3x3BorderPixelInt8
#ifndef __APPLE__
.type ConvDw3x3BorderPixelInt8, %function
#endif
// void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
// size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp,
// size_t out_multiplier, size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max) {
// x0: dst, x1: src, x2: weight, x3: bias, x4: height, x5: width, x6: in_kh_step, x7: in_kw_step,
// x8: channel, x9: in_zp, x10: out_zp, x11: out_multiplier, x12: left_shift, x13: right_shift
// x14: acc_min, x15: acc_max
ConvDw3x3BorderPixelInt8:
// registers v8 ~ v15 must be preserved by a callee across subroutine calls, according to
// https://github.com/ARM-software/abi-aa/blob/master/aapcs64/aapcs64.rst#simd-and-floating-point-registers
// x19 ~ x29 should be also preserved
// whereas our coding style do not permit such amount of parameters
ldr x8, [sp]
ldrb w9, [sp, #8]
dup v25.8b, w9 // in_zp
ldr x9, [sp, #16]
dup v26.4s, w9 // out_zp
ldr x9, [sp, #24]
dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #32]
dup v28.4s, w9 // left_shift
ldr x9, [sp, #40]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #48]
dup v30.4s, w9 // acc_min
ldr x9, [sp, #56]
dup v31.4s, w9 // acc_max
mov x9, #2
mul x13, x8, x9 // x8 * 2
mov x9, #3
mul x14, x13, x9 // x8 * 3 * 2
LoopC:
mov x9, x1
mov x10, x2
mov x17, x4 // height
ld1 {v5.4s}, [x3], #16
mov v3.16b, v5.16b
ld1 {v6.4s}, [x3], #16
mov v4.16b, v6.16b
LoopH:
mov x11, x9
mov x12, x10
mov x18, x5 // width
LoopW:
ld1 {v0.8b}, [x11], x7
ssubl v1.8h, v0.8b, v25.8b
ld1 {v2.8h}, [x12], x13 // weight
smlal v3.4s, v1.4h, v2.4h
smlal2 v4.4s, v1.8h, v2.8h
subs x18, x18, #1
bne LoopW
subs x17, x17, #1
add x9, x9, x6
add x10, x10, x14
bne LoopH
sqshl v3.4s, v3.4s, v28.4s
sqshl v4.4s, v4.4s, v28.4s
sqrdmulh v3.4s, v3.4s, v27.4s
sqrdmulh v4.4s, v4.4s, v27.4s
and v12.16b, v29.16b, v3.16b
sshr v12.4s, v12.4s, #31
sqadd v3.4s, v3.4s, v12.4s
srshl v3.4s, v3.4s, v29.4s
and v11.16b, v29.16b, v4.16b
sshr v11.4s, v11.4s, #31
sqadd v4.4s, v4.4s, v11.4s
srshl v4.4s, v4.4s, v29.4s
add v3.4s, v3.4s, v26.4s
add v4.4s, v4.4s, v26.4s
smax v3.4s, v3.4s, v30.4s
smax v4.4s, v4.4s, v30.4s
smin v3.4s, v3.4s, v31.4s
smin v4.4s, v4.4s, v31.4s
sqxtn v3.4h, v3.4s
sqxtn v4.4h, v4.4s
sqxtn v3.8b, v3.8h
sqxtn v4.8b, v4.8h
st1 {v3.s}[0], [x0], #4
st1 {v4.s}[0], [x0], #4
add x1, x1, #8
add x2, x2, #16
sub x8, x8, #8
cmp x8, #8
bge LoopC
ret
#endif

@ -27,7 +27,7 @@ ConvDwInt8PostAlign4:
dup v30.4s, w7
dup v31.4s, w8
cmp x2, 16
cmp x2, #16
blt LoopDepth8
LoopDepth16:

@ -67,8 +67,11 @@ void IndirectGemmInt8_4x4(int8_t *output, const int8_t *input, const int8_t *wei
void DeconvDwInt8Center(int32_t *dst, const int16_t *src, const int16_t *weight, size_t height, size_t width,
size_t kernel_h, size_t kernel_w, size_t out_h_step, size_t block_channel, size_t in_sh_step,
size_t in_sw_step, size_t in_kh_step, size_t in_kw_step);
void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t height,
size_t width, size_t in_kh_step, size_t in_kw_step, size_t channel, size_t in_zp,
size_t out_zp, size_t out_multiplier, size_t left_shift, size_t right_shift,
size_t acc_min, size_t acc_max);
#endif
#ifdef __cplusplus
}
#endif

@ -140,9 +140,13 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da
/*conv depthwise 3x3 int8 begin*/
bool CheckIfUse3X3(const ConvParameter *conv_param, int channel) {
bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 &&
conv_param->stride_w_ == 1 && conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 &&
(channel % 8 == 0);
bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 &&
(conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) &&
(conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) &&
conv_param->stride_h_ == conv_param->stride_w_ &&
(conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) &&
(conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && conv_param->pad_u_ == conv_param->pad_l_ &&
conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (channel % 8 == 0);
return use_3x3;
}
@ -159,10 +163,10 @@ void InitInputBuffer(int8_t *buffer, const int8_t *input, const ConvParameter *c
}
}
// stride 1
void ConvDw3x3Int8Window(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int col_size,
int row_size, int channel, int output_h, int output_w, int8_t in_zp, int32_t out_zp,
int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) {
int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max,
int stride) {
for (int w = 0; w < output_w; w++) {
int tmp_buffer[C8NUM];
for (int i = 0; i < C8NUM; i++) {
@ -195,17 +199,17 @@ void ConvDw3x3Int8Window(int8_t *output, const int8_t *buffer, const int16_t *we
*output_tmp++ = (tmp_buffer[c]);
}
output += channel;
buffer += col_size;
buffer += col_size * stride;
}
}
void ConvDw3x3Int8Block(int8_t *output, const int8_t *buffer, const int16_t *weight, const int32_t *bias, int start_c,
int end_c, int col_size, int row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max) {
int32_t acc_max, int stride) {
for (; start_c <= end_c - 8; start_c += 8) {
ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
out_multiplier, left_shift, right_shift, acc_min, acc_max, stride);
output += 8;
buffer += 8;
weight += 8;
@ -236,7 +240,7 @@ void ConvDw3x3Int8Row(int8_t *output, int8_t *buffer, const int8_t *input, const
InitInputBuffer(buffer, input_ptr, conv_param, block_input_h, block_input_w);
ConvDw3x3Int8Block(output_ptr, buffer, weight_ptr, bias_ptr, 0, 64, 64, ih_offset, conv_param->input_channel_,
block_output_h, block_output_w, in_zp, out_zp, out_multiplier, left_shift, right_shift,
acc_min, acc_max);
acc_min, acc_max, conv_param->stride_h_);
output_ptr += 64;
input_ptr += 64;
weight_ptr += 64;
@ -246,7 +250,7 @@ void ConvDw3x3Int8Row(int8_t *output, int8_t *buffer, const int8_t *input, const
ConvDw3x3Int8Block(output_ptr, input_ptr, weight_ptr, bias_ptr, c, conv_param->input_channel_,
conv_param->input_channel_, conv_param->input_w_ * conv_param->input_channel_,
conv_param->input_channel_, block_output_h, block_output_w, in_zp, out_zp, out_multiplier,
left_shift, right_shift, acc_min, acc_max);
left_shift, right_shift, acc_min, acc_max, conv_param->stride_h_);
output += block_output_w * conv_param->input_channel_;
input += conv_param->stride_w_ * block_output_w * conv_param->input_channel_;
}
@ -255,19 +259,20 @@ void ConvDw3x3Int8Row(int8_t *output, int8_t *buffer, const int8_t *input, const
if (left_width > 0) {
ConvDw3x3Int8Block(output, input, weight, bias, 0, conv_param->input_channel_, conv_param->input_channel_,
conv_param->input_w_ * conv_param->input_channel_, conv_param->input_channel_, block_output_h,
left_width, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max);
left_width, in_zp, out_zp, out_multiplier, left_shift, right_shift, acc_min, acc_max,
conv_param->stride_h_);
}
}
void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, int task_id) {
int step_oh = UP_DIV(conv_param->output_h_, conv_param->thread_num_);
int start_oh = step_oh * task_id;
int end_oh = MSMIN(start_oh + step_oh, conv_param->output_h_);
int start_ow = MSMAX(0, conv_param->pad_l_);
int end_ow = conv_param->output_w_ - conv_param->pad_l_;
start_oh = MSMAX(start_oh, conv_param->pad_u_);
end_oh = MSMIN(conv_param->output_h_ - conv_param->pad_u_, end_oh);
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
int task_id) {
int output_h = sliding->bottom_ - sliding->top_;
int step_oh = UP_DIV(output_h, conv_param->thread_num_);
int start_oh = step_oh * task_id + sliding->top_;
int end_oh = MSMIN(start_oh + step_oh, sliding->bottom_);
int start_ow = sliding->left_;
int end_ow = sliding->right_;
int block_output_h = 1;
int block_output_w = conv_param->stride_w_ == 1 ? 30 : 14;
@ -293,6 +298,7 @@ void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data
}
}
#ifndef ENABLE_ARM64
void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int height,
int width, int in_kh_step, int in_kw_step, int channel, int8_t in_zp, int32_t out_zp,
int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t acc_max) {
@ -329,6 +335,7 @@ void ConvDw3x3BorderPixelInt8(int8_t *dst, const int8_t *src, const int16_t *wei
}
}
}
#endif
void ConvDw3x3BorderInt8(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, int top,
int bottom, int left, int right, const ConvParameter *conv_param,

@ -33,7 +33,8 @@ void ConvDw3x3PadInt8(int8_t *output_data, const int8_t *input_data, const int16
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding);
void ConvDw3x3Int8(int8_t *output_data, int8_t *buffer, const int8_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, int task_id);
const int32_t *bias_data, const ConvParameter *conv_param, const SlidingWindowParam *sliding,
int task_id);
void ConvDwSWInt8(int8_t *output_data, const int8_t *input_data, const int16_t *weight_data, const int32_t *bias_data,
int8_t *input_zp, int32_t *output_zp, const ConvParameter *conv_param,

@ -125,7 +125,7 @@ int ConvolutionDepthwise3x3Int8CPUKernel::ReSize() {
int ConvolutionDepthwise3x3Int8CPUKernel::Execute(int task_id) {
auto buffer = buffer_ + 64 * 10 * 10 * task_id;
ConvDw3x3Int8(output_ptr_, buffer, input_ptr_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_,
task_id);
sliding_, task_id);
return RET_OK;
}
@ -167,7 +167,8 @@ int ConvolutionDepthwise3x3Int8CPUKernel::Run() {
auto output_tensor = out_tensors_.at(kOutputIndex);
output_ptr_ = reinterpret_cast<int8_t *>(output_tensor->MutableData());
if (conv_param_->pad_l_ == 1 && conv_param_->pad_u_ == 1) {
if (sliding_->top_ > 0 || sliding_->bottom_ < conv_param_->output_h_ || sliding_->left_ > 0 ||
sliding_->right_ < conv_param_->output_w_) {
ConvDw3x3PadInt8(output_ptr_, input_ptr_, packed_weight_, reinterpret_cast<int32_t *>(bias_data_), conv_param_,
sliding_);
}

Loading…
Cancel
Save