[MSLITE][Develop] optimize arm cpu int8 op conv dw 3x3, add stride 2 assembly

pull/7862/head
yangruoqi713 5 years ago
parent 9fc0218c56
commit 6273bdaedb

@ -24,8 +24,8 @@ ConvDw3x3Int8Corner:
dup v26.4s, w9 // out_zp dup v26.4s, w9 // out_zp
ldr x9, [sp, #8] ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16] ldr x8, [sp, #16]
dup v28.4s, w9 // left_shift dup v28.4s, w8 // left_shift
ldr x9, [sp, #24] ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift dup v29.4s, w9 // right_shift
ldr x9, [sp, #32] ldr x9, [sp, #32]
@ -85,26 +85,24 @@ ConvDw3x3Int8Corner:
smlal v23.4s, v3.4h, v7.4h smlal v23.4s, v3.4h, v7.4h
ld1 {v6.8h}, [x12], x13 ld1 {v6.8h}, [x12], x13
smlal2 v24.4s, v3.8h, v7.8h smlal2 v24.4s, v3.8h, v7.8h
ld1 {v3.8b}, [x11], x5
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
cbz w8, RightShiftLoop
sqshl v23.4s, v23.4s, v28.4s sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s sqrdmulh v24.4s, v24.4s, v27.4s
b AddZpLoop
and v21.16b, v29.16b, v23.16b RightShiftLoop:
sshr v21.4s, v21.4s, #31 sqrdmulh v23.4s, v23.4s, v27.4s
sqadd v23.4s, v23.4s, v21.4s sqrdmulh v24.4s, v24.4s, v27.4s
srshl v23.4s, v23.4s, v29.4s sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v3.8b}, [x11], x5
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
AddZpLoop:
add v23.4s, v23.4s, v26.4s add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s smax v23.4s, v23.4s, v30.4s
@ -135,21 +133,20 @@ ConvDw3x3Int8Corner:
smlal v23.4s, v3.4h, v7.4h smlal v23.4s, v3.4h, v7.4h
smlal2 v24.4s, v3.8h, v7.8h smlal2 v24.4s, v3.8h, v7.8h
cbz w8, RightShift
sqshl v23.4s, v23.4s, v28.4s sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s sqrdmulh v24.4s, v24.4s, v27.4s
b AddZp
and v21.16b, v29.16b, v23.16b RightShift:
sshr v21.4s, v21.4s, #31 sqrdmulh v23.4s, v23.4s, v27.4s
sqadd v23.4s, v23.4s, v21.4s sqrdmulh v24.4s, v24.4s, v27.4s
srshl v23.4s, v23.4s, v29.4s sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
AddZp:
add v23.4s, v23.4s, v26.4s add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s smax v23.4s, v23.4s, v30.4s

@ -24,8 +24,8 @@ ConvDw3x3Int8Horizontal:
dup v26.4s, w9 // out_zp dup v26.4s, w9 // out_zp
ldr x9, [sp, #8] ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16] ldr x8, [sp, #16]
dup v28.4s, w9 // left_shift dup v28.4s, w8 // left_shift
ldr x9, [sp, #24] ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift dup v29.4s, w9 // right_shift
ldr x9, [sp, #32] ldr x9, [sp, #32]
@ -109,26 +109,24 @@ ConvDw3x3Int8Horizontal:
smlal v23.4s, v17.4h, v19.4h smlal v23.4s, v17.4h, v19.4h
ld1 {v18.8h}, [x16], x13 ld1 {v18.8h}, [x16], x13
smlal2 v24.4s, v17.8h, v19.8h smlal2 v24.4s, v17.8h, v19.8h
ld1 {v17.8b}, [x15], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x16], x13
cbz w8, RightShiftLoop
sqshl v23.4s, v23.4s, v28.4s sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s sqrdmulh v24.4s, v24.4s, v27.4s
b AddZpLoop
and v21.16b, v29.16b, v23.16b RightShiftLoop:
sshr v21.4s, v21.4s, #31 sqrdmulh v23.4s, v23.4s, v27.4s
sqadd v23.4s, v23.4s, v21.4s sqrdmulh v24.4s, v24.4s, v27.4s
srshl v23.4s, v23.4s, v29.4s sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v17.8b}, [x15], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x16], x13
AddZpLoop:
add v23.4s, v23.4s, v26.4s add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s smax v23.4s, v23.4s, v30.4s
@ -163,21 +161,20 @@ ConvDw3x3Int8Horizontal:
smlal v23.4s, v17.4h, v19.4h smlal v23.4s, v17.4h, v19.4h
smlal2 v24.4s, v17.8h, v19.8h smlal2 v24.4s, v17.8h, v19.8h
cbz w8, RightShift
sqshl v23.4s, v23.4s, v28.4s sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s sqrdmulh v24.4s, v24.4s, v27.4s
b AddZp
and v21.16b, v29.16b, v23.16b RightShift:
sshr v21.4s, v21.4s, #31 sqrdmulh v23.4s, v23.4s, v27.4s
sqadd v23.4s, v23.4s, v21.4s sqrdmulh v24.4s, v24.4s, v27.4s
srshl v23.4s, v23.4s, v29.4s sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
AddZp:
add v23.4s, v23.4s, v26.4s add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s smax v23.4s, v23.4s, v30.4s

File diff suppressed because it is too large Load Diff

@ -24,8 +24,8 @@ ConvDw3x3Int8Vertical:
dup v26.4s, w9 // out_zp dup v26.4s, w9 // out_zp
ldr x9, [sp, #8] ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16] ldr x8, [sp, #16]
dup v28.4s, w9 // left_shift dup v28.4s, w8 // left_shift
ldr x9, [sp, #24] ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift dup v29.4s, w9 // right_shift
ldr x9, [sp, #32] ldr x9, [sp, #32]
@ -105,26 +105,24 @@ ConvDw3x3Int8Vertical:
smlal v23.4s, v17.4h, v19.4h smlal v23.4s, v17.4h, v19.4h
ld1 {v18.8h}, [x10], x13 ld1 {v18.8h}, [x10], x13
smlal2 v24.4s, v17.8h, v19.8h smlal2 v24.4s, v17.8h, v19.8h
ld1 {v17.8b}, [x11], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x12], x13
cbz w8, RightShiftLoop
sqshl v23.4s, v23.4s, v28.4s sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s sqrdmulh v24.4s, v24.4s, v27.4s
b AddZpLoop
and v21.16b, v29.16b, v23.16b RightShiftLoop:
sshr v21.4s, v21.4s, #31 sqrdmulh v23.4s, v23.4s, v27.4s
sqadd v23.4s, v23.4s, v21.4s sqrdmulh v24.4s, v24.4s, v27.4s
srshl v23.4s, v23.4s, v29.4s sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v17.8b}, [x11], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x12], x13
AddZpLoop:
add v23.4s, v23.4s, v26.4s add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s smax v23.4s, v23.4s, v30.4s
@ -159,21 +157,20 @@ ConvDw3x3Int8Vertical:
smlal v23.4s, v17.4h, v19.4h smlal v23.4s, v17.4h, v19.4h
smlal2 v24.4s, v17.8h, v19.8h smlal2 v24.4s, v17.8h, v19.8h
cbz w8, RightShift
sqshl v23.4s, v23.4s, v28.4s sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s sqrdmulh v24.4s, v24.4s, v27.4s
b AddZp
and v21.16b, v29.16b, v23.16b RightShift:
sshr v21.4s, v21.4s, #31 sqrdmulh v23.4s, v23.4s, v27.4s
sqadd v23.4s, v23.4s, v21.4s sqrdmulh v24.4s, v24.4s, v27.4s
srshl v23.4s, v23.4s, v29.4s sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
AddZp:
add v23.4s, v23.4s, v26.4s add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s smax v23.4s, v23.4s, v30.4s

@ -77,6 +77,10 @@ void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *wei
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp, int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min, int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max); int32_t acc_max);
void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max);
void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step, void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier, size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max); size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max);

@ -139,11 +139,22 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da
/*conv depthwise int8 end*/ /*conv depthwise int8 end*/
/*conv depthwise 3x3 int8 begin*/ /*conv depthwise 3x3 int8 begin*/
bool CheckIfUse3X3(const ConvParameter *conv_param, int channel) { bool CheckIfUse3X3(const ConvParameter *conv_param) {
bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 && bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 &&
conv_param->stride_w_ == 1 && (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) && (conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) &&
(conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) &&
conv_param->stride_h_ == conv_param->stride_w_ &&
(conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) &&
(conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && conv_param->pad_u_ == conv_param->pad_l_ && (conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && conv_param->pad_u_ == conv_param->pad_l_ &&
conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (channel % 8 == 0); conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (conv_param->input_channel_ % 8 == 0);
if (!use_3x3) {
return false;
}
const int out_w = conv_param->output_w_ - 1;
const int out_h = conv_param->output_h_ - 1;
const int in_w = out_w * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_;
const int in_h = out_h * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_;
use_3x3 = in_w <= (conv_param->input_w_ + conv_param->pad_l_) && in_h <= (conv_param->input_h_ + conv_param->pad_u_);
return use_3x3; return use_3x3;
} }
@ -206,8 +217,14 @@ void ConvDw3x3Int8Block(int8_t *output, const int8_t *buffer, const int16_t *wei
int32_t acc_max, int stride) { int32_t acc_max, int stride) {
for (; start_c <= end_c - 8; start_c += 8) { for (; start_c <= end_c - 8; start_c += 8) {
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, if (stride == 1) {
out_multiplier, left_shift, right_shift, acc_min, acc_max); ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
} else {
ConvDw3x3Int8Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
}
#else #else
ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp, ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max, stride); out_multiplier, left_shift, right_shift, acc_min, acc_max, stride);

@ -24,7 +24,7 @@
extern "C" { extern "C" {
#endif #endif
bool CheckIfUse3X3(const ConvParameter *conv_param, int channel); bool CheckIfUse3X3(const ConvParameter *conv_param);
void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data, void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, int task_id); const int32_t *bias_data, const ConvParameter *conv_param, int task_id);

@ -168,22 +168,26 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *>
const mindspore::lite::PrimitiveC *primitive) { const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr); MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D); MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
kernel::LiteKernel *kernel; kernel::LiteKernel *kernel = nullptr;
auto act_quant_size = auto act_quant_size =
MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size()); MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size());
if (act_quant_size == 1) { // per tensor if (act_quant_size == 1) { // per tensor
auto conv_parm = reinterpret_cast<ConvParameter *>(opParameter); auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
auto channel = inputs[kWeightIndex]->shape()[0]; if (primitive != nullptr && primitive->GetInferFlag()) {
conv_param->input_h_ = inputs[kInputIndex]->Height();
conv_param->input_w_ = inputs[kInputIndex]->Width();
conv_param->input_channel_ = inputs[kInputIndex]->Channel();
conv_param->output_h_ = outputs[kOutputIndex]->Height();
conv_param->output_w_ = outputs[kOutputIndex]->Width();
}
auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size(); auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size();
if (CheckIfUse3X3(conv_parm, channel) && weight_quant_size == 1) { if (CheckIfUse3X3(conv_param) && weight_quant_size == 1) {
#ifdef ENABLE_ARM64 #ifdef ENABLE_ARM64
kernel = kernel =
new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive); new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
#else
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
#endif #endif
} else { }
if (kernel == nullptr) {
kernel = kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive); new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} }

@ -1,2 +0,0 @@
ssd-10.onnx
efficientnet-lite4-11.onnx
Loading…
Cancel
Save