Add int8 conv weight per channel

pull/4432/head
fuzhiye 5 years ago
parent c7b50bcdd2
commit 72953307d9

@ -32,6 +32,7 @@
using mindspore::lite::Context;
using mindspore::schema::PadMode;
using mindspore::schema::QuantType;
static constexpr int kPerTensor = 1;
namespace mindspore::kernel {
class ConvolutionBaseCPUKernel : public LiteKernel {
@ -49,7 +50,14 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
int ReSize() override { return 0; }
int Run() override { return 0; }
virtual int CheckLayout(lite::tensor::Tensor *input_tensor);
int SetIfAsymmetric();
int SetIfPerChannel();
int MallocQuantParam();
int SetQuantParam();
int SetInputTensorQuantParam();
int SetFilterTensorQuantParam();
int SetOutputTensorQuantParam();
int SetQuantMultiplier();
void FreeQuantParam();
protected:
@ -59,9 +67,9 @@ class ConvolutionBaseCPUKernel : public LiteKernel {
void *nhwc4_input_ = nullptr;
const Context *ctx_;
ConvParameter *conv_param_;
ConvQuantArg *conv_quant_arg_;
LayoutConvertor convert_func_;
};
bool CheckSupportFP16();
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_CONVOLUTION_BASE_H_

@ -69,8 +69,8 @@ int ConvolutionInt8CPUKernel::InitWeightBias() {
int kernel_plane = kernel_h * kernel_w;
int plane_c4 = UP_DIV(kernel_plane, C4NUM);
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * plane_c4 * C4NUM;
int32_t filter_zp = conv_param_->conv_quant_arg_.quant_args_[1][0].zp_;
int32_t input_zp = conv_param_->conv_quant_arg_.quant_args_[0][0].zp_;
auto filter_arg = conv_param_->conv_quant_arg_.filter_quant_args_;
int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_;
// init weight
auto origin_weight = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->Data());
@ -99,8 +99,14 @@ int ConvolutionInt8CPUKernel::InitWeightBias() {
}
auto *bias_data = reinterpret_cast<int32_t *>(bias_data_);
int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM;
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_zp * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_arg[i].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
}
} else {
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_arg[0].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
}
}
free(weight_sum);
return RET_OK;
@ -125,7 +131,13 @@ int ConvolutionInt8CPUKernel::InitTmpBuffer() {
memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size);
/*=============================input_sum_============================*/
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t)));
size_t input_sum_size;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
input_sum_size = conv_param_->output_channel_ * tile_num_ * thread_count_ * sizeof(int32_t);
} else {
input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t);
}
input_sum_ = reinterpret_cast<int32_t *>(malloc(input_sum_size));
if (input_sum_ == nullptr) {
MS_LOG(ERROR) << "malloc input_sum_ failed.";
return RET_ERROR;
@ -168,8 +180,8 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
int oc4 = UP_DIV(out_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane;
int32_t filter_zp = conv_param_->conv_quant_arg_.quant_args_[1][0].zp_;
int32_t input_zp = conv_param_->conv_quant_arg_.quant_args_[0][0].zp_;
auto filter_arg = conv_param_->conv_quant_arg_.filter_quant_args_;
int32_t input_zp = conv_param_->conv_quant_arg_.input_quant_args_[0].zp_;
// init weight
auto origin_weight = reinterpret_cast<int8_t *>(in_tensors_.at(kWeightIndex)->Data());
@ -178,9 +190,9 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
MS_LOG(ERROR) << "malloc packed_weight_ failed.";
return RET_ERROR;
}
memset(packed_weight_, filter_zp, pack_weight_size);
memset(packed_weight_, 0, pack_weight_size);
auto *weight_sum = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t) * out_channel));
for (int i = 0; i < out_channel; i++) weight_sum[i] = filter_zp * ic4 * C4NUM * kernel_plane;
for (int i = 0; i < out_channel; i++) weight_sum[i] = 0;
PackWeightInt8Opt(origin_weight, conv_param_, packed_weight_, weight_sum);
// init bias
@ -198,8 +210,14 @@ int ConvolutionInt8CPUKernel::InitWeightBiasOpt() {
}
auto *bias_data = reinterpret_cast<int32_t *>(bias_data_);
int c4_kernel_plane_size = kernel_plane * ic4 * C4NUM;
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_zp * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_arg[i].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
}
} else {
for (int i = 0; i < out_channel; i++) {
bias_data[i] += filter_arg[0].zp_ * input_zp * c4_kernel_plane_size - weight_sum[i] * input_zp;
}
}
free(weight_sum);
return RET_OK;
@ -223,7 +241,13 @@ int ConvolutionInt8CPUKernel::InitTmpBufferOpt() {
memset(packed_input_, 0, conv_param_->input_batch_ * packed_input_size);
/*=============================input_sum_============================*/
input_sum_ = reinterpret_cast<int32_t *>(malloc(tile_num_ * thread_count_ * sizeof(int32_t)));
size_t input_sum_size;
if (conv_quant_arg_->per_channel_ & FILTER_PER_CHANNEL) {
input_sum_size = conv_param_->output_channel_ * tile_num_ * thread_count_ * sizeof(int32_t);
} else {
input_sum_size = tile_num_ * thread_count_ * sizeof(int32_t);
}
input_sum_ = reinterpret_cast<int32_t *>(malloc(input_sum_size));
if (input_sum_ == nullptr) {
MS_LOG(ERROR) << "malloc input_sum_ failed.";
return RET_ERROR;

@ -77,7 +77,7 @@ void DepthwiseBorderInt8(int8_t *dst, const int16_t *src, const int16_t *weight,
dst_kernel, src_kernel, weight_kernel, bias, end_kh - start_kh, end_kw - start_kw, sliding->in_kh_step_,
sliding->in_kw_step_, conv_param->kernel_w_, conv_param->conv_quant_arg_.quant_multiplier_[0],
conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0],
conv_param->conv_quant_arg_.quant_args_[2][0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
conv_param->conv_quant_arg_.output_quant_args_[0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
conv_param->conv_quant_arg_.out_act_max_[0]);
dst_kernel += sliding->block_channel_;
@ -168,15 +168,15 @@ void ConvDwInt8(int8_t *output_data, const int16_t *input_data, const int16_t *w
sliding->in_sw_step_ * sizeof(int16_t), sliding->in_kh_step_ * sizeof(int16_t),
sliding->in_kw_step_ * sizeof(int16_t), conv_param->conv_quant_arg_.quant_multiplier_[0],
conv_param->conv_quant_arg_.left_shift_[0], conv_param->conv_quant_arg_.right_shift_[0],
conv_param->conv_quant_arg_.quant_args_[2][0].zp_, conv_param->conv_quant_arg_.out_act_min_[0],
conv_param->conv_quant_arg_.out_act_max_[0]);
conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
#else
DepthwiseCenterInt8(
out_t, in_t, weight, bias, sliding->bottom_ - sliding->top_, sliding->right_ - sliding->left_,
conv_param->kernel_h_, conv_param->kernel_w_, sliding->out_h_step_, sliding->block_channel_,
sliding->in_sh_step_, sliding->in_sw_step_, sliding->in_kh_step_, sliding->in_kw_step_,
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_,
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
#endif
}
@ -333,7 +333,7 @@ void DeconvDwInt8(int8_t *output_data, int32_t *output_buffer, const int16_t *in
DeconvDepthwisePostFuncInt8(
dst_data, output_buffer, bias, sliding->block_channel_, conv_param,
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_,
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
} // output C4 loop
src += sliding->in_step_;

@ -22,10 +22,10 @@
void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum,
ConvParameter *conv_param) {
int32_t shift_before = conv_param->conv_quant_arg_.left_shift_[0];
int32_t shift_after = conv_param->conv_quant_arg_.right_shift_[0];
int32_t out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0];
int32_t out_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_;
int32_t *shift_before = conv_param->conv_quant_arg_.left_shift_;
int32_t *shift_after = conv_param->conv_quant_arg_.right_shift_;
int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0];
int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0];
#ifdef __aarch64__
@ -63,14 +63,49 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
} // in c4num loop
} // ic4 loop
} // kernel_plane loop
tmp_dst[dst_tile_offset] -= input_sum[n];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before), out_multiplier), -shift_after);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]),
-shift_after[oc]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]),
-shift_after[0]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
tmp_dst[dst_tile_offset] -= input_sum[n];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]),
-shift_after[0]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
tmp_dst[dst_tile_offset] -= input_sum[n * output_channel + oc];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]),
-shift_after[oc]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
}
} // tile_num loop
} // output_channel loop
#endif
@ -79,10 +114,10 @@ void IndirectGemmInt8(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const in
void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const int8_t *weight, const int32_t *bias,
int ic4, size_t kernel_plane, size_t output_channel, const int32_t *input_sum,
ConvParameter *conv_param, GEMM_FUNC gemm_func) {
int32_t shift_before = conv_param->conv_quant_arg_.left_shift_[0];
int32_t shift_after = conv_param->conv_quant_arg_.right_shift_[0];
int32_t out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_[0];
int32_t out_zp = conv_param->conv_quant_arg_.quant_args_[2][0].zp_;
int32_t *shift_before = conv_param->conv_quant_arg_.left_shift_;
int32_t *shift_after = conv_param->conv_quant_arg_.right_shift_;
int32_t *out_multiplier = conv_param->conv_quant_arg_.quant_multiplier_;
int32_t out_zp = conv_param->conv_quant_arg_.output_quant_args_[0].zp_;
int32_t act_min = conv_param->conv_quant_arg_.out_act_min_[0];
int32_t act_max = conv_param->conv_quant_arg_.out_act_max_[0];
if (gemm_func != NULL) {
@ -113,14 +148,49 @@ void IndirectGemmInt8Opt(int8_t *dst, int32_t *tmp_dst, const int8_t *src, const
} // in c4num loop
} // ic4 loop
} // kernel_plane loop
tmp_dst[dst_tile_offset] -= input_sum[n];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before), out_multiplier), -shift_after);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]),
-shift_after[oc]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]),
-shift_after[0]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
tmp_dst[dst_tile_offset] -= input_sum[n];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[0]), out_multiplier[0]),
-shift_after[0]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
tmp_dst[dst_tile_offset] -= input_sum[n * output_channel + oc];
int result = tmp_dst[dst_tile_offset] + bias[oc];
result = RoundingDivideByPOT(
SaturatingRoundingDoublingHighMul(result * (1 << (unsigned int)shift_before[oc]), out_multiplier[oc]),
-shift_after[oc]);
result += out_zp;
result = result > act_min ? result : act_min;
result = result < act_max ? result : act_max;
dst[dst_tile_offset] = (int8_t)result;
}
} // tile_num loop
} // output_channel loop
}
@ -182,7 +252,7 @@ void ConvInt8(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight, c
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_;
int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
int tile_n = conv_param->tile_num_;
int thread_count = conv_param->thread_num_;
@ -238,7 +308,7 @@ void ConvInt8Opt(int8_t *input_data, int8_t *packed_input, int8_t *packed_weight
int out_h = conv_param->output_h_;
int out_w = conv_param->output_w_;
int out_channel = conv_param->output_channel_;
int32_t input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_;
int32_t input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
int tile_n = conv_param->tile_num_;
int thread_count = conv_param->thread_num_;
int output_count = out_h * out_w;

@ -19,8 +19,8 @@
int DeConvInt8(const int8_t *input, const int8_t *weight, int32_t *output, size_t row8, size_t col8, size_t deep,
ConvParameter *conv_param) {
MatMulInt8(input, weight, output, row8, col8, deep, conv_param->conv_quant_arg_.quant_args_[0][0].zp_,
conv_param->conv_quant_arg_.quant_args_[1][0].zp_);
MatMulInt8(input, weight, output, row8, col8, deep, conv_param->conv_quant_arg_.input_quant_args_[0].zp_,
conv_param->conv_quant_arg_.filter_quant_args_[0].zp_);
return NNACL_OK;
}
@ -65,7 +65,7 @@ int DeConvPostInt8(const int32_t *src, const int32_t *bias, int32_t *tmp, int8_t
PostFuncInt8(tmp, bias, out, output_channel, output_plane, UP_ROUND(output_plane, 8),
conv_param->conv_quant_arg_.quant_multiplier_[0], conv_param->conv_quant_arg_.left_shift_[0],
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.quant_args_[2][0].zp_,
conv_param->conv_quant_arg_.right_shift_[0], conv_param->conv_quant_arg_.output_quant_args_[0].zp_,
conv_param->conv_quant_arg_.out_act_min_[0], conv_param->conv_quant_arg_.out_act_max_[0]);
return NNACL_OK;
}

@ -115,7 +115,6 @@ void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *p
int oc4 = UP_DIV(out_channel, C4NUM);
int ic4 = UP_DIV(in_channel, C4NUM);
int kernel_plane = kernel_h * kernel_w;
int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
int pack_weight_size = oc4 * ic4 * C4NUM * C4NUM * kernel_plane;
int unit_size = C4NUM * C4NUM;
int block_size = pack_weight_size / oc4;
@ -143,7 +142,7 @@ void PackWeightInt8Opt(int8_t *weight_data, ConvParameter *conv_param, int8_t *p
if (packed_data_ptr[0] == -128) {
packed_data_ptr[0] = -127;
}
weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0] - filter_zp);
weight_sum[j * C4NUM + k] += (int32_t)(packed_data_ptr[0]);
}
} // kernel block loop
} // inchannel block loop
@ -241,7 +240,7 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
int32_t *input_sum, ConvParameter *conv_param) {
// input format : nhwc
int tile_num = conv_param->tile_num_;
int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
QuantArg *filter_arg = conv_param->conv_quant_arg_.filter_quant_args_;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
@ -292,7 +291,18 @@ void Im2ColPackUnitInt8(const int8_t *input_data, int8_t *packed_input, int real
} // channel_block loop
} // kernel_w loop
} // kernel_h loop
input_sum[i] = input_accumulator * filter_zp;
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC)) {
return;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int cal_num_offset = i * conv_param->output_channel_;
for (int l = 0; l < conv_param->output_channel_; ++l) {
input_sum[cal_num_offset + l] = input_accumulator * filter_arg[i].zp_;
}
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
input_sum[i] = input_accumulator * filter_arg[0].zp_;
}
} // tile num loop
}
@ -300,7 +310,7 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
int32_t *input_sum, ConvParameter *conv_param) {
// input format : nhwc
int tile_num = conv_param->tile_num_;
int32_t filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
QuantArg *filter_arg = conv_param->conv_quant_arg_.filter_quant_args_;
int kernel_h = conv_param->kernel_h_;
int kernel_w = conv_param->kernel_w_;
int stride_h = conv_param->stride_h_;
@ -348,13 +358,23 @@ void Im2ColPackUnitInt8Opt(const int8_t *input_data, int8_t *packed_input, int r
int block_offset = j * tile_num * ic4 * C4NUM + i * C4NUM;
for (int c = 0; c < ic4; c++) {
int ic4_offset = block_offset + c * tile_num * C4NUM;
input_accumulator += (packed_input + ic4_offset)[0];
input_accumulator += (packed_input + ic4_offset)[1];
input_accumulator += (packed_input + ic4_offset)[2];
input_accumulator += (packed_input + ic4_offset)[3];
for (int k = 0; k < C4NUM; ++k) {
input_accumulator += (packed_input + ic4_offset)[k];
}
}
}
if (!(conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC)) {
return;
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
int cal_num_offset = i * conv_param->output_channel_;
for (int l = 0; l < conv_param->output_channel_; ++l) {
input_sum[cal_num_offset + l] = input_accumulator * filter_arg[i].zp_;
}
} else if ((conv_param->conv_quant_arg_.asymmetric_ & FILTER_ASYMMETRIC) &&
!(conv_param->conv_quant_arg_.per_channel_ & FILTER_PER_CHANNEL)) {
input_sum[i] = input_accumulator * filter_arg[0].zp_;
}
input_sum[i] = input_accumulator * filter_zp;
} // tile num loop
}
@ -387,7 +407,7 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight
int input_channel = conv_param->input_channel_;
int ic8 = UP_DIV(input_channel, C8NUM);
int output_channel = conv_param->output_channel_;
int filter_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
QuantArg *filter_zp = conv_param->conv_quant_arg_.filter_quant_args_;
int kernel_plane = conv_param->kernel_h_ * conv_param->kernel_w_;
for (int k = 0; k < kernel_plane; k++) {
@ -401,7 +421,7 @@ void PackWeightToC8Int8(const int8_t *origin_weight_data, int16_t *packed_weight
int c8_block_rem = i % C8NUM;
int src_ic_offset = src_oc_offset + i;
int dst_ic_offset = dst_oc_offset + c8_block_num * kernel_plane * C8NUM + c8_block_rem;
(packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - filter_zp);
(packed_weight_data + dst_ic_offset)[0] = (int16_t)((origin_weight_data + src_ic_offset)[0] - filter_zp[o].zp_);
}
}
}
@ -806,7 +826,7 @@ void MatrixPack(const float *src, float *dst, int row, int ic4, int stride) {
}
void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter *conv_param) {
int input_zp = conv_param->conv_quant_arg_.quant_args_[0][0].zp_;
int input_zp = conv_param->conv_quant_arg_.input_quant_args_[0].zp_;
int ic4 = UP_DIV(conv_param->input_channel_, C4NUM);
int unit = conv_param->input_h_ * conv_param->input_w_;
@ -824,7 +844,7 @@ void PackDepthwiseInt8Input(const int8_t *src, int16_t *dst, const ConvParameter
}
void PackDepthwiseInt8Weight(const int8_t *origin_weight, int16_t *packed_weight_, const ConvParameter *conv_param) {
int weight_zp = conv_param->conv_quant_arg_.quant_args_[1][0].zp_;
int weight_zp = conv_param->conv_quant_arg_.filter_quant_args_[0].zp_;
int unit = conv_param->kernel_h_ * conv_param->kernel_w_;
for (int c = 0; c < conv_param->output_channel_; c++) {
int c4_block_num = c / C4NUM;

@ -17,25 +17,37 @@
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_QUANTIZE_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_NNACL_QUANTIZATION_QUANTIZE_H_
#include <stdint.h>
#include <math.h>
#include <stdlib.h>
#include <limits.h>
#include "nnacl/op_base.h"
#define INPUT_ASYMMETRIC 0b001
#define FILTER_ASYMMETRIC 0b010
#define OUTPUT_ASYMMETRIC 0b100
#define INPUT_PER_CHANNEL 0b001
#define FILTER_PER_CHANNEL 0b010
#define OUTPUT_PER_CHANNEL 0b100
typedef struct QuantArg {
double scale_;
int32_t zp_;
} QuantArg;
typedef struct ConvQuantArg {
QuantArg **quant_args_;
QuantArg *input_quant_args_;
QuantArg *filter_quant_args_;
QuantArg *output_quant_args_;
double *real_multiplier_;
int32_t *left_shift_;
int32_t *right_shift_;
int32_t *quant_multiplier_;
int32_t *out_act_min_;
int32_t *out_act_max_;
size_t input_arg_num_;
size_t filter_arg_num_;
size_t output_arg_num_;
uint8_t asymmetric_;
uint8_t per_channel_;
} ConvQuantArg;
typedef struct ConcatQuantArg {

@ -65,7 +65,7 @@ void Conv3x3Int8FilterTransform(const int16_t *weight_data, int16_t *trans_weigh
int kernel_plane);
void Conv3x3Uint8OutputUnit(const int32_t *gemm_out, const int32_t *bias_data, int8_t *output_data, bool h_not_bound,
bool w_not_bound, int output_w, int real_num, ConvParameter *conv_param);
bool w_not_bound, int output_w, int real_num, int oc_start, ConvParameter *conv_param);
void Conv3x3Uint8OutputTransform(const int32_t *gemm_out, int8_t *out_data, const int32_t *bias_data, int start_index,
int real_cal_num, int out_w_block, ConvParameter *conv_param);

Loading…
Cancel
Save