|
|
|
@ -358,15 +358,15 @@ static bool SearchUpperBound(const std::vector<float> &data, const size_t &index
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static float CalPercentile(const std::vector<float> &datas, const int &outlier_percent) {
|
|
|
|
|
const int size = datas.size();
|
|
|
|
|
static float CalPercentile(const std::vector<float> &data, const int &outlier_percent) {
|
|
|
|
|
const int size = data.size();
|
|
|
|
|
float val = outlier_percent / 100.0 * size;
|
|
|
|
|
int index = std::ceil(val);
|
|
|
|
|
float result;
|
|
|
|
|
if (index - val > 0) {
|
|
|
|
|
result = datas.at(index - 1);
|
|
|
|
|
result = data.at(index - 1);
|
|
|
|
|
} else {
|
|
|
|
|
result = (datas.at(index - 1) + datas.at(index)) / 2;
|
|
|
|
|
result = (data.at(index - 1) + data.at(index)) / 2;
|
|
|
|
|
}
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
@ -522,11 +522,78 @@ std::vector<std::vector<int>> DataToVectors(const string &str) {
|
|
|
|
|
return result;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) {
|
|
|
|
|
if (post_quant_config == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "post_quant_config is null.";
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
void ParseInputShape(PostQuantConfig *post_quant_config, std::string raw_shape) {
|
|
|
|
|
MS_ASSERT(post_quant_config != nullptr);
|
|
|
|
|
auto ind = raw_shape.find('/');
|
|
|
|
|
while (ind != std::string::npos) {
|
|
|
|
|
auto shape = raw_shape.substr(0, ind);
|
|
|
|
|
Trim(&shape);
|
|
|
|
|
post_quant_config->input_shapes.push_back(DataToVectors(shape));
|
|
|
|
|
raw_shape = raw_shape.substr(ind + 1);
|
|
|
|
|
Trim(&raw_shape);
|
|
|
|
|
ind = raw_shape.find('/');
|
|
|
|
|
}
|
|
|
|
|
if (!raw_shape.empty()) {
|
|
|
|
|
post_quant_config->input_shapes.push_back(DataToVectors(raw_shape));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParseImagePath(PostQuantConfig *post_quant_config, std::string raw_image_paths) {
|
|
|
|
|
MS_ASSERT(post_quant_config != nullptr);
|
|
|
|
|
auto ind = raw_image_paths.find(',');
|
|
|
|
|
while (ind != std::string::npos) {
|
|
|
|
|
auto image_path = raw_image_paths.substr(0, ind);
|
|
|
|
|
Trim(&image_path);
|
|
|
|
|
post_quant_config->image_paths.push_back(image_path);
|
|
|
|
|
raw_image_paths = raw_image_paths.substr(ind + 1);
|
|
|
|
|
Trim(&raw_image_paths);
|
|
|
|
|
ind = raw_image_paths.find(',');
|
|
|
|
|
}
|
|
|
|
|
post_quant_config->image_paths.push_back(raw_image_paths);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParseBatchCount(PostQuantConfig *post_quant_config, std::string value) {
|
|
|
|
|
MS_ASSERT(post_quant_config != nullptr);
|
|
|
|
|
post_quant_config->batch_count = std::stoul(value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParseThreadNum(PostQuantConfig *post_quant_config, std::string value) {
|
|
|
|
|
MS_ASSERT(post_quant_config != nullptr);
|
|
|
|
|
post_quant_config->thread_num = std::stoul(value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParseMethodX(PostQuantConfig *post_quant_config, const std::string &value) {
|
|
|
|
|
MS_ASSERT(post_quant_config != nullptr);
|
|
|
|
|
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
|
|
|
|
|
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
|
|
|
|
|
} else {
|
|
|
|
|
post_quant_config->method_x = value;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParseMixed(PostQuantConfig *post_quant_config, std::string value) {
|
|
|
|
|
MS_ASSERT(post_quant_config != nullptr);
|
|
|
|
|
std::for_each(value.begin(), value.end(), ::tolower);
|
|
|
|
|
if (value == "true") {
|
|
|
|
|
post_quant_config->mixed = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParseMeanErrorThreshold(PostQuantConfig *post_quant_config, std::string value) {
|
|
|
|
|
MS_ASSERT(post_quant_config != nullptr);
|
|
|
|
|
post_quant_config->mean_error_threshold = std::stof(value);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ParseBiasCorrection(PostQuantConfig *post_quant_config, std::string value) {
|
|
|
|
|
MS_ASSERT(post_quant_config != nullptr);
|
|
|
|
|
std::for_each(value.begin(), value.end(), ::tolower);
|
|
|
|
|
if (value == "true") {
|
|
|
|
|
post_quant_config->bias_correction = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_config) {
|
|
|
|
|
MS_ASSERT(post_quant_config != nullptr);
|
|
|
|
|
|
|
|
|
|
if (config_file.empty() || config_file.length() > PATH_MAX) {
|
|
|
|
|
MS_LOG(ERROR) << "invalid config path!";
|
|
|
|
@ -552,6 +619,26 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf
|
|
|
|
|
MS_LOG(ERROR) << "config file open failed: " << config_file;
|
|
|
|
|
return RET_PARAM_INVALID;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string INPUT_SHAPES = "input_shapes";
|
|
|
|
|
std::string IMAGE_PATH = "image_path";
|
|
|
|
|
std::string BATCH_COUNT = "batch_count";
|
|
|
|
|
std::string THREAD_NUM = "thread_num";
|
|
|
|
|
std::string METHOD_X = "method_x";
|
|
|
|
|
std::string MIXED = "mixed";
|
|
|
|
|
std::string MEAN_ERROR_THRESHOLD = "mean_error_threshold";
|
|
|
|
|
std::string BIAS_CORRECTION = "bias_correction";
|
|
|
|
|
|
|
|
|
|
std::map<std::string, std::function<void(PostQuantConfig *, std::string)>> value_parser;
|
|
|
|
|
value_parser[INPUT_SHAPES] = ParseInputShape;
|
|
|
|
|
value_parser[IMAGE_PATH] = ParseImagePath;
|
|
|
|
|
value_parser[BATCH_COUNT] = ParseBatchCount;
|
|
|
|
|
value_parser[THREAD_NUM] = ParseThreadNum;
|
|
|
|
|
value_parser[METHOD_X] = ParseMethodX;
|
|
|
|
|
value_parser[MIXED] = ParseMixed;
|
|
|
|
|
value_parser[MEAN_ERROR_THRESHOLD] = ParseMeanErrorThreshold;
|
|
|
|
|
value_parser[BIAS_CORRECTION] = ParseBiasCorrection;
|
|
|
|
|
|
|
|
|
|
std::string line;
|
|
|
|
|
while (std::getline(fs, line)) {
|
|
|
|
|
Trim(&line);
|
|
|
|
@ -567,54 +654,9 @@ STATUS ParseConfigFile(std::string config_file, PostQuantConfig *post_quant_conf
|
|
|
|
|
auto value = line.substr(index + 1);
|
|
|
|
|
Trim(&key);
|
|
|
|
|
Trim(&value);
|
|
|
|
|
if (key == "image_path") {
|
|
|
|
|
auto &raw_image_paths = value;
|
|
|
|
|
auto ind = raw_image_paths.find(',');
|
|
|
|
|
while (ind != std::string::npos) {
|
|
|
|
|
auto image_path = raw_image_paths.substr(0, ind);
|
|
|
|
|
Trim(&image_path);
|
|
|
|
|
post_quant_config->image_paths.push_back(image_path);
|
|
|
|
|
raw_image_paths = raw_image_paths.substr(ind + 1);
|
|
|
|
|
Trim(&raw_image_paths);
|
|
|
|
|
ind = raw_image_paths.find(',');
|
|
|
|
|
}
|
|
|
|
|
post_quant_config->image_paths.push_back(raw_image_paths);
|
|
|
|
|
} else if (key == "batch_count") {
|
|
|
|
|
post_quant_config->batch_count = std::stoul(value);
|
|
|
|
|
} else if (key == "thread_num") {
|
|
|
|
|
post_quant_config->thread_num = std::stoul(value);
|
|
|
|
|
} else if (key == "method_x") {
|
|
|
|
|
if (value != kMethodKL && value != kMethodMaxMin && value != kMethodOutlier) {
|
|
|
|
|
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
|
|
|
|
|
} else {
|
|
|
|
|
post_quant_config->method_x = value;
|
|
|
|
|
}
|
|
|
|
|
} else if (key == "bias_correction") {
|
|
|
|
|
std::for_each(value.begin(), value.end(), ::tolower);
|
|
|
|
|
if (value == "true") {
|
|
|
|
|
post_quant_config->bias_correction = true;
|
|
|
|
|
}
|
|
|
|
|
} else if (key == "mixed") {
|
|
|
|
|
std::for_each(value.begin(), value.end(), ::tolower);
|
|
|
|
|
if (value == "true") {
|
|
|
|
|
post_quant_config->mixed = true;
|
|
|
|
|
}
|
|
|
|
|
} else if (key == "mean_error_threshold") {
|
|
|
|
|
post_quant_config->mean_error_threshold = std::stof(value);
|
|
|
|
|
} else if (key == "input_shapes") {
|
|
|
|
|
auto &raw_shape = value;
|
|
|
|
|
auto ind = raw_shape.find('/');
|
|
|
|
|
while (ind != std::string::npos) {
|
|
|
|
|
auto shape = raw_shape.substr(0, ind);
|
|
|
|
|
Trim(&shape);
|
|
|
|
|
post_quant_config->input_shapes.push_back(DataToVectors(shape));
|
|
|
|
|
raw_shape = raw_shape.substr(ind + 1);
|
|
|
|
|
Trim(&raw_shape);
|
|
|
|
|
ind = raw_shape.find('/');
|
|
|
|
|
}
|
|
|
|
|
if (!raw_shape.empty()) {
|
|
|
|
|
post_quant_config->input_shapes.push_back(DataToVectors(raw_shape));
|
|
|
|
|
}
|
|
|
|
|
auto it = value_parser.find(key);
|
|
|
|
|
if (it != value_parser.end()) {
|
|
|
|
|
it->second(post_quant_config, value);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(WARNING) << "unsupported parameter: " << key;
|
|
|
|
|
}
|
|
|
|
@ -881,4 +923,24 @@ STATUS UpdateTensorDataAndSize(ParamValueLitePtr weight, void *quant_datas, int
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void GetMaxMinPerchannel(int channels, int one_filter_size, int i, int elem_count, float *raw_datas,
|
|
|
|
|
bool channel_at_first, float *desired_max, float *desired_min) {
|
|
|
|
|
float min = FLT_MAX;
|
|
|
|
|
float max = -FLT_MAX;
|
|
|
|
|
// find min and max
|
|
|
|
|
for (int j = 0; j < one_filter_size; j++) {
|
|
|
|
|
auto index = j + i * one_filter_size;
|
|
|
|
|
if (!channel_at_first) {
|
|
|
|
|
index = j * channels + i;
|
|
|
|
|
}
|
|
|
|
|
if (index >= elem_count) {
|
|
|
|
|
MS_LOG(ERROR) << "over flow!";
|
|
|
|
|
}
|
|
|
|
|
min = std::min(min, raw_datas[index]);
|
|
|
|
|
max = std::max(max, raw_datas[index]);
|
|
|
|
|
}
|
|
|
|
|
*desired_max = max;
|
|
|
|
|
*desired_min = min;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace mindspore::lite::quant
|
|
|
|
|