|
|
|
@ -253,11 +253,11 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
|
|
|
|
|
}
|
|
|
|
|
auto status = RET_ERROR;
|
|
|
|
|
if (type_id_ == kNumberTypeInt8) {
|
|
|
|
|
status =
|
|
|
|
|
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
|
|
|
|
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
|
|
|
|
false, 1);
|
|
|
|
|
} else if (type_id_ == kNumberTypeInt16) {
|
|
|
|
|
status =
|
|
|
|
|
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
|
|
|
|
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
|
|
|
|
false, 1);
|
|
|
|
|
}
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
|
|
|
@ -316,11 +316,11 @@ STATUS WeightQuantizer::DoLstmQuntize(CNodePtr cnode) {
|
|
|
|
|
}
|
|
|
|
|
auto status = RET_ERROR;
|
|
|
|
|
if (type_id_ == kNumberTypeInt8) {
|
|
|
|
|
status =
|
|
|
|
|
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
|
|
|
|
status = QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
|
|
|
|
false, 3);
|
|
|
|
|
} else if (type_id_ == kNumberTypeInt16) {
|
|
|
|
|
status = QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_,
|
|
|
|
|
false);
|
|
|
|
|
false, 3);
|
|
|
|
|
}
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
|
|
|
@ -340,10 +340,10 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
|
|
|
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
|
|
|
|
MS_ASSERT(primitive_c != nullptr);
|
|
|
|
|
|
|
|
|
|
auto weight_h = cnode->input(1);
|
|
|
|
|
auto first_input = cnode->input(1);
|
|
|
|
|
ParameterPtr param_node;
|
|
|
|
|
ParamValueLitePtr param_value;
|
|
|
|
|
GetLiteParameter(weight_h, ¶m_node, ¶m_value);
|
|
|
|
|
GetLiteParameter(first_input, ¶m_node, ¶m_value);
|
|
|
|
|
if (param_node == nullptr || param_value == nullptr || param_value->tensor_type() != TypeId::kNumberTypeFloat32) {
|
|
|
|
|
MS_LOG(INFO) << "This Gather op " << cnode->fullname_with_scope() << " can not quant weight";
|
|
|
|
|
return RET_OK;
|
|
|
|
@ -358,10 +358,10 @@ STATUS WeightQuantizer::DoGatherQuntize(CNodePtr cnode) {
|
|
|
|
|
auto status = RET_ERROR;
|
|
|
|
|
if (type_id_ == kNumberTypeInt8) {
|
|
|
|
|
status =
|
|
|
|
|
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
|
|
|
|
QuantFilter<int8_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0);
|
|
|
|
|
} else if (type_id_ == kNumberTypeInt16) {
|
|
|
|
|
status =
|
|
|
|
|
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false);
|
|
|
|
|
QuantFilter<int16_t>(param_value, primitive_c, QuantType_WeightQuant, quant_max_, quant_min_, bit_num_, false, 0);
|
|
|
|
|
}
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "QuantFilter failed : " << status;
|
|
|
|
@ -510,7 +510,7 @@ STATUS WeightQuantizer::RunFp32Graph(FuncGraphPtr func_graph) {
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
|
|
|
|
|
STATUS WeightQuantizer::DoMixedQuant(FuncGraphPtr func_graph) {
|
|
|
|
|
// 0.2 Parse input calib files
|
|
|
|
|
auto status = CollectCalibInputs(config_param_.image_paths, config_param_.batch_count, &images_);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
@ -652,7 +652,7 @@ STATUS WeightQuantizer::DoMiexedQuant(FuncGraphPtr func_graph) {
|
|
|
|
|
delete quant_sm.model;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
// 3. compare betwen quant and fp32
|
|
|
|
|
// 3. compare between quant and fp32
|
|
|
|
|
auto quant_outputs = quant_session->GetOutputs();
|
|
|
|
|
mean_error += CompareOutputData<float>(fp32_output_tensors_[i], quant_outputs);
|
|
|
|
|
} // end_for: calib data loop
|
|
|
|
@ -690,8 +690,8 @@ STATUS WeightQuantizer::DoFixedQuant(FuncGraphPtr func_graph) {
|
|
|
|
|
for (auto &cnode : func_graph->GetOrderedCnodes()) {
|
|
|
|
|
auto primitive_c = GetValueNode<std::shared_ptr<PrimitiveC>>(cnode->input(0));
|
|
|
|
|
if (primitive_c == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "primitive_c is nullptr";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " : primitive_c is nullptr";
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
auto op_name = cnode->fullname_with_scope();
|
|
|
|
|
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
|
|
|
@ -744,7 +744,7 @@ STATUS WeightQuantizer::DoQuantize(FuncGraphPtr func_graph) {
|
|
|
|
|
quant_min_ = -(1 << (unsigned int)(this->bit_num_ - 1));
|
|
|
|
|
type_id_ = kNumberTypeInt8;
|
|
|
|
|
MS_LOG(INFO) << "Do mixed bit quantization";
|
|
|
|
|
return DoMiexedQuant(func_graph);
|
|
|
|
|
return DoMixedQuant(func_graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return DoFixedQuant(func_graph);
|
|
|
|
|