[MSLITE][MICRO] delete MatMul Quant args

pull/12206/head
ling 4 years ago
parent 79675d612e
commit f937653fcb

@ -72,21 +72,21 @@ int MatMulInt8Coder::ReSize(CoderContext *const context) {
std::vector<QuantArg> params = input_tensor_->quant_params();
MS_CHECK_TRUE(!params.empty(), "params is empty");
quant_params_.input.zp_ = params.front().zeroPoint;
quant_params_.input.scale_ = static_cast<float>(params.front().scale);
quant_params_.input_.zp_ = params.front().zeroPoint;
quant_params_.input_.scale_ = static_cast<float>(params.front().scale);
params = filter_tensor_->quant_params();
MS_CHECK_TRUE(!params.empty(), "params is empty");
quant_params_.weight.zp_ = params.front().zeroPoint;
quant_params_.weight.scale_ = static_cast<float>(params.front().scale);
quant_params_.weight_.zp_ = params.front().zeroPoint;
quant_params_.weight_.scale_ = static_cast<float>(params.front().scale);
params = output_tensor_->quant_params();
MS_CHECK_TRUE(!params.empty(), "params is empty");
quant_params_.output.zp_ = params.front().zeroPoint;
quant_params_.output.scale_ = static_cast<float>(params.front().scale);
double real_multiplier = quant_params_.input.scale_ * quant_params_.weight.scale_ / quant_params_.output.scale_;
QuantizeRoundParameterWithDoublePrecision(real_multiplier, &quant_params_.quant_multiplier, &quant_params_.left_shift,
&quant_params_.right_shift);
quant_params_.output_.zp_ = params.front().zeroPoint;
quant_params_.output_.scale_ = static_cast<float>(params.front().scale);
double real_multiplier = quant_params_.input_.scale_ * quant_params_.weight_.scale_ / quant_params_.output_.scale_;
QuantizeRoundParameterWithDoublePrecision(real_multiplier, quant_params_.quant_multiplier_, quant_params_.left_shift_,
quant_params_.right_shift_);
if (params_->b_const_) {
NNaclInt8Serializer init_code;
if (bias_ptr_) {
@ -99,15 +99,31 @@ int MatMulInt8Coder::ReSize(CoderContext *const context) {
init_code.CodeMallocExpression(b_pack_batch_ptr_, b_pack_batch_ptr_size_);
init_code.CodeFunction("memset", b_pack_batch_ptr_, 0, b_pack_batch_ptr_size_);
init_code << "int tmp_weight_zp = " << quant_params_.weight.zp_ << ";\n";
init_code << "int tmp_weight_zp = " << quant_params_.weight_.zp_ << ";\n";
init_code.CodeFunction("InitIn8MatrixB", filter_tensor_->data_c(), weight_bias_sums_batch_, b_pack_batch_ptr_,
params_->batch, params_->deep_, params_->col_, params_->col_4_, params_->deep_16_,
quant_params_.input.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_);
quant_params_.input_.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_);
context->AppendInitCode(init_code.str());
}
return RET_OK;
}
MatMulInt8Coder::~MatMulInt8Coder() {
if (quant_params_.quant_multiplier_ != nullptr) {
free(quant_params_.quant_multiplier_);
quant_params_.quant_multiplier_ = nullptr;
}
if (quant_params_.right_shift_ != nullptr) {
free(quant_params_.right_shift_);
quant_params_.right_shift_ = nullptr;
}
if (quant_params_.left_shift_ != nullptr) {
free(quant_params_.left_shift_);
quant_params_.left_shift_ = nullptr;
}
return;
}
int MatMulInt8Coder::Init() {
params_ = reinterpret_cast<MatMulParameter *>(parameter_);
filter_tensor_ = input_tensors_.at(kWeightIndex);
@ -118,6 +134,14 @@ int MatMulInt8Coder::Init() {
MS_CHECK_PTR(bias_tensor_->data_c());
}
params_->b_const_ = (filter_tensor_->data_c() != nullptr);
quant_params_.quant_multiplier_ = reinterpret_cast<int32_t *>(malloc(1 * sizeof(int32_t)));
MS_CHECK_PTR(quant_params_.quant_multiplier_);
quant_params_.left_shift_ = reinterpret_cast<int32_t *>(malloc(1 * sizeof(int32_t)));
MS_CHECK_PTR(quant_params_.left_shift_);
quant_params_.right_shift_ = reinterpret_cast<int32_t *>(malloc(1 * sizeof(int32_t)));
MS_CHECK_PTR(quant_params_.right_shift_);
return RET_OK;
}
@ -143,11 +167,11 @@ int MatMulInt8Coder::DoCode(CoderContext *const context) {
if (cur_oc <= 0) {
return RET_OK;
}
code << "int tmp_weight_zp = " << quant_params_.weight.zp_ << ";\n";
code << "int tmp_weight_zp = " << quant_params_.weight_.zp_ << ";\n";
if (!params_->b_const_) {
code.CodeFunction("InitIn8MatrixB", filter_tensor_->data_c(), weight_bias_sums_batch_, b_pack_batch_ptr_,
params_->batch, params_->deep_, params_->col_, params_->col_4_, params_->deep_16_,
quant_params_.input.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_);
quant_params_.input_.zp_, "&tmp_weight_zp", bias_ptr_, params_->b_transpose_);
}
std::string b_batch_str = allocator_->GetRuntimeAddr(b_pack_batch_ptr_);
std::string weight_bias_sums_batch_str = allocator_->GetRuntimeAddr(weight_bias_sums_batch_);
@ -157,10 +181,10 @@ int MatMulInt8Coder::DoCode(CoderContext *const context) {
code << " int8_t* cur_a_ptr = " << a_ptr_str << " + i * " << a_stride << ";\n";
if (params_->a_transpose_) {
code.CodeFunction("RowMajor2Col16x4MajorInt8", "cur_a_ptr", params_->deep_, params_->row_, a_pack_ptr_);
code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->deep_, quant_params_.weight.zp_, input_sums_, ColMajor);
code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->deep_, quant_params_.weight_.zp_, input_sums_, ColMajor);
} else {
code.CodeFunction("RowMajor2Row16x4MajorInt8", "cur_a_ptr", a_pack_ptr_, params_->row_, params_->deep_);
code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->row_, params_->deep_, quant_params_.weight.zp_,
code.CodeFunction("CalcInputSums", "cur_a_ptr", params_->row_, params_->deep_, quant_params_.weight_.zp_,
input_sums_, RowMajor);
}
code << " b_pack_ptr_ = " << b_batch_str << " + i * " << params_->col_4_ * params_->deep_16_ << ";\n";
@ -171,18 +195,18 @@ int MatMulInt8Coder::DoCode(CoderContext *const context) {
code << " int8_t* cur_b = b_pack_ptr_ + " << task_id * thread_stride_ * C4NUM * params_->deep_16_ << ";\n";
code << " int32_t* cur_bias = weight_bias_sums_ + " << task_id * thread_stride_ * C4NUM << ";\n";
code << " int8_t *cur_c = c_ptr_ + " << task_id * thread_stride_ * C4NUM << ";\n";
code << " static const int left_shift = " << quant_params_.left_shift << ";\n";
code << " static const int right_shift = " << quant_params_.right_shift << ";\n";
code << " static const int quant_multiplier = " << quant_params_.quant_multiplier << ";\n";
code << " static const int left_shift = " << quant_params_.left_shift_[0] << ";\n";
code << " static const int right_shift = " << quant_params_.right_shift_[0] << ";\n";
code << " static const int quant_multiplier = " << quant_params_.quant_multiplier_[0] << ";\n";
if (target_ == kARM64) {
code.CodeFunction("MatmulInt8Neon64", "cur_a_ptr", "cur_b", "cur_c", params_->row_4_, cur_oc * C4NUM,
params_->deep_16_, input_sums_, "cur_bias", INT8_MIN, INT8_MAX, quant_params_.output.zp_,
params_->deep_16_, input_sums_, "cur_bias", INT8_MIN, INT8_MAX, quant_params_.output_.zp_,
"&quant_multiplier", "&left_shift", "&right_shift", params_->row_, cur_oc_res, params_->col_,
false);
} else {
code.CodeFunction("MatMulInt8_16x4_r", "cur_a_ptr", "cur_b", "cur_c", params_->row_, cur_oc_res, params_->deep_16_,
params_->col_, input_sums_, "cur_bias", "&left_shift", "&right_shift", "&quant_multiplier",
quant_params_.output.zp_, INT8_MIN, INT8_MAX, false);
quant_params_.output_.zp_, INT8_MIN, INT8_MAX, false);
}
code << "}\n";
MS_LOG(DEBUG) << "FullConnectionInt8Coder has been called";

@ -39,7 +39,7 @@ class MatMulInt8Coder final : public OperatorCoder {
Tensor *filter_tensor_{nullptr};
Tensor *bias_tensor_{nullptr};
MatMulParameter *params_{nullptr};
MatmulQuantArg quant_params_{0};
MatmulQuantParameter quant_params_{0};
size_t a_pack_ptr_size_{0};
int8_t *a_pack_ptr_{nullptr};
size_t b_pack_batch_ptr_size_{0};

@ -20,7 +20,6 @@
#include "coder/log.h"
namespace mindspore::lite::micro::nnacl {
void NNaclInt8Serializer::CodeStruct(const std::string &name, const ConvParameter &conv_parameter) {
const ConvQuantArg &quant_arg = conv_parameter.conv_quant_arg_;
std::string quant_arg_in = name + "_quant_arg_in";
@ -195,10 +194,11 @@ void NNaclInt8Serializer::CodeStruct(const std::string &name, const ReshapeQuant
reshape_quant_arg.output_activation_min_, reshape_quant_arg.output_activation_max_);
}
void NNaclInt8Serializer::CodeStruct(const std::string &name, const MatmulQuantArg &matmul_quant_arg) {
CodeBaseStruct("MatmulQuantArg", name, matmul_quant_arg.input, matmul_quant_arg.weight, matmul_quant_arg.output,
matmul_quant_arg.out_act_min, matmul_quant_arg.out_act_max, matmul_quant_arg.left_shift,
matmul_quant_arg.right_shift, matmul_quant_arg.quant_multiplier);
void NNaclInt8Serializer::CodeStruct(const std::string &name, const MatmulQuantParameter &matmul_quant_arg) {
CodeBaseStruct("MatmulQuantParameter", name, matmul_quant_arg.input_, matmul_quant_arg.weight_,
matmul_quant_arg.output_, matmul_quant_arg.out_act_min_, matmul_quant_arg.out_act_max_,
matmul_quant_arg.left_shift_[0], matmul_quant_arg.right_shift_[0],
matmul_quant_arg.quant_multiplier_[0]);
}
void NNaclInt8Serializer::CodeStruct(const std::string &name, const SubQuantArg &sub_quant_arg) {

@ -33,7 +33,6 @@
#include "nnacl/int8/relux_int8.h"
namespace mindspore::lite::micro::nnacl {
class NNaclInt8Serializer : public Serializer {
public:
NNaclInt8Serializer() = default;
@ -53,11 +52,10 @@ class NNaclInt8Serializer : public Serializer {
void CodeStruct(const std::string &name, const ::QuantMulArg &quant_mul_arg);
void CodeStruct(const std::string &name, const ReduceQuantArg &reduce_quant_arg);
void CodeStruct(const std::string &name, const ReshapeQuantArg &reshape_quant_arg);
void CodeStruct(const std::string &name, const MatmulQuantArg &matmul_quant_arg);
void CodeStruct(const std::string &name, const MatmulQuantParameter &matmul_quant_arg);
void CodeStruct(const std::string &name, const SubQuantArg &sub_quant_arg);
void CodeStruct(const std::string &name, const DivQuantArg &div_quant_arg);
void CodeStruct(const std::string &name, const ReluXQuantArg &relu_quant_arg);
};
} // namespace mindspore::lite::micro::nnacl
#endif // MINDSPORE_LITE_MICRO_CODER_OPCODERS_SERIALIZERS_NNACL_INT8_SERIALIZER_H_

@ -65,6 +65,7 @@ typedef struct MatMulParameter {
typedef struct MatmulQuantParameter {
QuantArg input_;
QuantArg weight_;
QuantArg output_;
int32_t out_act_min_;
int32_t out_act_max_;
@ -75,15 +76,4 @@ typedef struct MatmulQuantParameter {
int32_t *quant_multiplier_;
} MatmulQuantParameter;
typedef struct MatmulQuantArg {
QuantArg input;
QuantArg weight;
QuantArg output;
int32_t out_act_min;
int32_t out_act_max;
int32_t left_shift;
int32_t right_shift;
int32_t quant_multiplier;
} MatmulQuantArg;
#endif // MINDSPORE_LITE_NNACL_MATMUL_H_

Loading…
Cancel
Save