post training quantization

pull/4325/head
xutianchun 5 years ago
parent 13c2b23356
commit f32ff96ecd

@ -118,7 +118,7 @@ int QuantDTypeCastCPUKernel::Run() {
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->Data()); int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->Data());
} }
int ret = LiteBackendParallelLaunch(QuantDTypeCastRun, this, thread_n_num_); auto ret = LiteBackendParallelLaunch(QuantDTypeCastRun, this, thread_n_num_);
if (ret != RET_OK) { if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale error error_code[" << ret << "]"; MS_LOG(ERROR) << "Scale error error_code[" << ret << "]";
return RET_ERROR; return RET_ERROR;

@ -39,7 +39,7 @@ class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel {
int Run() override; int Run() override;
private: private:
float *sum_data_; float *sum_data_ = nullptr;
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel

@ -31,7 +31,7 @@
#include "src/common/anf_importer/import_from_protobuf.h" #include "src/common/anf_importer/import_from_protobuf.h"
#include "tools/converter/parser/onnx/onnx.pb.h" #include "tools/converter/parser/onnx/onnx.pb.h"
#include "tools/converter/quantizer/weight_quantizer.h" #include "tools/converter/quantizer/weight_quantizer.h"
#include "tools/converter/quantizer/post_training.h" #include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h" #include "tools/converter/quantizer/quant_cast.h"
namespace mindspore { namespace mindspore {
@ -94,6 +94,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
CreateQuantizer(graph, flag); CreateQuantizer(graph, flag);
if (mQuantizer != nullptr) { if (mQuantizer != nullptr) {
mQuantizer->flags = *flag;
auto status = mQuantizer->DoQuantize(graph); auto status = mQuantizer->DoQuantize(graph);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "Quant failed " << status; MS_LOG(ERROR) << "Quant failed " << status;

@ -277,8 +277,9 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx } else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC); status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC);
MS_LOG(DEBUG) << node->name << " weight trans format: CHWK->KHWC";
} else { } else {
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK); status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} }
} else if (weightTensor->format == schema::Format_KCHW) { } else if (weightTensor->format == schema::Format_KCHW) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
@ -291,8 +292,8 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
return -1; return -1;
} }
if (status == 0) { if (status == 0) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW; node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_HWCK; weightTensor->format = schema::Format_KHWC;
} else { } else {
MS_LOG(WARNING) << "TransFilter %ToHWCK failed, node : " MS_LOG(WARNING) << "TransFilter %ToHWCK failed, node : "
<< (weightTensor->format == schema::Format_CHWK ? "CHWK" : "CKHW"), << (weightTensor->format == schema::Format_CHWK ? "CHWK" : "CKHW"),

@ -4,15 +4,12 @@ include_directories(${3RD_DIR}/flatbuffers/include)
include_directories(${3RD_DIR}/opencv/build/include/opencv4) include_directories(${3RD_DIR}/opencv/build/include/opencv4)
add_library(quantizer_mid OBJECT add_library(quantizer_mid OBJECT
#${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc
#${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc
${CMAKE_CURRENT_SOURCE_DIR}/post_training.cc ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc
#${CMAKE_CURRENT_SOURCE_DIR}/../proto/post_training/post_training.pb.cc
) )
if(ENABLE_ASAN) if(ENABLE_ASAN)

@ -28,7 +28,7 @@
#include "schema/inner/model_generated.h" #include "schema/inner/model_generated.h"
#include "src/ir/tensor.h" #include "src/ir/tensor.h"
#include "src/common/anf_exporter/anf_exporter.h" #include "src/common/anf_exporter/anf_exporter.h"
#include "tools/converter/quantizer/post_training.h" #include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quantize_util.h" #include "tools/converter/quantizer/quantize_util.h"
#include "src/common/common.h" #include "src/common/common.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@ -54,7 +54,10 @@ struct DivergInfo {
size_t bit_num; size_t bit_num;
int quant_max = 255; int quant_max = 255;
int quant_min = 0; int quant_min = 0;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min) { std::string method_x = kMethodKL;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) {
this->method_x = method_x;
this->cnode = cnode; this->cnode = cnode;
this->bin_num = bins; this->bin_num = bins;
this->bit_num = bits; this->bit_num = bits;
@ -99,6 +102,12 @@ struct DivergInfo {
} }
STATUS ComputeThreshold() { STATUS ComputeThreshold() {
if (method_x == kMethodMaxMin) {
this->best_T = std::max(fabs(this->max), fabs(this->min));
MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T;
return RET_OK;
}
constexpr int quant_bint_nums = 128; constexpr int quant_bint_nums = 128;
int threshold = quant_bint_nums; int threshold = quant_bint_nums;
float min_kl = FLT_MAX; float min_kl = FLT_MAX;
@ -200,46 +209,32 @@ struct DivergInfo {
threshold = i; threshold = i;
} }
} }
MS_LOG(DEBUG) << "Best threshold bin index: " << threshold;
this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval; this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval;
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold
<< " T: " << best_T
<< " max: " << std::max(fabs(this->max), fabs(this->min));
return RET_OK; return RET_OK;
} }
std::pair<CNodePtr, float> GetScale() { std::pair<CNodePtr, float> GetScale() {
float max_value = this->best_T; float max_value = this->best_T;
float min_value = -max_value; float min_value = -max_value;
MS_ASSERT(quant_max - quant_min != 0); MS_ASSERT(quant_max - quant_min != 0);
double scale = (max_value - min_value) / (quant_max - quant_min); float scale = (max_value - min_value) / (quant_max - quant_min);
MS_ASSERT(scale != 0); MS_ASSERT(scale != 0);
return std::make_pair(this->cnode, scale); return std::make_pair(this->cnode, scale);
} }
std::pair<CNodePtr, int32_t> GetZeropoint() { std::pair<CNodePtr, int32_t> GetZeropoint() {
float max_value = this->best_T; int zero_point = 0;
float min_value = -max_value;
MS_ASSERT(quant_max - quant_min != 0);
float scale = (max_value - min_value) / (quant_max - quant_min);
auto quant_min_float = static_cast<float>(quant_min);
auto quant_max_float = static_cast<float>(quant_max);
MS_ASSERT(scale != 0);
const float zero_point_from_min = quant_min_float - min_value / scale;
// const float zero_point_from_max = quant_max_float - max_value / scale;
int zero_point;
if (zero_point_from_min < quant_min_float) {
zero_point = quant_min;
} else if (zero_point_from_min > quant_max_float) {
zero_point = quant_max;
} else {
zero_point = static_cast<int>(std::round(zero_point_from_min));
}
MS_LOG(DEBUG) << "zero point:" << zero_point;
if (quant_min == 0 && quant_max == 255) { if (quant_min == 0 && quant_max == 255) {
zero_point = 128; zero_point = 128;
} else if (quant_min == -128 && quant_max == 127) { } else if (quant_min == -128 && quant_max == 127) {
zero_point = 0; zero_point = 0;
} else {
MS_LOG(ERROR) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
} }
return std::make_pair(this->cnode, zero_point); return std::make_pair(this->cnode, zero_point);
} }
}; };
@ -356,9 +351,9 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) {
} }
string node_name = node->fullname_with_scope(); string node_name = node->fullname_with_scope();
std::unique_ptr<DivergInfo> input_diverg = std::unique_ptr<DivergInfo> input_diverg =
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_)); std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x));
std::unique_ptr<DivergInfo> output_diverg = std::unique_ptr<DivergInfo> output_diverg =
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_)); std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x));
input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg))); input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg)));
output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg))); output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg)));
@ -383,13 +378,13 @@ STATUS Calibrator::GenerateInputData(const int index, mindspore::tensor::MSTenso
MS_LOG(INFO) << "read image: " << path; MS_LOG(INFO) << "read image: " << path;
size_t size; size_t size;
char *binBuf = ReadFile(path.c_str(), &size); char *binBuf = ReadFile(path.c_str(), &size);
// auto *rawinputDatas = reinterpret_cast<const float *>(binBuf);
// auto mobilenet_input = const_cast<float *>(rawinputDatas);
auto data = tensor->MutableData(); auto data = tensor->MutableData();
if (size != tensor->Size()) {
MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size
<< " input tensor size: " << tensor->Size();
return RET_ERROR;
}
memcpy(data, binBuf, size); memcpy(data, binBuf, size);
// tensor->SetData(mobilenet_input);
return RET_OK; return RET_OK;
} }
@ -457,12 +452,19 @@ STATUS Calibrator::ReadConfig() {
config_param_.batch_count = std::stoul(value); config_param_.batch_count = std::stoul(value);
} else if (key == "thread_num") { } else if (key == "thread_num") {
config_param_.thread_num = std::stoul(value); config_param_.thread_num = std::stoul(value);
} else if (key == "method_x") {
if (value != kMethodKL && value != kMethodMaxMin) {
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
} else {
config_param_.method_x = value;
}
} else { } else {
MS_LOG(WARNING) << "unsupported parameter"; MS_LOG(WARNING) << "unsupported parameter";
} }
} }
MS_LOG(INFO) << "image_path: " << config_param_.image_path << " " MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " "
<< "batch_count: " << config_param_.batch_count << " " << "batch_count: " << config_param_.batch_count << " "
<< "mothod_x: " << config_param_.method_x << " "
<< "thread_num: " << config_param_.thread_num; << "thread_num: " << config_param_.thread_num;
delete[] resolved_path; delete[] resolved_path;
@ -615,7 +617,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input
quant_datas[i] = quant_data; quant_datas[i] = quant_data;
} }
auto ret = auto ret =
memcpy_s(bias_param->tensor_addr(), shape_size * sizeof(int32_t), quant_datas, shape_size * sizeof(int32_t)); memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t));
if (ret != EOK) { if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed."; MS_LOG(ERROR) << "memcpy_s failed.";
delete[] quant_datas; delete[] quant_datas;
@ -805,14 +807,6 @@ STATUS PostTrainingQuantizer::DoInference() {
MS_LOG(ERROR) << "generate input data from images failed!"; MS_LOG(ERROR) << "generate input data from images failed!";
return RET_ERROR; return RET_ERROR;
} }
/**
* struct CallBackParam {
std::string nodeType;
NODE_ID nodeName;
std::unordered_set<NODE_ID> depends;
int opExecResult;
};
*/
mindspore::session::KernelCallBack beforeCallBack = mindspore::session::KernelCallBack beforeCallBack =
[&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs, [&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs, const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
@ -916,9 +910,26 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
MS_LOG(ERROR) << "do pre process failed!"; MS_LOG(ERROR) << "do pre process failed!";
return status; return status;
} }
// anf -- fb
auto meta_graph = Export(funcGraph);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph return nullptr";
return RET_ERROR;
}
// transform
GraphDefTransform transform;
transform.SetGraphDef(meta_graph);
flags.quantType = schema::QuantType_QUANT_NONE;
status = transform.Transform(flags);
if (status != RET_OK) {
MS_LOG(ERROR) << "FBTransform model failed " << status;
return RET_ERROR;
}
MS_LOG(INFO) << "start create session"; MS_LOG(INFO) << "start create session";
flatbuffers::FlatBufferBuilder builder(1024); flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::MetaGraph::Pack(builder, Export(funcGraph)); auto offset = schema::MetaGraph::Pack(builder, meta_graph);
builder.Finish(offset); builder.Finish(offset);
size_t size = builder.GetSize(); size_t size = builder.GetSize();
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer()); auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());

@ -46,10 +46,14 @@ enum ImageFormat {
BGR = 2, BGR = 2,
}; };
const char kMethodMaxMin[] = "MAX_MIN";
const char kMethodKL[] = "KL";
struct ConfigParam { struct ConfigParam {
// ImageFormat imageFormat; // ImageFormat imageFormat;
std::string image_path; std::string image_path;
uint32_t batch_count; uint32_t batch_count{100};
std::string method_x{kMethodKL};
uint32_t thread_num; uint32_t thread_num;
}; };
@ -115,6 +119,8 @@ class Calibrator {
uint32_t GetThreadNum() const { return config_param_.thread_num; } uint32_t GetThreadNum() const { return config_param_.thread_num; }
std::string GetMethodX() const { return config_param_.method_x; }
STATUS AddQuantizedOp(CNodePtr node); STATUS AddQuantizedOp(CNodePtr node);
STATUS RecordMaxValue(std::string opName, std::vector<float> data, STATUS RecordMaxValue(std::string opName, std::vector<float> data,

@ -89,7 +89,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0)); auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
if (primitiveT_value == nullptr) { if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
return false; return false;
} }
@ -344,7 +344,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
} }
weightPtr->set_quant_param(quantParam); weightPtr->set_quant_param(quantParam);
auto ret = memcpy_s(rawDatas, weightPtr->tensor_size() * sizeof(int8_t), auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(),
qDatas.data(), shapeSize * sizeof(int8_t)); qDatas.data(), shapeSize * sizeof(int8_t));
if (ret != EOK) { if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret; MS_LOG(ERROR) << "memcpy error: " << ret;

@ -24,6 +24,7 @@
#include "include/model.h" #include "include/model.h"
#include "base/base.h" #include "base/base.h"
#include "src/param_value_lite.h" #include "src/param_value_lite.h"
#include "tools/converter/converter_flags.h"
namespace mindspore { namespace mindspore {
namespace lite { namespace lite {
@ -52,6 +53,7 @@ class Quantizer {
virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0; virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0;
mindspore::lite::converter::Flags flags;
protected: protected:
FuncGraphPtr funcGraph = nullptr; FuncGraphPtr funcGraph = nullptr;
}; };

Loading…
Cancel
Save