|
|
|
@ -122,8 +122,10 @@ void AnfConvPopulater::CalQuantParam(const double &mean, const double &stdDev, f
|
|
|
|
|
*mMax = static_cast<float>((qmax - mean) / stdDev);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|
|
|
|
std::vector<std::vector<schema::QuantParamT>> *vecQuantParam) {
|
|
|
|
|
void AnfConvPopulater::PopulaterQuantParam(
|
|
|
|
|
const PrimitivePtr &prim,
|
|
|
|
|
std::vector<std::vector<schema::QuantParamT>> *vecInputQuantParam,
|
|
|
|
|
std::vector<std::vector<schema::QuantParamT>> *vecOutputQuantParam) {
|
|
|
|
|
auto narrow_range = prim->GetAttr("narrow_range");
|
|
|
|
|
bool narrowRangeQuantParam = GetValue<bool>(narrow_range);
|
|
|
|
|
auto num_bits = prim->GetAttr("num_bits");
|
|
|
|
@ -154,7 +156,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|
|
|
|
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
|
|
|
|
numbitsRangeQuantParam);
|
|
|
|
|
quants.emplace_back(quantParam);
|
|
|
|
|
vecQuantParam->emplace_back(quants);
|
|
|
|
|
vecInputQuantParam->emplace_back(quants);
|
|
|
|
|
|
|
|
|
|
quants.clear();
|
|
|
|
|
int biasQuantSize = 0;
|
|
|
|
@ -173,7 +175,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|
|
|
|
numbitsRangeQuantParam);
|
|
|
|
|
quants.emplace_back(quantParam);
|
|
|
|
|
}
|
|
|
|
|
vecQuantParam->emplace_back(quants);
|
|
|
|
|
vecInputQuantParam->emplace_back(quants);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
quants.clear();
|
|
|
|
@ -181,10 +183,12 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|
|
|
|
quantParam.min = 0.0;
|
|
|
|
|
quantParam.max = 0.0;
|
|
|
|
|
quantParam.zeroPoint = 0;
|
|
|
|
|
quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale;
|
|
|
|
|
|
|
|
|
|
quantParam.scale =
|
|
|
|
|
vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale;
|
|
|
|
|
quants.emplace_back(quantParam);
|
|
|
|
|
}
|
|
|
|
|
vecQuantParam->emplace_back(quants);
|
|
|
|
|
vecInputQuantParam->emplace_back(quants);
|
|
|
|
|
|
|
|
|
|
quants.clear();
|
|
|
|
|
auto outputMin = prim->GetAttr("output_minq");
|
|
|
|
@ -199,7 +203,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim,
|
|
|
|
|
quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam,
|
|
|
|
|
numbitsRangeQuantParam);
|
|
|
|
|
quants.emplace_back(quantParam);
|
|
|
|
|
vecQuantParam->emplace_back(quants);
|
|
|
|
|
vecOutputQuantParam->emplace_back(quants);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -215,10 +219,13 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit
|
|
|
|
|
PopulaterConv2DSingleGroup(prim, primitive, group);
|
|
|
|
|
}
|
|
|
|
|
primitiveTValuePtr->SetPrimitiveT(primitive.release());
|
|
|
|
|
|
|
|
|
|
if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) {
|
|
|
|
|
std::vector<std::vector<schema::QuantParamT>> vecQuantParam;
|
|
|
|
|
PopulaterQuantParam(prim, &vecQuantParam);
|
|
|
|
|
primitiveTValuePtr->SetInputQuantParam(vecQuantParam);
|
|
|
|
|
std::vector<std::vector<schema::QuantParamT>> vecInputQuantParam;
|
|
|
|
|
std::vector<std::vector<schema::QuantParamT>> vecOutputQuantParam;
|
|
|
|
|
PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam);
|
|
|
|
|
primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam);
|
|
|
|
|
primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam);
|
|
|
|
|
}
|
|
|
|
|
return 0;
|
|
|
|
|
}
|
|
|
|
|