fix post training code

pull/7780/head
xutianchun 4 years ago
parent 0fb165ceb7
commit 022c530de0

@ -565,7 +565,7 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in
} }
STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min,
std::shared_ptr<PrimitiveC> lite_primitive, const size_t &index) { std::shared_ptr<PrimitiveC> lite_primitive) {
schema::QuantParamT quant_param; schema::QuantParamT quant_param;
quant_param.scale = scale; quant_param.scale = scale;
quant_param.zeroPoint = zeropoint; quant_param.zeroPoint = zeropoint;
@ -573,8 +573,9 @@ STATUS PostTrainingQuantizer::DoQuantInput(double scale, int32_t zeropoint, stru
quant_param.min = max_min->min; quant_param.min = max_min->min;
quant_param.numBits = bit_num; quant_param.numBits = bit_num;
quant_param.narrowRange = false; quant_param.narrowRange = false;
quant_param.inited = true;
std::vector<schema::QuantParamT> quant_params = {quant_param}; std::vector<schema::QuantParamT> quant_params = {quant_param};
lite_primitive->SetInputQuantParam(index, quant_params); lite_primitive->AddInputQuantParam(quant_params);
return RET_OK; return RET_OK;
} }
@ -589,7 +590,7 @@ STATUS PostTrainingQuantizer::DoQuantOutput(double scale, int zeropoint, struct
quant_param.narrowRange = false; quant_param.narrowRange = false;
quant_param.inited = true; quant_param.inited = true;
std::vector<schema::QuantParamT> quant_params = {quant_param}; std::vector<schema::QuantParamT> quant_params = {quant_param};
lite_primitive->SetOutputQuantParam(0, quant_params); lite_primitive->AddOutputQuantParam(quant_params);
return RET_OK; return RET_OK;
} }
@ -642,7 +643,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param); auto bias_param = std::dynamic_pointer_cast<ParamValueLite>(bias_default_param);
auto active_weight_quant_params = primitive_c->GetInputQuantParams(); auto active_weight_quant_params = primitive_c->GetInputQuantParams();
if (active_weight_quant_params.size() != 3) { if (active_weight_quant_params.size() != 2) {
MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size(); MS_LOG(ERROR) << "unexpected active_weight_quant_params size: " << active_weight_quant_params.size();
return RET_ERROR; return RET_ERROR;
} }
@ -721,7 +722,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(AnfNodePtr bias, std::shared_ptr<Primi
auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp); auto quant_data = (int32_t)std::round(raw_datas[i] / bias_scale_tmp);
quant_datas[i] = quant_data; quant_datas[i] = quant_data;
} }
primitive_c->SetInputQuantParam(2, quant_params); primitive_c->AddInputQuantParam(quant_params);
auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t)); auto ret = memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t));
if (ret != EOK) { if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed."; MS_LOG(ERROR) << "memcpy_s failed.";
@ -832,19 +833,19 @@ STATUS PostTrainingQuantizer::QuantNode() {
} }
if (input_cnode_primitive_c->IsOutputQuantParamsInited()) { if (input_cnode_primitive_c->IsOutputQuantParamsInited()) {
auto quant_param = input_cnode_primitive_c->GetOutputQuantParams().front(); auto quant_param = input_cnode_primitive_c->GetOutputQuantParams().front();
primitive_c->SetInputQuantParam(i - 1, quant_param); primitive_c->AddInputQuantParam(quant_param);
} else { } else {
// do input quant // do input quant
double scale = input_scale[cnode]; double scale = input_scale[cnode];
int32_t zp = input_zero_point[cnode]; int32_t zp = input_zero_point[cnode];
DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c, i - 1); DoQuantInput(scale, zp, &input_min_max[cnode], primitive_c);
} }
} }
} else { } else {
// do input quant // do input quant
double scale = input_scale[cnode]; double scale = input_scale[cnode];
int32_t convInputzeropoint = input_zero_point[cnode]; int32_t convInputzeropoint = input_zero_point[cnode];
DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c, 0); DoQuantInput(scale, convInputzeropoint, &input_min_max[cnode], primitive_c);
// do weight quant // do weight quant
auto weight = cnode->input(2); auto weight = cnode->input(2);
bool perchannel = per_channel_; bool perchannel = per_channel_;
@ -916,6 +917,12 @@ STATUS PostTrainingQuantizer::PreProcess() {
if (strategy.CanOpPostQuantized(anf)) { if (strategy.CanOpPostQuantized(anf)) {
calibrator_->AddQuantizedOp(cnode); calibrator_->AddQuantizedOp(cnode);
} }
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
if (primitive_c == nullptr) {
MS_LOG(ERROR) << cnode->fullname_with_scope() << " primitive is null";
continue;
}
primitive_c->ClearInputOutputQuantParam();
} }
return RET_OK; return RET_OK;
} }

@ -107,7 +107,7 @@ class PostTrainingQuantizer : public Quantizer {
STATUS QuantNode(); STATUS QuantNode();
STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min, STATUS DoQuantInput(double scale, int32_t zeropoint, struct MaxMin *max_min,
std::shared_ptr<PrimitiveC> lite_primitive, const size_t &index); std::shared_ptr<PrimitiveC> lite_primitive);
STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>); STATUS DoQuantOutput(double scale, int32_t zeropoint, struct MaxMin *max_min, std::shared_ptr<PrimitiveC>);
STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel); STATUS DoWeightQuant(AnfNodePtr weight, std::shared_ptr<PrimitiveC> primitive_c, bool perchannel);

@ -283,7 +283,11 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
MS_LOG(ERROR) << "quant_params empty"; MS_LOG(ERROR) << "quant_params empty";
return RET_ERROR; return RET_ERROR;
} }
primitive_c->SetInputQuantParam(WEIGHT_INDEX, quant_params); if (quantType == QuantType_PostTraining) {
primitive_c->AddInputQuantParam(quant_params);
} else {
primitive_c->SetInputQuantParam(WEIGHT_INDEX, quant_params);
}
return RET_OK; return RET_OK;
} }

Loading…
Cancel
Save