!8124 change shema of quant_params to reduce model size

Merge pull request !8124 from xutianchun/1102
pull/8124/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e51fcc5af5

@ -32,10 +32,9 @@ table QuantParam {
narrowRange: bool = true;
numBits: int = 8;
inited: bool = false;
varCorr: double = 1;
meanCorr: double = 0;
varCorr: float = 1;
meanCorr: float = 0;
dstDtype: int = 32;
clusters: [float];
}
table Tensor {
@ -49,6 +48,7 @@ table Tensor {
offset: int;
data: [ubyte];
quantParams: [QuantParam];
quantClusters: [float];
}
union PrimitiveType {

@ -107,15 +107,17 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
quant_arg.var_corr = quant_params->Get(j)->varCorr();
quant_arg.mean_corr = quant_params->Get(j)->meanCorr();
quant_arg.inited = quant_params->Get(j)->inited();
auto quant_clusters = quant_params->Get(j)->clusters();
if (quant_clusters != nullptr) {
for (size_t k = 0; k < quant_clusters->size(); k++) {
quant_arg.clusters.emplace_back(quant_clusters->Get(k));
}
}
dstTensor->AddQuantParam(quant_arg);
}
}
auto quant_clusters = srcTensor->quantClusters();
if (quant_clusters != nullptr) {
std::vector<float> clusters;
for (size_t j = 0; j < quant_clusters->size(); j++) {
clusters.push_back(quant_clusters->Get(j));
}
dstTensor->SetQuantClusters(clusters);
}
this->tensors_.emplace_back(dstTensor);
}

@ -79,11 +79,12 @@ class DequantUtil {
}
} else {
auto quant_param = input_tensor->GetQuantParams();
auto quant_clusters = input_tensor->GetQuantClusters();
auto param = quant_param.front();
auto scale = param.scale;
auto zero_point = param.zeroPoint;
for (int64_t j = 0; j < input_tensor->ElementsNum(); j++) {
if (param.clusters.size() != 0) {
if (!quant_clusters.empty()) {
int8_t index = quant_datas[j];
if (index > INT8_MAX || index < INT8_MIN) {
MS_LOG(ERROR) << "KMeans param quant is error.";

@ -367,6 +367,10 @@ void Tensor::AddQuantParam(const QuantArg &quant_arg) { this->quant_params_.push
std::vector<QuantArg> Tensor::GetQuantParams() const { return this->quant_params_; }
std::vector<float> Tensor::GetQuantClusters() const { return this->quant_clusters_; }
void Tensor::SetQuantClusters(const std::vector<float> &clusters) { this->quant_clusters_ = clusters; }
std::vector<tensor::MSTensor *> TensorVectorCast(const std::vector<Tensor *> &src) {
std::vector<tensor::MSTensor *> target(src.size());
std::transform(src.begin(), src.end(), target.begin(), [](Tensor *t) { return dynamic_cast<tensor::MSTensor *>(t); });

@ -33,8 +33,8 @@ namespace lite {
struct QuantArg {
double scale;
int32_t zeroPoint;
double var_corr{1};
double mean_corr{0};
float var_corr{1};
float mean_corr{0};
bool inited;
std::vector<float> clusters{};
};
@ -119,6 +119,10 @@ class Tensor : public mindspore::tensor::MSTensor {
std::vector<QuantArg> GetQuantParams() const;
std::vector<float> GetQuantClusters() const;
void SetQuantClusters(const std::vector<float> &clusters);
bool IsConst();
bool IsScalar();
@ -138,6 +142,7 @@ class Tensor : public mindspore::tensor::MSTensor {
Category category_;
size_t ref_count_ = 0;
std::vector<QuantArg> quant_params_;
std::vector<float> quant_clusters_;
mindspore::lite::Allocator *allocator_ = nullptr;
};

@ -449,7 +449,6 @@ std::vector<int8_t> KMeans(float *data, size_t elem_count, size_t k, size_t epoc
error = error_cur;
}
// update data
quantParam->clusters = clusters;
return clusters_index;
}

@ -130,7 +130,7 @@ T QuantizeData(float originData, const schema::QuantParamT &quantParam, int quan
}
template <typename T>
STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primitive_c, QuantType quantType,
int quant_max, int quant_min, size_t bitNum, bool per_channel) {
int quant_max, int quant_min, size_t bitNum, bool per_channel, bool k_means = false) {
auto dims = weight->tensor_shape();
auto op_type = (schema::PrimitiveType)primitive_c->Type();
if (per_channel) {
@ -208,7 +208,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
average_raw += raw_data;
}
}
if (quantType == QuantType_WeightQuant && quant_param.clusters.size() == 0) {
if (quantType == QuantType_WeightQuant && !k_means) {
// mean
average_dequant = average_dequant / one_filter_size;
average_raw = average_raw / one_filter_size;
@ -256,7 +256,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
}
schema::QuantParamT quant_param;
if (quant_param.clusters.size() == 0) {
if (!k_means) {
STATUS status = CalQuantizationParams(&quant_param, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
@ -267,7 +267,7 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr<PrimitiveC> primiti
// update data and datatype
for (uint32_t i = 0; i < elem_count; i++) {
float raw_data = raw_datas[i];
if (quant_param.clusters.size() == 0) {
if (!k_means) {
auto quant_data = QuantizeData<T>(raw_data, quant_param, quant_max, quant_min);
quant_datas[i] = quant_data;
}

Loading…
Cancel
Save