From 022c530de02456d4e26130e38fe3e11ec6cfdb00 Mon Sep 17 00:00:00 2001 From: xutianchun Date: Mon, 26 Oct 2020 18:56:52 +0800 Subject: [PATCH] fix post training code --- .../quantizer/post_training_quantizer.cc | 23 ++++++++++++------- .../quantizer/post_training_quantizer.h | 2 +- .../tools/converter/quantizer/quantize_util.h | 6 ++++- 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc index dd1f010de3..5d391765be 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.cc @@ -565,7 +565,7 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in } STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, - std::shared_ptr lite_primitive, const size_t &index) { + std::shared_ptr lite_primitive) { schema::QuantParamT quant_param; quant_param.scale = scale; quant_param.zeroPoint = zeropoint; @@ -573,8 +573,9 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, stru quant_param.min = max_min->min; quant_param.numBits = bit_num; quant_param.narrowRange = false; + quant_param.inited = true; std::vector quant_params = {quant_param}; - lite_primitive->SetInputQuantParam(index, quant_params); + lite_primitive->AddInputQuantParam(quant_params); return RET_OK; } @@ -589,7 +590,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct quant_param.narrowRange = false; quant_param.inited = true; std::vector quant_params = {quant_param}; - lite_primitive->SetOutputQuantParam(0, quant_params); + lite_primitive->AddOutputQuantParam(quant_params); return RET_OK; } @@ -642,7 +643,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr(bias_default_param); auto active_weight_quant_params = primitive_c->GetInputQuantParams(); - if (active_weight_quant_params.size() != 3) { + if (active_weight_quant_params.size() != 2) { MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size(); return RET_ERROR; } @@ -721,7 +722,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptrSetInputQuantParam(2, quant_params); + primitive_c->AddInputQuantParam(quant_params); auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t)); if (ret != EOK) { MS_LOG(ERROR) << "memcpy_s failed."; @@ -832,19 +833,19 @@ STATUS PostTrainingQuantizer::QuantNode() { } if (input_cnode_primitive_c->IsOutputQuantParamsInited()) { auto quant_param = input_cnode_primitive_c->GetOutputQuantParams().front(); - primitive_c->SetInputQuantParam(i - 1, quant_param); + primitive_c->AddInputQuantParam(quant_param); } else { // do input quant double scale = input_scale[cnode]; int32_t zp = input_zero_point[cnode]; - DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c, i - 1); + DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c); } } } else { // do input quant double scale = input_scale[cnode]; int32_t convInputzeropoint = input_zero_point[cnode]; - DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c, 0); + DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c); // do weight quant auto weight = cnode->input(2); bool perchannel = per_channel_; @@ -916,6 +917,12 @@ STATUS PostTrainingQuantizer::PreProcess() { if (strategy.CanOpPostQuantized(anf)) { calibrator_->AddQuantizedOp(cnode); } + auto primitive_c = GetValueNode>(cnode->input(0)); + if (primitive_c == nullptr) { + MS_LOG(ERROR) << cnode->fullname_with_scope() << " primitive is null"; + continue; + } + primitive_c->ClearInputOutputQuantParam(); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h index 9a9a0aae5d..037b2f1593 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h +++ b/mindspore/lite/tools/converter/quantizer/post_training_quantizer.h @@ -107,7 +107,7 @@ class PostTrainingQuantizer : public Quantizer { STATUS QuantNode(); STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, - std::shared_ptr lite_primitive, const size_t &index); + std::shared_ptr lite_primitive); STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr); STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr primitive_c, bool perchannel); diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index ed1fa0cb85..4fbacb854b 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -283,7 +283,11 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti MS_LOG(ERROR) << "quant_params empty"; return RET_ERROR; } - primitive_c->SetInputQuantParam(WEIGHT_INDEX, quant_params); + if (quantType == QuantType_PostTraining) { + primitive_c->AddInputQuantParam(quant_params); + } else { + primitive_c->SetInputQuantParam(WEIGHT_INDEX, quant_params); + } return RET_OK; }