post training quantization

pull/4325/head
xutianchun 5 years ago
parent 13c2b23356
commit f32ff96ecd

@ -118,7 +118,7 @@ int QuantDTypeCastCPUKernel::Run() {
int8_ptr_ = reinterpret_cast<int8_t *>(out_tensors_[0]->Data());
}
int ret = LiteBackendParallelLaunch(QuantDTypeCastRun, this, thread_n_num_);
auto ret = LiteBackendParallelLaunch(QuantDTypeCastRun, this, thread_n_num_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Scale error error_code[" << ret << "]";
return RET_ERROR;

@ -39,7 +39,7 @@ class SoftmaxCPUKernel : public SoftmaxBaseCPUKernel {
int Run() override;
private:
float *sum_data_;
float *sum_data_ = nullptr;
};
} // namespace mindspore::kernel

@ -31,7 +31,7 @@
#include "src/common/anf_importer/import_from_protobuf.h"
#include "tools/converter/parser/onnx/onnx.pb.h"
#include "tools/converter/quantizer/weight_quantizer.h"
#include "tools/converter/quantizer/post_training.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
namespace mindspore {
@ -94,6 +94,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
CreateQuantizer(graph, flag);
if (mQuantizer != nullptr) {
mQuantizer->flags = *flag;
auto status = mQuantizer->DoQuantize(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "Quant failed " << status;

@ -277,8 +277,9 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC);
MS_LOG(DEBUG) << node->name << " weight trans format: CHWK->KHWC";
} else {
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2HWCK);
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
}
} else if (weightTensor->format == schema::Format_KCHW) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
@ -291,8 +292,8 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
return -1;
}
if (status == 0) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NCHW;
weightTensor->format = schema::Format_HWCK;
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter %ToHWCK failed, node : "
<< (weightTensor->format == schema::Format_CHWK ? "CHWK" : "CKHW"),

@ -4,15 +4,12 @@ include_directories(${3RD_DIR}/flatbuffers/include)
include_directories(${3RD_DIR}/opencv/build/include/opencv4)
add_library(quantizer_mid OBJECT
#${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc
#${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc
${CMAKE_CURRENT_SOURCE_DIR}/post_training.cc
${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc
#${CMAKE_CURRENT_SOURCE_DIR}/../proto/post_training/post_training.pb.cc
)
if(ENABLE_ASAN)

@ -28,7 +28,7 @@
#include "schema/inner/model_generated.h"
#include "src/ir/tensor.h"
#include "src/common/anf_exporter/anf_exporter.h"
#include "tools/converter/quantizer/post_training.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quantize_util.h"
#include "src/common/common.h"
#include "utils/log_adapter.h"
@ -54,7 +54,10 @@ struct DivergInfo {
size_t bit_num;
int quant_max = 255;
int quant_min = 0;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min) {
std::string method_x = kMethodKL;
DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min, const std::string &method_x) {
this->method_x = method_x;
this->cnode = cnode;
this->bin_num = bins;
this->bit_num = bits;
@ -99,6 +102,12 @@ struct DivergInfo {
}
STATUS ComputeThreshold() {
if (method_x == kMethodMaxMin) {
this->best_T = std::max(fabs(this->max), fabs(this->min));
MS_LOG(DEBUG) << "using MAX_MIN, T: " << this->best_T;
return RET_OK;
}
constexpr int quant_bint_nums = 128;
int threshold = quant_bint_nums;
float min_kl = FLT_MAX;
@ -200,46 +209,32 @@ struct DivergInfo {
threshold = i;
}
}
MS_LOG(DEBUG) << "Best threshold bin index: " << threshold;
this->best_T = (static_cast<float>(threshold) + 0.5f) * this->interval;
MS_LOG(DEBUG) << cnode->fullname_with_scope() << " Best threshold bin index: " << threshold
<< " T: " << best_T
<< " max: " << std::max(fabs(this->max), fabs(this->min));
return RET_OK;
}
std::pair<CNodePtr, float> GetScale() {
float max_value = this->best_T;
float min_value = -max_value;
MS_ASSERT(quant_max - quant_min != 0);
double scale = (max_value - min_value) / (quant_max - quant_min);
float scale = (max_value - min_value) / (quant_max - quant_min);
MS_ASSERT(scale != 0);
return std::make_pair(this->cnode, scale);
}
std::pair<CNodePtr, int32_t> GetZeropoint() {
float max_value = this->best_T;
float min_value = -max_value;
MS_ASSERT(quant_max - quant_min != 0);
float scale = (max_value - min_value) / (quant_max - quant_min);
auto quant_min_float = static_cast<float>(quant_min);
auto quant_max_float = static_cast<float>(quant_max);
MS_ASSERT(scale != 0);
const float zero_point_from_min = quant_min_float - min_value / scale;
// const float zero_point_from_max = quant_max_float - max_value / scale;
int zero_point;
if (zero_point_from_min < quant_min_float) {
zero_point = quant_min;
} else if (zero_point_from_min > quant_max_float) {
zero_point = quant_max;
} else {
zero_point = static_cast<int>(std::round(zero_point_from_min));
}
MS_LOG(DEBUG) << "zero point:" << zero_point;
int zero_point = 0;
if (quant_min == 0 && quant_max == 255) {
zero_point = 128;
} else if (quant_min == -128 && quant_max == 127) {
zero_point = 0;
} else {
MS_LOG(ERROR) << "unexpectd quant range, quant_min: " << quant_min << " quant_max: " << quant_max;
}
return std::make_pair(this->cnode, zero_point);
}
};
@ -356,9 +351,9 @@ STATUS Calibrator::AddQuantizedOp(CNodePtr node) {
}
string node_name = node->fullname_with_scope();
std::unique_ptr<DivergInfo> input_diverg =
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_));
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x));
std::unique_ptr<DivergInfo> output_diverg =
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_));
std::unique_ptr<DivergInfo>(new DivergInfo(node, 2048, bit_num_, quant_max_, quant_min_, config_param_.method_x));
input_diverg_info_.insert(std::make_pair(string(node_name), std::move(input_diverg)));
output_diverg_info_.insert(std::make_pair(string(node_name), std::move(output_diverg)));
@ -383,13 +378,13 @@ STATUS Calibrator::GenerateInputData(const int index, mindspore::tensor::MSTenso
MS_LOG(INFO) << "read image: " << path;
size_t size;
char *binBuf = ReadFile(path.c_str(), &size);
// auto *rawinputDatas = reinterpret_cast<const float *>(binBuf);
// auto mobilenet_input = const_cast<float *>(rawinputDatas);
auto data = tensor->MutableData();
if (size != tensor->Size()) {
MS_LOG(ERROR) << "the input data is not consistent with model input, file_size: " << size
<< " input tensor size: " << tensor->Size();
return RET_ERROR;
}
memcpy(data, binBuf, size);
// tensor->SetData(mobilenet_input);
return RET_OK;
}
@ -457,13 +452,20 @@ STATUS Calibrator::ReadConfig() {
config_param_.batch_count = std::stoul(value);
} else if (key == "thread_num") {
config_param_.thread_num = std::stoul(value);
} else if (key == "method_x") {
if (value != kMethodKL && value != kMethodMaxMin) {
MS_LOG(WARNING) << "unsupported method_x: " << value << ". Use default value.";
} else {
config_param_.method_x = value;
}
} else {
MS_LOG(WARNING) << "unsupported parameter";
}
}
MS_LOG(INFO) << "image_path: " << config_param_.image_path << " "
<< "batch_count: " << config_param_.batch_count << " "
<< "thread_num: " << config_param_.thread_num;
MS_LOG(DEBUG) << "image_path: " << config_param_.image_path << " "
<< "batch_count: " << config_param_.batch_count << " "
<< "mothod_x: " << config_param_.method_x << " "
<< "thread_num: " << config_param_.thread_num;
delete[] resolved_path;
fs.close();
@ -615,7 +617,7 @@ STATUS PostTrainingQuantizer::DoBiasQuant(std::shared_ptr<PrimitiveTValue> input
quant_datas[i] = quant_data;
}
auto ret =
memcpy_s(bias_param->tensor_addr(), shape_size * sizeof(int32_t), quant_datas, shape_size * sizeof(int32_t));
memcpy_s(bias_param->tensor_addr(), bias_param->tensor_size(), quant_datas, shape_size * sizeof(int32_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy_s failed.";
delete[] quant_datas;
@ -805,14 +807,6 @@ STATUS PostTrainingQuantizer::DoInference() {
MS_LOG(ERROR) << "generate input data from images failed!";
return RET_ERROR;
}
/**
* struct CallBackParam {
std::string nodeType;
NODE_ID nodeName;
std::unordered_set<NODE_ID> depends;
int opExecResult;
};
*/
mindspore::session::KernelCallBack beforeCallBack =
[&](const std::vector<mindspore::tensor::MSTensor *> &beforeInputs,
const std::vector<mindspore::tensor::MSTensor *> &beforeOutputs,
@ -916,9 +910,26 @@ STATUS PostTrainingQuantizer::DoQuantize(FuncGraphPtr funcGraph) {
MS_LOG(ERROR) << "do pre process failed!";
return status;
}
// anf -- fb
auto meta_graph = Export(funcGraph);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Export to meta_graph return nullptr";
return RET_ERROR;
}
// transform
GraphDefTransform transform;
transform.SetGraphDef(meta_graph);
flags.quantType = schema::QuantType_QUANT_NONE;
status = transform.Transform(flags);
if (status != RET_OK) {
MS_LOG(ERROR) << "FBTransform model failed " << status;
return RET_ERROR;
}
MS_LOG(INFO) << "start create session";
flatbuffers::FlatBufferBuilder builder(1024);
auto offset = schema::MetaGraph::Pack(builder, Export(funcGraph));
auto offset = schema::MetaGraph::Pack(builder, meta_graph);
builder.Finish(offset);
size_t size = builder.GetSize();
auto *content = reinterpret_cast<const char *>(builder.GetBufferPointer());

@ -46,10 +46,14 @@ enum ImageFormat {
BGR = 2,
};
const char kMethodMaxMin[] = "MAX_MIN";
const char kMethodKL[] = "KL";
struct ConfigParam {
// ImageFormat imageFormat;
std::string image_path;
uint32_t batch_count;
uint32_t batch_count{100};
std::string method_x{kMethodKL};
uint32_t thread_num;
};
@ -115,6 +119,8 @@ class Calibrator {
uint32_t GetThreadNum() const { return config_param_.thread_num; }
std::string GetMethodX() const { return config_param_.method_x; }
STATUS AddQuantizedOp(CNodePtr node);
STATUS RecordMaxValue(std::string opName, std::vector<float> data,

@ -89,7 +89,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const {
auto primitiveT_value = GetValueNode<std::shared_ptr<PrimitiveTValue>>(cnode->input(0));
if (primitiveT_value == nullptr) {
MS_LOG(ERROR) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope();
return false;
}
@ -344,7 +344,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_
}
weightPtr->set_quant_param(quantParam);
auto ret = memcpy_s(rawDatas, weightPtr->tensor_size() * sizeof(int8_t),
auto ret = memcpy_s(rawDatas, weightPtr->tensor_size(),
qDatas.data(), shapeSize * sizeof(int8_t));
if (ret != EOK) {
MS_LOG(ERROR) << "memcpy error: " << ret;

@ -24,6 +24,7 @@
#include "include/model.h"
#include "base/base.h"
#include "src/param_value_lite.h"
#include "tools/converter/converter_flags.h"
namespace mindspore {
namespace lite {
@ -52,6 +53,7 @@ class Quantizer {
virtual STATUS DoQuantize(FuncGraphPtr funcGraph) = 0;
mindspore::lite::converter::Flags flags;
protected:
FuncGraphPtr funcGraph = nullptr;
};

Loading…
Cancel
Save