!5818 [MSLITE][Develop] int8 conv 1x1 support per weight output-channel on x86

Merge pull request !5818 from ling/sr
pull/5818/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 55adbbeae8

File diff suppressed because it is too large Load Diff

@ -54,13 +54,16 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
ConvParameter *conv_param, GEMM_FUNC gemm_func);
// int8 convolution 1x1
void Conv1x1PreOpt(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, ConvParameter *conv_param);
void Conv1x1PreOptPeroc(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t output_channel, size_t plane_size, int32_t *filter_zp, size_t inputsum_stride);
void Conv1x1PreOptPert(const int8_t *src_input, int8_t *packed_input, int32_t *input_sum, size_t input_channel,
size_t plane_size, ConvParameter *conv_param);
void Conv1x1Int8(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep16, ConvParameter *conv_param);
const int32_t *bias, int row, int col, int deep16, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param);
void Conv1x1Int8Opt(const int8_t *packed_input, const int8_t *packed_weight, int8_t *dst, const int32_t *input_sum,
const int32_t *bias, int row, int col, int deep4, ConvParameter *conv_param,
MATMUL_OPT_R_FUNC matmul_func);
const int32_t *bias, int row, int col, int deep4, int32_t *left_shift, int32_t *right_shift,
int32_t *multiplier, ConvParameter *conv_param, MATMUL_OPT_R_FUNC matmul_func);
// int8 convolution 3x3
void Conv3x3Int8(int16_t *input_data, int16_t *transed_weight, const int32_t *bias_data, int8_t *output_data,

@ -186,8 +186,9 @@ void MatMulInt8_16x4(const int8_t *a, const int8_t *b, int *dst, int row_4, int
void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row, size_t col, size_t deep_16,
size_t stride, const int32_t *input_sum, const int32_t *bias, int32_t *left_shift,
int32_t *right_shift, int32_t *multiplier, int32_t output_zp, int32_t mini, int32_t maxi,
bool per_channel) {
/* row4x16-major * row16x4-major => (int8)row-major : per-channel */
bool peroc) {
/* support per-layer && weight per-channel */
/* row4x16-major * row16x4-major => (int8)row-major*/
for (int r = 0; r < row; r++) {
for (int c = 0; c < col; c++) {
int r4div = r / C4NUM, r4mod = r % C4NUM;
@ -200,12 +201,13 @@ void MatMulInt8_16x4_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row
size_t bi = c4div * deep_16 * C4NUM + d16div * C4NUM * C16NUM + c4mod * C16NUM + d16mod;
value = value + a[ai] * b[bi];
}
int32_t cur_input_sum = per_channel ? input_sum[c4div * UP_ROUND(row, C4NUM) + r * C4NUM + c4mod] : input_sum[r];
int32_t cur_input_sum =
peroc ? input_sum[c4div * UP_ROUND(row, C4NUM) * C4NUM + r * C4NUM + c4mod] : input_sum[r];
value -= cur_input_sum;
value += bias[c];
int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0];
int32_t cur_right_shift = per_channel ? right_shift[c] : right_shift[0];
int32_t cur_multiplier = per_channel ? multiplier[c] : multiplier[0];
int32_t cur_left_shift = peroc ? left_shift[c] : left_shift[0];
int32_t cur_right_shift = peroc ? right_shift[c] : right_shift[0];
int32_t cur_multiplier = peroc ? multiplier[c] : multiplier[0];
value = MultiplyByQuantizedMultiplier(value, cur_multiplier, cur_left_shift, cur_right_shift) + output_zp;
value = MSMIN(maxi, value);
value = MSMAX(mini, value);
@ -232,7 +234,8 @@ void MatMulInt8_8x8_r(const int8_t *a, const int8_t *b, int8_t *dst, size_t row,
size_t bi = c8div * deep_4 * C8NUM + d4div * C8NUM * C4NUM + c8mod * C4NUM + d4mod;
value = value + a[ai] * b[bi];
}
int32_t cur_input_sum = per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) + r * C8NUM + c8mod] : input_sum[r];
int32_t cur_input_sum =
per_channel ? input_sum[c8div * UP_ROUND(row, C8NUM) * C8NUM + r * C8NUM + c8mod] : input_sum[r];
value -= cur_input_sum;
value += bias[c];
int32_t cur_left_shift = per_channel ? left_shift[c] : left_shift[0];

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_NNACL_INT8_MATMUL_H_
#define MINDSPORE_LITE_NNACL_INT8_MATMUL_H_
#include <stdio.h>
#include <string.h>
#include "nnacl/op_base.h"
#include "nnacl/matmul_parameter.h"

@ -188,7 +188,7 @@ void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParam
}
void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp, size_t row4, size_t col16) {
/* optimize normal -> same layout */
/* normal matmul : 4x16 * 16x4 -> 4x4 */
#ifdef ENABLE_ARM64
asm volatile(
"mov x10, %[src] \n"
@ -260,62 +260,158 @@ void PackInputSum16x4PerLayer(const int8_t *src, int32_t *dst, int32_t filter_zp
return;
}
void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param) {
void PackInputSum16x4PerChannel(const int8_t *input_value, int32_t *input_sum, int32_t *filter_zp_ptr,
size_t plane_size, size_t input_channel, size_t output_channel) {
size_t hw4 = UP_ROUND(plane_size, C4NUM);
size_t ic16 = UP_ROUND(input_channel, C16NUM);
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
PackInputSum16x4PerLayer(input_value, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16);
} else {
for (int ri = 0; ri < plane_size; ri++) {
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci4div = ci / C4NUM, ci4mod = ci % C4NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_;
for (int di = 0; di < input_channel; di++) {
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
#ifdef ENABLE_ARM64
size_t oc_div4 = output_channel / C4NUM * C4NUM;
size_t oc_res4 = output_channel - oc_div4;
size_t inputsun_stride = hw4 * C4NUM * 4 - C4NUM * C4NUM * 4;
asm volatile(
"mov x10, %[input_value] \n"
"mov x11, %[input_sum] \n"
"mov x15, %[filter_zp_ptr] \n"
"mov x0, #0 \n" // row 4 count
"1: \n"
"cmp x0, %[hw4] \n"
"beq 11f \n"
"add x0, x0, #4\n"
"dup v10.4s, wzr \n"
"mov x2, #0 \n" // input deep count
"mov x16, x15 \n"
"2: \n"
"cmp x2, %[ic16] \n"
"beq 3f \n"
"add x2, x2, #16 \n"
"ld1 {v0.16b}, [x10], #16\n"
"ld1 {v1.16b}, [x10], #16\n"
"ld1 {v2.16b}, [x10], #16\n"
"ld1 {v3.16b}, [x10], #16\n"
"saddlp v4.8h, v0.16b \n"
"saddlp v5.8h, v1.16b \n"
"saddlp v6.8h, v2.16b \n"
"saddlp v7.8h, v3.16b \n"
"saddlp v0.4S, v4.8h \n"
"saddlp v1.4S, v5.8h \n"
"saddlp v2.4S, v6.8h \n"
"saddlp v3.4S, v7.8h \n"
"addv s4, v0.4S \n"
"addv s5, v1.4S \n"
"addv s6, v2.4S \n"
"addv s7, v3.4S \n"
"mov v0.s[0], v4.s[0] \n"
"mov v0.s[1], v5.s[0] \n"
"mov v0.s[2], v6.s[0] \n"
"mov v0.s[3], v7.s[0] \n"
"add v10.4s, v10.4s, v0.4s \n"
"b 2b \n"
"3: \n"
"mov x12, x11 \n" // tmp inputsm inputsum hw
"add x11, x11, #64 \n"
"mov x4, #0 \n" // oc count
"dup v1.4s, v10.s[0] \n"
"dup v2.4s, v10.s[1] \n"
"dup v3.4s, v10.s[2] \n"
"dup v4.4s, v10.s[3] \n"
"4: \n"
"cmp x4, %[oc_div4] \n"
"beq 6f \n"
"add x4, x4, #4\n"
"ld1 {v15.4s}, [x16], #16\n"
"mul v16.4s, v15.4s, v1.4s \n"
"mul v17.4s, v15.4s, v2.4s \n"
"mul v18.4s, v15.4s, v3.4s \n"
"mul v19.4s, v15.4s, v4.4s \n"
"st1 {v16.4s}, [x12], #16 \n"
"st1 {v17.4s}, [x12], #16 \n"
"st1 {v18.4s}, [x12], #16 \n"
"st1 {v19.4s}, [x12], #16 \n"
"add x12, x12, %[inputsun_stride] \n"
"b 4b \n"
"6: \n"
"cmp %[oc_res4], #0\n"
"beq 1b \n"
"dup v15.4s, wzr \n"
"cmp %[oc_res4], #1\n"
"beq 7f \n"
"cmp %[oc_res4], #2\n"
"beq 8f \n"
"cmp %[oc_res4], #3\n"
"beq 9f \n"
"7: \n"
"ld1 {v15.s}[0], [x16] \n"
"b 10f \n"
"8: \n"
"ld1 {v15.h}[0], [x16] \n"
"b 10f \n"
"9: \n"
"ld1 {v15.h}[0], [x16] \n"
"add x16, x16, #8 \n"
"ld1 {v15.s}[2], [x16] \n"
"b 10f \n"
"10: \n"
"mul v16.4s, v15.4s, v1.4s \n"
"mul v17.4s, v15.4s, v2.4s \n"
"mul v18.4s, v15.4s, v3.4s \n"
"mul v19.4s, v15.4s, v4.4s \n"
"st1 {v16.4s}, [x12], #16 \n"
"st1 {v17.4s}, [x12], #16 \n"
"st1 {v18.4s}, [x12], #16 \n"
"st1 {v19.4s}, [x12], #16 \n"
"b 1b \n"
"11: \n"
:
: [ input_value ] "r"(input_value), [ input_sum ] "r"(input_sum), [ filter_zp_ptr ] "r"(filter_zp_ptr),
[ hw4 ] "r"(hw4), [ ic16 ] "r"(ic16), [ oc_div4 ] "r"(oc_div4), [ oc_res4 ] "r"(oc_res4),
[ inputsun_stride ] "r"(inputsun_stride)
: "x0", "x2", "x4", "x10", "x11", "x12", "x15", "x16", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v10", "v15",
"v16", "v17", "v18", "v19");
#else
for (int ri = 0; ri < plane_size; ri++) {
int ri4div = ri / C4NUM, ri4mod = ri % C4NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci4div = ci / C4NUM, ci4mod = ci % C4NUM;
int32_t filter_zp = filter_zp_ptr[ci];
for (int di = 0; di < input_channel; di++) {
size_t di16div = di / C16NUM, di16mod = di % C16NUM;
int src_index = ri4div * C4NUM * ic16 + di16div * C16NUM * C4NUM + ri4mod * C16NUM + di16mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci4div * C4NUM * hw4 + ri * C4NUM + ci4mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
#endif
return;
}
void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param) {
size_t hw8 = UP_ROUND(plane_size, C8NUM);
size_t ic4 = UP_ROUND(input_channel, C4NUM);
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param) {
size_t hw4 = UP_ROUND(conv_param->input_h_ * conv_param->input_w_, C4NUM);
size_t ic16 = UP_ROUND(conv_param->input_channel_, C16NUM);
if (conv_param->conv_quant_arg_.filter_arg_num_ == 1) {
for (int r = 0; r < hw8; r++) {
int32_t tmp_value = 0;
for (int c = 0; c < ic4; c++) {
int r8div = r / C8NUM, r8mod = r % C8NUM, c4div = c / C4NUM, c4mod = c % C4NUM;
int src_index = r8div * C8NUM * ic4 + c4div * C8NUM * C4NUM + r8mod * C4NUM + c4mod;
tmp_value += input_value[src_index];
}
input_sum[r] = tmp_value * conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
}
PackInputSum16x4PerLayer(input, input_sum, conv_param->conv_quant_arg_.filter_quant_args_[0].zp_, hw4, ic16);
} else {
for (int ri = 0; ri < plane_size; ri++) {
int ri8div = ri / C8NUM, ri8mod = ri % C8NUM;
for (int ci = 0; ci < output_channel; ci++) {
int32_t tmp_sum_value = 0;
int ci8div = ci / C8NUM, ci8mod = ci % C8NUM;
int32_t filter_zp = conv_param->conv_quant_arg_.filter_quant_args_[ci].zp_;
for (int di = 0; di < input_channel; di++) {
size_t di4div = di / C4NUM, di4mod = di % C4NUM;
int src_index = ri8div * C8NUM * ic4 + di4div * C8NUM * C4NUM + ri8mod * C4NUM + di4mod;
tmp_sum_value += input_value[src_index];
}
int dst_index = ci8div * C8NUM * hw8 + ri * C8NUM + ci8mod;
input_sum[dst_index] = tmp_sum_value * filter_zp;
}
}
PackInputSum16x4PerChannel(input, input_sum, filter_zp, conv_param->input_h_ * conv_param->input_w_,
conv_param->input_channel_, conv_param->output_channel_);
}
return;
}

@ -17,6 +17,7 @@
#ifndef MINDSPORE_LITE_NNACL_PACK_H_
#define MINDSPORE_LITE_NNACL_PACK_H_
#include <stdio.h>
#ifdef ENABLE_NEON
#include <arm_neon.h>
#endif
@ -41,8 +42,7 @@ void Conv1x1InputPack(const void *src_ptr, void *dst_ptr, ConvParameter *conv_pa
void Pack1x1WeightFp32(const float *weight_data, float *packed_weight, ConvParameter *conv_param);
void PackInputSum16x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param);
void PackInputSum16x4Int8(const int8_t *input, int32_t *input_sum, int32_t *filter_zp, ConvParameter *conv_param);
void PackInputSum8x4Int8(const int8_t *input_value, int32_t *input_sum, size_t input_channel, size_t output_channel,
size_t plane_size, ConvParameter *conv_param);

@ -316,14 +316,12 @@ int ConvolutionBaseCPUKernel::SetQuantParam() {
MS_LOG(ERROR) << "Set Quant Multiplier Failed.";
return ret;
}
// now only consider per tensor for output
bool relu = conv_param_->act_type_ == ActType_Relu;
bool relu6 = conv_param_->act_type_ == ActType_Relu6;
CalculateActivationRangeQuantized(relu, relu6, conv_param_->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param_->conv_quant_arg_.output_quant_args_[0].scale_,
&conv_param_->conv_quant_arg_.out_act_min_[0],
&conv_param_->conv_quant_arg_.out_act_max_[0]);
return RET_OK;
}
int ConvolutionBaseCPUKernel::RestoreFilter(lite::Tensor *input_tensor) {

@ -16,6 +16,7 @@
#include "src/runtime/kernel/arm/int8/convolution_1x1_int8.h"
#include "src/runtime/runtime_api.h"
#include "src/common/file_utils.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_MEMORY_FAILED;
@ -41,6 +42,10 @@ Convolution1x1Int8CPUKernel::~Convolution1x1Int8CPUKernel() {
free(packed_weight_);
packed_weight_ = nullptr;
}
if (filter_peroc_ && filter_zp_ptr_ != nullptr) {
free(filter_zp_ptr_);
filter_zp_ptr_ = nullptr;
}
FreeResizeBuf();
FreeQuantParam();
}
@ -54,7 +59,7 @@ void Convolution1x1Int8CPUKernel::FreeResizeBuf() {
}
void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
support_optimize_ = true;
support_optimize_ = false;
matmul_func_ = MatMulInt8_8x8_r;
#ifdef ENABLE_ARM64
void *optimize_op_handler = OptimizeModule::GetInstance()->optimized_op_handler_;
@ -73,6 +78,10 @@ void Convolution1x1Int8CPUKernel::CheckSupportOptimize() {
support_optimize_ = false;
matmul_func_ = nullptr;
}
if (filter_peroc_) {
support_optimize_ = false;
}
#endif
return;
}
@ -118,14 +127,23 @@ int Convolution1x1Int8CPUKernel::InitWeightBias() {
int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_;
for (int oc = 0; oc < output_channel; oc++) {
int32_t weight_sum_value = 0;
int32_t filter_zp = (conv_param_->conv_quant_arg_.filter_arg_num_ == 1)
? conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_
: conv_param_->conv_quant_arg_.filter_quant_args_[oc].zp_;
int32_t filter_zp = (filter_peroc_) ? conv_param_->conv_quant_arg_.filter_quant_args_[oc].zp_
: conv_param_->conv_quant_arg_.filter_quant_args_[0].zp_;
for (int ic = 0; ic < input_channel; ic++) {
weight_sum_value += weight[oc * input_channel + ic];
}
bias_data[oc] += filter_zp * input_zp * input_channel - weight_sum_value * input_zp;
}
if (filter_peroc_) {
filter_zp_ptr_ = reinterpret_cast<int32_t *>(malloc(output_channel * sizeof(int32_t)));
if (filter_zp_ptr_ == nullptr) {
return RET_ERROR;
}
for (int fi = 0; fi < output_channel; fi++) {
filter_zp_ptr_[fi] = conv_param_->conv_quant_arg_.filter_quant_args_[fi].zp_;
}
}
return RET_OK;
}
@ -136,14 +154,16 @@ int Convolution1x1Int8CPUKernel::Init() {
return RET_ERROR;
}
CheckSupportOptimize();
auto ret = SetQuantParam();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Set quant param failed.";
return ret;
}
filter_peroc_ = (conv_param_->conv_quant_arg_.filter_arg_num_ != 1);
CheckSupportOptimize();
ret = InitWeightBias();
if (ret != RET_OK) {
MS_LOG(ERROR) << "Init weight bias failed.";
@ -229,14 +249,17 @@ void Convolution1x1Int8CPUKernel::Pre1x1Trans(int8_t *src_input, int8_t *src_out
ParallelLaunch(THREAD_POOL_DEFAULT, Convolution1x1Int8Pre, this, thread_count_hw_);
} else {
RowMajor2Row16x4MajorInt8(input_ptr_, packed_input_, matmul_param_->row_, matmul_param_->deep_);
PackInputSum16x4Int8(packed_input_, input_sum_, matmul_param_->deep_, matmul_param_->col_, matmul_param_->row_,
conv_param_);
PackInputSum16x4Int8(packed_input_, input_sum_, filter_zp_ptr_, conv_param_);
}
return;
}
int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
int32_t *cur_input_sum = input_sum_;
int32_t *cur_left_shift = conv_param_->conv_quant_arg_.left_shift_;
int32_t *cur_right_shift = conv_param_->conv_quant_arg_.right_shift_;
int32_t *cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_;
if (support_optimize_) {
int cur_stride = thread_stride_ * C8NUM;
int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C8NUM;
@ -244,10 +267,17 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
if (cur_oc <= 0) {
return RET_OK;
}
if (filter_peroc_) {
cur_input_sum = input_sum_ + task_id * matmul_param_->row_8_ * thread_stride_ * C8NUM;
cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C8NUM;
cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C8NUM;
cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C8NUM;
}
Conv1x1Int8Opt(packed_input_, packed_weight_ + task_id * thread_stride_ * C8NUM * matmul_param_->deep_4_,
output_ptr_ + task_id * thread_stride_ * C8NUM, input_sum_,
output_ptr_ + task_id * thread_stride_ * C8NUM, cur_input_sum,
reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C8NUM, matmul_param_->row_,
cur_oc, matmul_param_->deep_4_, conv_param_, matmul_func_);
cur_oc, matmul_param_->deep_4_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_,
matmul_func_);
} else {
int cur_stride = thread_stride_ * C4NUM;
int res_stride = matmul_param_->col_ - task_id * thread_stride_ * C4NUM;
@ -255,10 +285,16 @@ int Convolution1x1Int8CPUKernel::RunImpl(int task_id) {
if (cur_oc <= 0) {
return RET_OK;
}
if (filter_peroc_) {
cur_input_sum = input_sum_ + task_id * matmul_param_->row_4_ * thread_stride_ * C4NUM;
cur_left_shift = conv_param_->conv_quant_arg_.left_shift_ + task_id * thread_stride_ * C4NUM;
cur_right_shift = conv_param_->conv_quant_arg_.right_shift_ + task_id * thread_stride_ * C4NUM;
cur_multiplier = conv_param_->conv_quant_arg_.quant_multiplier_ + task_id * thread_stride_ * C4NUM;
}
Conv1x1Int8(packed_input_, packed_weight_ + task_id * thread_stride_ * C4NUM * matmul_param_->deep_16_,
output_ptr_ + task_id * thread_stride_ * C4NUM, input_sum_,
output_ptr_ + task_id * thread_stride_ * C4NUM, cur_input_sum,
reinterpret_cast<int32_t *>(bias_data_) + task_id * thread_stride_ * C4NUM, matmul_param_->row_, cur_oc,
matmul_param_->deep_16_, conv_param_);
matmul_param_->deep_16_, cur_left_shift, cur_right_shift, cur_multiplier, conv_param_);
}
return RET_OK;
}
@ -270,10 +306,18 @@ int Convolution1x1Int8CPUKernel::RunPre(int task_id) {
if (cur_hw <= 0) {
return RET_OK;
}
Conv1x1PreOpt(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_,
packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_,
input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, matmul_param_->col_, cur_hw,
conv_param_);
if (filter_peroc_) {
Conv1x1PreOptPeroc(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_,
packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_,
input_sum_ + task_id * thread_stride_hw_ * C8NUM * C8NUM, matmul_param_->deep_,
matmul_param_->col_, cur_hw, filter_zp_ptr_, matmul_param_->row_8_ * C8NUM);
} else {
Conv1x1PreOptPert(input_ptr_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_,
packed_input_ + task_id * thread_stride_hw_ * C8NUM * matmul_param_->deep_4_,
input_sum_ + task_id * thread_stride_hw_ * C8NUM, matmul_param_->deep_, cur_hw, conv_param_);
}
return RET_OK;
}

@ -56,7 +56,8 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
void CheckSupportOptimize();
private:
int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */
int32_t *input_sum_ = nullptr; /* per-channel: oc4 format */
int32_t *filter_zp_ptr_ = nullptr; /* oc - per - channel */
int8_t *packed_weight_ = nullptr;
int8_t *packed_input_ = nullptr;
int8_t *input_ptr_ = nullptr;
@ -70,6 +71,7 @@ class Convolution1x1Int8CPUKernel : public ConvolutionBaseCPUKernel {
MatMulParameter *matmul_param_ = nullptr;
MATMUL_OPT_R_FUNC matmul_func_ = nullptr;
bool support_optimize_ = false;
bool filter_peroc_ = false;
};
} // namespace mindspore::kernel

@ -397,10 +397,9 @@ kernel::LiteKernel *CpuConvInt8KernelCreator(const std::vector<lite::Tensor *> &
int dilation_h = conv_param->dilation_h_;
int dilation_w = conv_param->dilation_w_;
kernel::LiteKernel *kernel;
auto filter_quant_size = inputs[kWeightIndex]->GetQuantParams().size();
if (kernel_h == 3 && kernel_w == 3 && stride_h == 1 && stride_w == 1 && dilation_h == 1 && dilation_w == 1) {
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else if (kernel_h == 1 && kernel_w == 1 && filter_quant_size == 1) {
} else if (kernel_h == 1 && kernel_w == 1) {
kernel = new (std::nothrow) kernel::Convolution1x1Int8CPUKernel(opParameter, inputs, outputs, ctx, primitive);
} else {
kernel = new (std::nothrow) kernel::ConvolutionInt8CPUKernel(opParameter, inputs, outputs, ctx, primitive);

Loading…
Cancel
Save