!7862 [MSLITE][Develop] optimize arm cpu int8 op conv dw 3x3, add stride 2

Merge pull request !7862 from yangruoqi713/lite
pull/7862/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 23b07aee4c

@ -24,8 +24,8 @@ ConvDw3x3Int8Corner:
dup v26.4s, w9 // out_zp
ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16]
dup v28.4s, w9 // left_shift
ldr x8, [sp, #16]
dup v28.4s, w8 // left_shift
ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #32]
@ -85,26 +85,24 @@ ConvDw3x3Int8Corner:
smlal v23.4s, v3.4h, v7.4h
ld1 {v6.8h}, [x12], x13
smlal2 v24.4s, v3.8h, v7.8h
ld1 {v3.8b}, [x11], x5
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
cbz w8, RightShiftLoop
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZpLoop
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v3.8b}, [x11], x5
ssubl v3.8h, v3.8b, v25.8b
ld1 {v7.8h}, [x12], x13
RightShiftLoop:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZpLoop:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
@ -135,21 +133,20 @@ ConvDw3x3Int8Corner:
smlal v23.4s, v3.4h, v7.4h
smlal2 v24.4s, v3.8h, v7.8h
cbz w8, RightShift
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZp
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
RightShift:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZp:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s

@ -24,8 +24,8 @@ ConvDw3x3Int8Horizontal:
dup v26.4s, w9 // out_zp
ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16]
dup v28.4s, w9 // left_shift
ldr x8, [sp, #16]
dup v28.4s, w8 // left_shift
ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #32]
@ -109,26 +109,24 @@ ConvDw3x3Int8Horizontal:
smlal v23.4s, v17.4h, v19.4h
ld1 {v18.8h}, [x16], x13
smlal2 v24.4s, v17.8h, v19.8h
ld1 {v17.8b}, [x15], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x16], x13
cbz w8, RightShiftLoop
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZpLoop
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v17.8b}, [x15], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x16], x13
RightShiftLoop:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZpLoop:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
@ -163,21 +161,20 @@ ConvDw3x3Int8Horizontal:
smlal v23.4s, v17.4h, v19.4h
smlal2 v24.4s, v17.8h, v19.8h
cbz w8, RightShift
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZp
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
RightShift:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZp:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s

File diff suppressed because it is too large Load Diff

@ -24,8 +24,8 @@ ConvDw3x3Int8Vertical:
dup v26.4s, w9 // out_zp
ldr x9, [sp, #8]
dup v27.4s, w9 // out_multiplier
ldr x9, [sp, #16]
dup v28.4s, w9 // left_shift
ldr x8, [sp, #16]
dup v28.4s, w8 // left_shift
ldr x9, [sp, #24]
dup v29.4s, w9 // right_shift
ldr x9, [sp, #32]
@ -105,26 +105,24 @@ ConvDw3x3Int8Vertical:
smlal v23.4s, v17.4h, v19.4h
ld1 {v18.8h}, [x10], x13
smlal2 v24.4s, v17.8h, v19.8h
ld1 {v17.8b}, [x11], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x12], x13
cbz w8, RightShiftLoop
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZpLoop
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
ld1 {v17.8b}, [x11], x5
ssubl v17.8h, v17.8b, v25.8b
ld1 {v19.8h}, [x12], x13
RightShiftLoop:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZpLoop:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s
@ -159,21 +157,20 @@ ConvDw3x3Int8Vertical:
smlal v23.4s, v17.4h, v19.4h
smlal2 v24.4s, v17.8h, v19.8h
cbz w8, RightShift
sqshl v23.4s, v23.4s, v28.4s
sqshl v24.4s, v24.4s, v28.4s
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
b AddZp
and v21.16b, v29.16b, v23.16b
sshr v21.4s, v21.4s, #31
sqadd v23.4s, v23.4s, v21.4s
srshl v23.4s, v23.4s, v29.4s
and v22.16b, v29.16b, v24.16b
sshr v22.4s, v22.4s, #31
sqadd v24.4s, v24.4s, v22.4s
srshl v24.4s, v24.4s, v29.4s
RightShift:
sqrdmulh v23.4s, v23.4s, v27.4s
sqrdmulh v24.4s, v24.4s, v27.4s
sqrshl v23.4s, v23.4s, v29.4s
sqrshl v24.4s, v24.4s, v29.4s
AddZp:
add v23.4s, v23.4s, v26.4s
add v24.4s, v24.4s, v26.4s
smax v23.4s, v23.4s, v30.4s

@ -77,6 +77,10 @@ void ConvDw3x3Int8Neon64(int8_t *output, const int8_t *input, const int16_t *wei
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max);
void ConvDw3x3Int8Stride2(int8_t *output, const int8_t *input, const int16_t *weight, const int32_t *bias,
int input_col_size, int input_row_size, int channel, int output_h, int output_w, int8_t in_zp,
int32_t out_zp, int out_multiplier, int left_shift, int right_shift, int32_t acc_min,
int32_t acc_max);
void ConvDw3x3Int8Corner(int8_t *dst, const int8_t *src, const int16_t *weight, const int32_t *bias, size_t in_kh_step,
size_t in_kw_step, size_t channel, size_t in_zp, size_t out_zp, size_t out_multiplier,
size_t left_shift, size_t right_shift, size_t acc_min, size_t acc_max);

@ -139,11 +139,22 @@ void ConvDwInt8(int8_t *output_data, int32_t *row_buffer, const int8_t *input_da
/*conv depthwise int8 end*/
/*conv depthwise 3x3 int8 begin*/
bool CheckIfUse3X3(const ConvParameter *conv_param, int channel) {
bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 && conv_param->stride_h_ == 1 &&
conv_param->stride_w_ == 1 && (conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) &&
bool CheckIfUse3X3(const ConvParameter *conv_param) {
bool use_3x3 = conv_param->kernel_h_ == 3 && conv_param->kernel_w_ == 3 &&
(conv_param->stride_h_ == 1 || conv_param->stride_h_ == 2) &&
(conv_param->stride_w_ == 1 || conv_param->stride_w_ == 2) &&
conv_param->stride_h_ == conv_param->stride_w_ &&
(conv_param->pad_u_ == 0 || conv_param->pad_u_ == 1) &&
(conv_param->pad_l_ == 0 || conv_param->pad_l_ == 1) && conv_param->pad_u_ == conv_param->pad_l_ &&
conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (channel % 8 == 0);
conv_param->dilation_h_ == 1 && conv_param->dilation_w_ == 1 && (conv_param->input_channel_ % 8 == 0);
if (!use_3x3) {
return false;
}
const int out_w = conv_param->output_w_ - 1;
const int out_h = conv_param->output_h_ - 1;
const int in_w = out_w * conv_param->stride_w_ - conv_param->pad_l_ + conv_param->kernel_w_;
const int in_h = out_h * conv_param->stride_h_ - conv_param->pad_u_ + conv_param->kernel_h_;
use_3x3 = in_w <= (conv_param->input_w_ + conv_param->pad_l_) && in_h <= (conv_param->input_h_ + conv_param->pad_u_);
return use_3x3;
}
@ -206,8 +217,14 @@ void ConvDw3x3Int8Block(int8_t *output, const int8_t *buffer, const int16_t *wei
int32_t acc_max, int stride) {
for (; start_c <= end_c - 8; start_c += 8) {
#ifdef ENABLE_ARM64
ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
if (stride == 1) {
ConvDw3x3Int8Neon64(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
} else {
ConvDw3x3Int8Stride2(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max);
}
#else
ConvDw3x3Int8Window(output, buffer, weight, bias, col_size, row_size, channel, output_h, output_w, in_zp, out_zp,
out_multiplier, left_shift, right_shift, acc_min, acc_max, stride);

@ -24,7 +24,7 @@
extern "C" {
#endif
bool CheckIfUse3X3(const ConvParameter *conv_param, int channel);
bool CheckIfUse3X3(const ConvParameter *conv_param);
void ConvDwInt8(int8_t *output_data, int32_t *output_row, const int8_t *input_data, const int16_t *weight_data,
const int32_t *bias_data, const ConvParameter *conv_param, int task_id);

@ -168,22 +168,26 @@ kernel::LiteKernel *CpuConvDwInt8KernelCreator(const std::vector<lite::Tensor *>
const mindspore::lite::PrimitiveC *primitive) {
MS_ASSERT(opParameter != nullptr);
MS_ASSERT(desc.type == schema::PrimitiveType_DepthwiseConv2D);
kernel::LiteKernel *kernel;
kernel::LiteKernel *kernel = nullptr;
auto act_quant_size =
MSMAX(inputs[kInputIndex]->GetQuantParams().size(), outputs[kOutputIndex]->GetQuantParams().size());
if (act_quant_size == 1) { // per tensor
auto conv_parm = reinterpret_cast<ConvParameter *>(opParameter);
auto channel = inputs[kWeightIndex]->shape()[0];
auto conv_param = reinterpret_cast<ConvParameter *>(opParameter);
if (primitive != nullptr && primitive->GetInferFlag()) {
conv_param->input_h_ = inputs[kInputIndex]->Height();
conv_param->input_w_ = inputs[kInputIndex]->Width();
conv_param->input_channel_ = inputs[kInputIndex]->Channel();
conv_param->output_h_ = outputs[kOutputIndex]->Height();
conv_param->output_w_ = outputs[kOutputIndex]->Width();
}
auto weight_quant_size = inputs[kWeightIndex]->GetQuantParams().size();
if (CheckIfUse3X3(conv_parm, channel) && weight_quant_size == 1) {
if (CheckIfUse3X3(conv_param) && weight_quant_size == 1) {
#ifdef ENABLE_ARM64
kernel =
new (std::nothrow) kernel::ConvolutionDepthwise3x3Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
#else
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
#endif
} else {
}
if (kernel == nullptr) {
kernel =
new (std::nothrow) kernel::ConvolutionDepthwiseInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
}

@ -1,2 +0,0 @@
ssd-10.onnx
efficientnet-lite4-11.onnx
Loading…
Cancel
Save