|
|
|
@ -130,7 +130,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
|
|
|
|
|
}
|
|
|
|
|
template <typename T>
|
|
|
|
|
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
|
|
|
|
|
int quant_max, int quant_min, size_t bitNum, bool per_channel) {
|
|
|
|
|
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool k_means = false) {
|
|
|
|
|
auto dims = weight->tensor_shape();
|
|
|
|
|
auto op_type = (schema::PrimitiveType)primitive_c->Type();
|
|
|
|
|
if (per_channel) {
|
|
|
|
@ -208,7 +208,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
|
|
|
|
average_raw += raw_data;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if (quantType == QuantType_WeightQuant && quant_param.clusters.size() == 0) {
|
|
|
|
|
if (quantType == QuantType_WeightQuant && !k_means) {
|
|
|
|
|
// mean
|
|
|
|
|
average_dequant = average_dequant / one_filter_size;
|
|
|
|
|
average_raw = average_raw / one_filter_size;
|
|
|
|
@ -256,7 +256,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
schema::QuantParamT quant_param;
|
|
|
|
|
if (quant_param.clusters.size() == 0) {
|
|
|
|
|
if (!k_means) {
|
|
|
|
|
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
|
|
|
|
@ -267,7 +267,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
|
|
|
|
|
// update data and datatype
|
|
|
|
|
for (uint32_t i = 0; i < elem_count; i++) {
|
|
|
|
|
float raw_data = raw_datas[i];
|
|
|
|
|
if (quant_param.clusters.size() == 0) {
|
|
|
|
|
if (!k_means) {
|
|
|
|
|
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
|
|
|
|
|
quant_datas[i] = quant_data;
|
|
|
|
|
}
|
|
|
|
|