|
|
|
@ -28,7 +28,7 @@
|
|
|
|
|
#include "schema/inner/model_generated.h"
|
|
|
|
|
#include "src/ir/tensor.h"
|
|
|
|
|
#include "src/common/anf_exporter/anf_exporter.h"
|
|
|
|
|
#include "tools/converter/quantizer/post_training.h"
|
|
|
|
|
#include "tools/converter/quantizer/post_training_quantizer.h"
|
|
|
|
|
#include "tools/converter/quantizer/quantize_util.h"
|
|
|
|
|
#include "src/common/common.h"
|
|
|
|
|
#include "utils/log_adapter.h"
|
|
|
|
@ -54,7 +54,10 @@ struct DivergInfo {
|
|
|
|
|
size_t bit_num;
|
|
|
|
|
int quant_max = 255;
|
|
|
|
|
int quant_min = 0;
|
|
|
|
|
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min) {
|
|
|
|
|
std::string method_x = kMethodKL;
|
|
|
|
|
|
|
|
|
|
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) {
|
|
|
|
|
this->method_x = method_x;
|
|
|
|
|
this->cnode = cnode;
|
|
|
|
|
this->bin_num = bins;
|
|
|
|
|
this->bit_num = bits;
|
|
|
|
@ -99,6 +102,12 @@ struct DivergInfo {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS ComputeThreshold() {
|
|
|
|
|
if (method_x == kMethodMaxMin) {
|
|
|
|
|
this->best_T = std::max(fabs(this->max), fabs(this->min));
|
|
|
|
|
MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T;
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constexpr int quant_bint_nums = 128;
|
|
|
|
|
int threshold = quant_bint_nums;
|
|
|
|
|
float min_kl = FLT_MAX;
|
|
|
|
@ -200,46 +209,32 @@ struct DivergInfo {
|
|
|
|
|
threshold = i;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "Best threshold bin index: " << threshold;
|
|
|
|
|
this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval;
|
|
|
|
|
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold
|
|
|
|
|
<< " T: " << best_T
|
|
|
|
|
<< " max: " << std::max(fabs(this->max), fabs(this->min));
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::pair<CNodePtr, float> GetScale() {
|
|
|
|
|
float max_value = this->best_T;
|
|
|
|
|
float min_value = -max_value;
|
|
|
|
|
|
|
|
|
|
MS_ASSERT(quant_max - quant_min != 0);
|
|
|
|
|
double scale = (max_value - min_value) / (quant_max - quant_min);
|
|
|
|
|
float scale = (max_value - min_value) / (quant_max - quant_min);
|
|
|
|
|
MS_ASSERT(scale != 0);
|
|
|
|
|
return std::make_pair(this->cnode, scale);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::pair<CNodePtr, int32_t> GetZeropoint() {
|
|
|
|
|
float max_value = this->best_T;
|
|
|
|
|
float min_value = -max_value;
|
|
|
|
|
MS_ASSERT(quant_max - quant_min != 0);
|
|
|
|
|
float scale = (max_value - min_value) / (quant_max - quant_min);
|
|
|
|
|
|
|
|
|
|
auto quant_min_float = static_cast<float>(quant_min);
|
|
|
|
|
auto quant_max_float = static_cast<float>(quant_max);
|
|
|
|
|
MS_ASSERT(scale != 0);
|
|
|
|
|
const float zero_point_from_min = quant_min_float - min_value / scale;
|
|
|
|
|
// const float zero_point_from_max = quant_max_float - max_value / scale;
|
|
|
|
|
int zero_point;
|
|
|
|
|
if (zero_point_from_min < quant_min_float) {
|
|
|
|
|
zero_point = quant_min;
|
|
|
|
|
} else if (zero_point_from_min > quant_max_float) {
|
|
|
|
|
zero_point = quant_max;
|
|
|
|
|
} else {
|
|
|
|
|
zero_point = static_cast<int>(std::round(zero_point_from_min));
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "zero point:" << zero_point;
|
|
|
|
|
int zero_point = 0;
|
|
|
|
|
if (quant_min == 0 && quant_max == 255) {
|
|
|
|
|
zero_point = 128;
|
|
|
|
|
} else if (quant_min == -128 && quant_max == 127) {
|
|
|
|
|
zero_point = 0;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return std::make_pair(this->cnode, zero_point);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -356,9 +351,9 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) {
|
|
|
|
|
}
|
|
|
|
|
string node_name = node->fullname_with_scope();
|
|
|
|
|
std::unique_ptr<DivergInfo> input_diverg =
|
|
|
|
|
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_));
|
|
|
|
|
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x));
|
|
|
|
|
std::unique_ptr<DivergInfo> output_diverg =
|
|
|
|
|
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_));
|
|
|
|
|
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x));
|
|
|
|
|
|
|
|
|
|
input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg)));
|
|
|
|
|
output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg)));
|
|
|
|
@ -383,13 +378,13 @@ STATUS Calibrator::GenerateInputData(const int index, mindspore::tensor::MSTenso
|
|
|
|
|
MS_LOG(INFO) << "read image: " << path;
|
|
|
|
|
size_t size;
|
|
|
|
|
char *binBuf = ReadFile(path.c_str(), &size);
|
|
|
|
|
|
|
|
|
|
// auto *rawinputDatas = reinterpret_cast<const float *>(binBuf);
|
|
|
|
|
// auto mobilenet_input = const_cast<float *>(rawinputDatas);
|
|
|
|
|
auto data = tensor->MutableData();
|
|
|
|
|
if (size != tensor->Size()) {
|
|
|
|
|
MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size
|
|
|
|
|
<< " input tensor size: " << tensor->Size();
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
memcpy(data, binBuf, size);
|
|
|
|
|
|
|
|
|
|
// tensor->SetData(mobilenet_input);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -457,13 +452,20 @@ STATUS Calibrator::ReadConfig() {
|
|
|
|
|
config_param_.batch_count = std::stoul(value);
|
|
|
|
|
} else if (key == "thread_num") {
|
|
|
|
|
config_param_.thread_num = std::stoul(value);
|
|
|
|
|
} else if (key == "method_x") {
|
|
|
|
|
if (value != kMethodKL && value != kMethodMaxMin) {
|
|
|
|
|
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
|
|
|
|
|
} else {
|
|
|
|
|
config_param_.method_x = value;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(WARNING) << "unsupported parameter";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "image_path: " << config_param_.image_path << " "
|
|
|
|
|
<< "batch_count: " << config_param_.batch_count << " "
|
|
|
|
|
<< "thread_num: " << config_param_.thread_num;
|
|
|
|
|
MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " "
|
|
|
|
|
<< "batch_count: " << config_param_.batch_count << " "
|
|
|
|
|
<< "mothod_x: " << config_param_.method_x << " "
|
|
|
|
|
<< "thread_num: " << config_param_.thread_num;
|
|
|
|
|
|
|
|
|
|
delete[] resolved_path;
|
|
|
|
|
fs.close();
|
|
|
|
@ -615,7 +617,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input
|
|
|
|
|
quant_datas[i] = quant_data;
|
|
|
|
|
}
|
|
|
|
|
auto ret =
|
|
|
|
|
memcpy_s(bias_param->tensor_addr(), shape_size * sizeof(int32_t), quant_datas, shape_size * sizeof(int32_t));
|
|
|
|
|
memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t));
|
|
|
|
|
if (ret != EOK) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s failed.";
|
|
|
|
|
delete[] quant_datas;
|
|
|
|
@ -805,14 +807,6 @@ STATUS PostTrainingQuantizer::DoInference() {
|
|
|
|
|
MS_LOG(ERROR) << "generate input data from images failed!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
/**
|
|
|
|
|
* struct CallBackParam {
|
|
|
|
|
std::string nodeType;
|
|
|
|
|
NODE_ID nodeName;
|
|
|
|
|
std::unordered_set<NODE_ID> depends;
|
|
|
|
|
int opExecResult;
|
|
|
|
|
};
|
|
|
|
|
*/
|
|
|
|
|
mindspore::session::KernelCallBack beforeCallBack =
|
|
|
|
|
[&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
|
|
|
|
|
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
|
|
|
|
@ -916,9 +910,26 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
|
|
|
|
|
MS_LOG(ERROR) << "do pre process failed!";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// anf -- fb
|
|
|
|
|
auto meta_graph = Export(funcGraph);
|
|
|
|
|
if (meta_graph == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "Export to meta_graph return nullptr";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// transform
|
|
|
|
|
GraphDefTransform transform;
|
|
|
|
|
transform.SetGraphDef(meta_graph);
|
|
|
|
|
flags.quantType = schema::QuantType_QUANT_NONE;
|
|
|
|
|
status = transform.Transform(flags);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "FBTransform model failed " << status;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "start create session";
|
|
|
|
|
flatbuffers::FlatBufferBuilder builder(1024);
|
|
|
|
|
auto offset = schema::MetaGraph::Pack(builder, Export(funcGraph));
|
|
|
|
|
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
|
|
|
|
|
builder.Finish(offset);
|
|
|
|
|
size_t size = builder.GetSize();
|
|
|
|
|
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());
|