|
|
|
@ -72,21 +72,21 @@ int MatMulInt8Coder::ReSize(CoderContext *const context) {
|
|
|
|
|
|
|
|
|
|
std::vector<QuantArg> params = input_tensor_->quant_params();
|
|
|
|
|
MS_CHECK_TRUE(!params.empty(), "params is empty");
|
|
|
|
|
quant_params_.input.zp_ = params.front().zeroPoint;
|
|
|
|
|
quant_params_.input.scale_ = static_cast<float>(params.front().scale);
|
|
|
|
|
quant_params_.input_.zp_ = params.front().zeroPoint;
|
|
|
|
|
quant_params_.input_.scale_ = static_cast<float>(params.front().scale);
|
|
|
|
|
|
|
|
|
|
params = filter_tensor_->quant_params();
|
|
|
|
|
MS_CHECK_TRUE(!params.empty(), "params is empty");
|
|
|
|
|
quant_params_.weight.zp_ = params.front().zeroPoint;
|
|
|
|
|
quant_params_.weight.scale_ = static_cast<float>(params.front().scale);
|
|
|
|
|
quant_params_.weight_.zp_ = params.front().zeroPoint;
|
|
|
|
|
quant_params_.weight_.scale_ = static_cast<float>(params.front().scale);
|
|
|
|
|
|
|
|
|
|
params = output_tensor_->quant_params();
|
|
|
|
|
MS_CHECK_TRUE(!params.empty(), "params is empty");
|
|
|
|
|
quant_params_.output.zp_ = params.front().zeroPoint;
|
|
|
|
|
quant_params_.output.scale_ = static_cast<float>(params.front().scale);
|
|
|
|
|
double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_;
|
|
|
|
|
QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
|
|
|
|
|
&quant_params_.right_shift);
|
|
|
|
|
quant_params_.output_.zp_ = params.front().zeroPoint;
|
|
|
|
|
quant_params_.output_.scale_ = static_cast<float>(params.front().scale);
|
|
|
|
|
double real_multiplier = quant_params_.input_.scale_ * quant_params_.weight_.scale_ / quant_params_.output_.scale_;
|
|
|
|
|
QuantizeRoundParameterWithDoublePrecision(real_multiplier, quant_params_.quant_multiplier_, quant_params_.left_shift_,
|
|
|
|
|
quant_params_.right_shift_);
|
|
|
|
|
if (params_->b_const_) {
|
|
|
|
|
NNaclInt8Serializer init_code;
|
|
|
|
|
if (bias_ptr_) {
|
|
|
|
@ -99,15 +99,31 @@ int MatMulInt8Coder::ReSize(CoderContext *const context) {
|
|
|
|
|
init_code.CodeMallocExpression(b_pack_batch_ptr_, b_pack_batch_ptr_size_);
|
|
|
|
|
init_code.CodeFunction("memset", b_pack_batch_ptr_, 0, b_pack_batch_ptr_size_);
|
|
|
|
|
|
|
|
|
|
init_code << "int tmp_weight_zp = " << quant_params_.weight.zp_ << ";\n";
|
|
|
|
|
init_code << "int tmp_weight_zp = " << quant_params_.weight_.zp_ << ";\n";
|
|
|
|
|
init_code.CodeFunction("InitIn8MatrixB", filter_tensor_->data_c(), weight_bias_sums_batch_, b_pack_batch_ptr_,
|
|
|
|
|
params_->batch, params_->deep_, params_->col_, params_->col_4_, params_->deep_16_,
|
|
|
|
|
quant_params_.input.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_);
|
|
|
|
|
quant_params_.input_.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_);
|
|
|
|
|
context->AppendInitCode(init_code.str());
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
MatMulInt8Coder::~MatMulInt8Coder() {
|
|
|
|
|
if (quant_params_.quant_multiplier_ != nullptr) {
|
|
|
|
|
free(quant_params_.quant_multiplier_);
|
|
|
|
|
quant_params_.quant_multiplier_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (quant_params_.right_shift_ != nullptr) {
|
|
|
|
|
free(quant_params_.right_shift_);
|
|
|
|
|
quant_params_.right_shift_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (quant_params_.left_shift_ != nullptr) {
|
|
|
|
|
free(quant_params_.left_shift_);
|
|
|
|
|
quant_params_.left_shift_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int MatMulInt8Coder::Init() {
|
|
|
|
|
params_ = reinterpret_cast<MatMulParameter *>(parameter_);
|
|
|
|
|
filter_tensor_ = input_tensors_.at(kWeightIndex);
|
|
|
|
@ -118,6 +134,14 @@ int MatMulInt8Coder::Init() {
|
|
|
|
|
MS_CHECK_PTR(bias_tensor_->data_c());
|
|
|
|
|
}
|
|
|
|
|
params_->b_const_ = (filter_tensor_->data_c() != nullptr);
|
|
|
|
|
|
|
|
|
|
quant_params_.quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(1 * sizeof(int32_t)));
|
|
|
|
|
MS_CHECK_PTR(quant_params_.quant_multiplier_);
|
|
|
|
|
quant_params_.left_shift_ = reinterpret_cast<int32_t *>(malloc(1 * sizeof(int32_t)));
|
|
|
|
|
MS_CHECK_PTR(quant_params_.left_shift_);
|
|
|
|
|
quant_params_.right_shift_ = reinterpret_cast<int32_t *>(malloc(1 * sizeof(int32_t)));
|
|
|
|
|
MS_CHECK_PTR(quant_params_.right_shift_);
|
|
|
|
|
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -143,11 +167,11 @@ int MatMulInt8Coder::DoCode(CoderContext *const context) {
|
|
|
|
|
if (cur_oc <= 0) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
code << "int tmp_weight_zp = " << quant_params_.weight.zp_ << ";\n";
|
|
|
|
|
code << "int tmp_weight_zp = " << quant_params_.weight_.zp_ << ";\n";
|
|
|
|
|
if (!params_->b_const_) {
|
|
|
|
|
code.CodeFunction("InitIn8MatrixB", filter_tensor_->data_c(), weight_bias_sums_batch_, b_pack_batch_ptr_,
|
|
|
|
|
params_->batch, params_->deep_, params_->col_, params_->col_4_, params_->deep_16_,
|
|
|
|
|
quant_params_.input.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_);
|
|
|
|
|
quant_params_.input_.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_);
|
|
|
|
|
}
|
|
|
|
|
std::string b_batch_str = allocator_->GetRuntimeAddr(b_pack_batch_ptr_);
|
|
|
|
|
std::string weight_bias_sums_batch_str = allocator_->GetRuntimeAddr(weight_bias_sums_batch_);
|
|
|
|
@ -157,10 +181,10 @@ int MatMulInt8Coder::DoCode(CoderContext *const context) {
|
|
|
|
|
code << " int8_t* cur_a_ptr = " << a_ptr_str << " + i * " << a_stride << ";\n";
|
|
|
|
|
if (params_->a_transpose_) {
|
|
|
|
|
code.CodeFunction("RowMajor2Col16x4MajorInt8", "cur_a_ptr", params_->deep_, params_->row_, a_pack_ptr_);
|
|
|
|
|
code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor);
|
|
|
|
|
code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->deep_, quant_params_.weight_.zp_, input_sums_, ColMajor);
|
|
|
|
|
} else {
|
|
|
|
|
code.CodeFunction("RowMajor2Row16x4MajorInt8", "cur_a_ptr", a_pack_ptr_, params_->row_, params_->deep_);
|
|
|
|
|
code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->row_, params_->deep_, quant_params_.weight.zp_,
|
|
|
|
|
code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->row_, params_->deep_, quant_params_.weight_.zp_,
|
|
|
|
|
input_sums_, RowMajor);
|
|
|
|
|
}
|
|
|
|
|
code << " b_pack_ptr_ = " << b_batch_str << " + i * " << params_->col_4_ * params_->deep_16_ << ";\n";
|
|
|
|
@ -171,18 +195,18 @@ int MatMulInt8Coder::DoCode(CoderContext *const context) {
|
|
|
|
|
code << " int8_t* cur_b = b_pack_ptr_ + " << task_id * thread_stride_ * C4NUM * params_->deep_16_ << ";\n";
|
|
|
|
|
code << " int32_t* cur_bias = weight_bias_sums_ + " << task_id * thread_stride_ * C4NUM << ";\n";
|
|
|
|
|
code << " int8_t *cur_c = c_ptr_ + " << task_id * thread_stride_ * C4NUM << ";\n";
|
|
|
|
|
code << " static const int left_shift = " << quant_params_.left_shift << ";\n";
|
|
|
|
|
code << " static const int right_shift = " << quant_params_.right_shift << ";\n";
|
|
|
|
|
code << " static const int quant_multiplier = " << quant_params_.quant_multiplier << ";\n";
|
|
|
|
|
code << " static const int left_shift = " << quant_params_.left_shift_[0] << ";\n";
|
|
|
|
|
code << " static const int right_shift = " << quant_params_.right_shift_[0] << ";\n";
|
|
|
|
|
code << " static const int quant_multiplier = " << quant_params_.quant_multiplier_[0] << ";\n";
|
|
|
|
|
if (target_ == kARM64) {
|
|
|
|
|
code.CodeFunction("MatmulInt8Neon64", "cur_a_ptr", "cur_b", "cur_c", params_->row_4_, cur_oc * C4NUM,
|
|
|
|
|
params_->deep_16_, input_sums_, "cur_bias", INT8_MIN, INT8_MAX, quant_params_.output.zp_,
|
|
|
|
|
params_->deep_16_, input_sums_, "cur_bias", INT8_MIN, INT8_MAX, quant_params_.output_.zp_,
|
|
|
|
|
"&quant_multiplier", "&left_shift", "&right_shift", params_->row_, cur_oc_res, params_->col_,
|
|
|
|
|
false);
|
|
|
|
|
} else {
|
|
|
|
|
code.CodeFunction("MatMulInt8_16x4_r", "cur_a_ptr", "cur_b", "cur_c", params_->row_, cur_oc_res, params_->deep_16_,
|
|
|
|
|
params_->col_, input_sums_, "cur_bias", "&left_shift", "&right_shift", "&quant_multiplier",
|
|
|
|
|
quant_params_.output.zp_, INT8_MIN, INT8_MAX, false);
|
|
|
|
|
quant_params_.output_.zp_, INT8_MIN, INT8_MAX, false);
|
|
|
|
|
}
|
|
|
|
|
code << "}\n";
|
|
|
|
|
MS_LOG(DEBUG) << "FullConnectionInt8Coder has been called";
|
|
|
|
|