fix post quantization && fix tflite reshape parser

pull/3891/head
xutianchun 5 years ago
parent bac1781539
commit 8f334af0e0

@ -118,6 +118,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
std::make_unique<schema::QuantParamT>(input_quant_params[0]); std::make_unique<schema::QuantParamT>(input_quant_params[0]);
tensor_input->quantParams.emplace_back(std::move(input_quant_param)); tensor_input->quantParams.emplace_back(std::move(input_quant_param));
} }
tensor_input->dataType = kNumberTypeInt8;
// output // output
auto output_index = node->outputIndex[0]; auto output_index = node->outputIndex[0];
auto tensor_output = metaGraphT->allTensors[output_index].get(); auto tensor_output = metaGraphT->allTensors[output_index].get();
@ -129,6 +130,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
std::make_unique<schema::QuantParamT>(output_quant_params[0]); std::make_unique<schema::QuantParamT>(output_quant_params[0]);
tensor_output->quantParams.emplace_back(std::move(output_quant_param)); tensor_output->quantParams.emplace_back(std::move(output_quant_param));
} }
tensor_output->dataType = kNumberTypeInt8;
// // TensorType // // TensorType
// valuePtr = primitive->GetAttr(kInputTensorDataType); // valuePtr = primitive->GetAttr(kInputTensorDataType);
// if (valuePtr != nullptr) { // if (valuePtr != nullptr) {
@ -210,17 +212,18 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
paramTensor->data.resize(paramValue->tensor_size()); paramTensor->data.resize(paramValue->tensor_size());
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size()); memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size());
} }
// for (auto &ite : paramValue->quant_param()) { for (auto &ite : paramValue->quant_param()) {
// auto quantPar = std::make_unique<schema::QuantParamT>(); auto quantPar = std::make_unique<schema::QuantParamT>();
// quantPar->scale = ite->scale; quantPar->scale = ite->scale;
// quantPar->zeroPoint = ite->zeroPoint; quantPar->zeroPoint = ite->zeroPoint;
// quantPar->min = ite->min; quantPar->min = ite->min;
// quantPar->max = ite->max; quantPar->max = ite->max;
// quantPar->narrowRange = ite->narrowRange; quantPar->narrowRange = ite->narrowRange;
// quantPar->inited = ite->inited; quantPar->inited = ite->inited;
// quantPar->numBits = ite->numBits; quantPar->numBits = ite->numBits;
// paramTensor->quantParams.emplace_back(std::move(quantPar)); paramTensor->quantParams.emplace_back(std::move(quantPar));
// } paramTensor->dataType = paramValue->tensor_type();
}
nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size(); nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size();
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size()); fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
meta_graph->allTensors.emplace_back(std::move(paramTensor)); meta_graph->allTensors.emplace_back(std::move(paramTensor));

@ -140,6 +140,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Flatten(const_cast<schema::Primitive *>(srcPrim)); return new lite::Flatten(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_MatMul: case schema::PrimitiveType_MatMul:
return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim)); return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_QuantDTypeCast:
return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(srcPrim));
default: default:
break; break;
} }

@ -57,6 +57,25 @@ int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector<int> *out_
return RET_OK; return RET_OK;
} }
template <typename T>
void CalShape(const T *data, const std::vector<tensor::Tensor *> &inputs, std::vector<int> *out_shape, int shape_size) {
int input_count = inputs[0]->ElementsNum();
int index = 0;
int size = 1;
for (size_t i = 0; i < shape_size; i++) {
if (data[i] == -1) {
index = i;
} else {
size *= data[i];
}
out_shape->push_back(data[i]);
}
if (data[index] == -1) {
(*out_shape)[index] = input_count / size;
}
}
int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) { int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr); MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front(); auto input = inputs_.front();
@ -69,31 +88,23 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
std::vector<int> out_shape; std::vector<int> out_shape;
if (inputs_.size() == kDoubleNum) { if (inputs_.size() == kDoubleNum) {
auto shape_tensor = inputs_.at(1); auto shape_tensor = inputs_.at(1);
size_t shape_size = shape_tensor->shape().size(); size_t shape_size = shape_tensor->ElementsNum();
switch (shape_tensor->data_type()) { switch (shape_tensor->data_type()) {
case kNumberTypeInt8: { case kNumberTypeInt8: {
auto data = reinterpret_cast<int8_t *>(shape_tensor->Data()); auto data = reinterpret_cast<int8_t *>(shape_tensor->Data());
for (size_t i = 0; i < shape_size; i++) { CalShape<int8_t>(data, inputs_, &out_shape, shape_size);
out_shape.push_back(data[i]);
}
} break; } break;
case kNumberTypeInt32: { case kNumberTypeInt32: {
auto data = reinterpret_cast<int32_t *>(shape_tensor->Data()); auto data = reinterpret_cast<int32_t *>(shape_tensor->Data());
for (size_t i = 0; i < shape_size; i++) { CalShape<int32_t>(data, inputs_, &out_shape, shape_size);
out_shape.push_back(data[i]);
}
} break; } break;
case kNumberTypeFloat: { case kNumberTypeFloat: {
auto data = reinterpret_cast<float *>(shape_tensor->Data()); auto data = reinterpret_cast<float *>(shape_tensor->Data());
for (size_t i = 0; i < shape_size; i++) { CalShape<float>(data, inputs_, &out_shape, shape_size);
out_shape.push_back(data[i]);
}
} break; } break;
case kNumberTypeUInt32: { case kNumberTypeUInt32: {
auto data = reinterpret_cast<uint32_t *>(shape_tensor->Data()); auto data = reinterpret_cast<uint32_t *>(shape_tensor->Data());
for (size_t i = 0; i < shape_size; i++) { CalShape<uint32_t>(data, inputs_, &out_shape, shape_size);
out_shape.push_back(data[i]);
}
} break; } break;
default: { default: {
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type(); MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type();

@ -215,7 +215,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
MS_ASSERT(node != nullptr); MS_ASSERT(node != nullptr);
auto opType = node->primitive->value.type; auto opType = node->primitive->value.type;
if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D &&
opType != schema::PrimitiveType_DeConv2D) { opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) {
return 0; return 0;
} }
@ -230,7 +230,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
<< weightTensor->dataType; << weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK); status = TransFilterFormat<int8_t>(weightTensor.get(), kKCHW2HWCK);
} else { } else {
MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
<< weightTensor->dataType; << weightTensor->dataType;
@ -238,7 +238,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} }
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx } else if (weightTensor->format == schema::Format_KHWC) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKHWC2HWCK); status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
} else { } else {
status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK); status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
} }

@ -31,14 +31,23 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli
const auto &tfliteAttr = tfliteOp->builtin_options.AsReshapeOptions(); const auto &tfliteAttr = tfliteOp->builtin_options.AsReshapeOptions();
if (tfliteAttr == nullptr) { if (tfliteAttr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed"; if (tfliteOp->inputs.size() < 2) {
return RET_NULL_PTR; MS_LOG(ERROR) << "expected two input tensors, but got: " << tfliteOp->inputs.size();
} return RET_ERROR;
}
attr->format = schema::Format_NHWC; auto shape_tensor_index = tfliteOp->inputs[1];
attr->shape.resize(tfliteAttr->new_shape.size()); const auto & shape_tensor = tfliteTensors[shape_tensor_index];
for (size_t i = 0; i < tfliteAttr->new_shape.size(); ++i) { std::vector<tflite::TensorT *> shape_tensors{shape_tensor.get()};
attr->shape[i] = tfliteAttr->new_shape[i]; if (RET_OK != ParseWeight(shape_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
MS_LOG(ERROR) << "parse shape tensor error";
return RET_ERROR;
}
} else {
attr->format = schema::Format_NHWC;
attr->shape.resize(tfliteAttr->new_shape.size());
for (size_t i = 0; i < tfliteAttr->new_shape.size(); ++i) {
attr->shape[i] = tfliteAttr->new_shape[i];
}
} }
if (op != nullptr) { if (op != nullptr) {

@ -230,6 +230,13 @@ struct DivergInfo {
} else { } else {
zero_point = static_cast<int>(std::round(zero_point_from_min)); zero_point = static_cast<int>(std::round(zero_point_from_min));
} }
MS_LOG(DEBUG) << "zero point:" << zero_point;
if (quant_min == 0 && quant_max == 255) {
zero_point = 128;
} else if (quant_min == -128 && quant_max == 127) {
zero_point = 0;
}
return std::make_pair(this->cnode, zero_point); return std::make_pair(this->cnode, zero_point);
} }
}; };
@ -466,11 +473,6 @@ Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin)
PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type) PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type)
: Quantizer(graph) { : Quantizer(graph) {
this->bit_num = bit_num; this->bit_num = bit_num;
calibrator_ = std::unique_ptr<Calibrator>(new Calibrator(path, this->bit_num, quant_max, quant_min));
if (calibrator_ == nullptr) {
MS_LOG(ERROR) << "creat calibrator failed!";
return;
}
this->target_type_ = target_type; this->target_type_ = target_type;
if (target_type == kNumberTypeInt8) { if (target_type == kNumberTypeInt8) {
quant_max = (1 << (this->bit_num - 1)) - 1; // 127 quant_max = (1 << (this->bit_num - 1)) - 1; // 127
@ -481,6 +483,11 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in
} else { } else {
MS_LOG(ERROR) << "unsupported quant value type: " << target_type; MS_LOG(ERROR) << "unsupported quant value type: " << target_type;
} }
calibrator_ = std::unique_ptr<Calibrator>(new Calibrator(path, this->bit_num, quant_max, quant_min));
if (calibrator_ == nullptr) {
MS_LOG(ERROR) << "creat calibrator failed!";
return;
}
} }
STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min, STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min,
@ -526,7 +533,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) {
} }
auto parameter = std::dynamic_pointer_cast<Parameter>(node); auto parameter = std::dynamic_pointer_cast<Parameter>(node);
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param()); ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
auto status = QuantFilter(paramValue, QuantType_PostTraining, bit_num); auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status; MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status; return status;

@ -64,7 +64,7 @@ class PostTrainingQuantizer : public Quantizer {
int quant_min{0}; int quant_min{0};
private: private:
TypeId target_type_{kNumberTypeUInt8}; TypeId target_type_{kNumberTypeInt8};
std::unique_ptr<Calibrator> calibrator_; std::unique_ptr<Calibrator> calibrator_;

@ -22,13 +22,16 @@
namespace mindspore::lite::quant { namespace mindspore::lite::quant {
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type) { ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) {
std::unique_ptr<schema::PrimitiveT> primitive = std::make_unique<schema::PrimitiveT>(); std::unique_ptr<schema::PrimitiveT> primitive = std::make_unique<schema::PrimitiveT>();
schema::QuantDTypeCastT quant_dtype_cast; schema::QuantDTypeCastT quant_dtype_cast;
quant_dtype_cast.srcT = src_type; // kNumberTypeUInt8; quant_dtype_cast.srcT = src_type; // kNumberTypeUInt8;
quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32; quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32;
primitive->value.Set(quant_dtype_cast); primitive->value.Set(quant_dtype_cast);
auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release()); auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release());
for (auto &quant_param : quant_params) {
primTValue->AddInputQuantParam(quant_param);
}
return NewValueNode(primTValue); return NewValueNode(primTValue);
} }
@ -48,7 +51,8 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
} }
if (first) { if (first) {
if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) {
auto value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8); auto value_node =
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, primitiveT_value->GetInputQuantParams());
std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)}; std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)};
auto quant_cast_cnode = graph->NewCNode(op_inputs); auto quant_cast_cnode = graph->NewCNode(op_inputs);
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast");
@ -78,10 +82,12 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
ValueNodePtr value_node = nullptr; ValueNodePtr value_node = nullptr;
if (curnode_quant_type == schema::QuantType_PostTraining && if (curnode_quant_type == schema::QuantType_PostTraining &&
input_cnode_quant_type == schema::QuantType_QUANT_NONE) { input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8); value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8,
input_cnode_primitiveT_value->GetInputQuantParams());
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE && } else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
input_cnode_quant_type == schema::QuantType_PostTraining) { input_cnode_quant_type == schema::QuantType_PostTraining) {
value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32); value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32,
input_cnode_primitiveT_value->GetInputQuantParams());
} }
if (value_node == nullptr) { if (value_node == nullptr) {
MS_LOG(WARNING) << "value_node is null! " MS_LOG(WARNING) << "value_node is null! "

@ -190,7 +190,7 @@ void CalFakeNode(const AnfNodePtr &inTensor) {
} }
STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin,
double mMax, bool narrowRange, int numBits) { double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits) {
MS_ASSERT(quantParam != nullptr); MS_ASSERT(quantParam != nullptr);
if (mMin > 0.0f) { if (mMin > 0.0f) {
MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision"; MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
@ -215,28 +215,17 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
quantParam->scale = 0.0f; quantParam->scale = 0.0f;
quantParam->zeroPoint = 0; quantParam->zeroPoint = 0;
quantParam->narrowRange = narrowRange; quantParam->narrowRange = narrowRange;
quantParam->numBits = numBits; quantParam->numBits = num_bits;
return RET_OK; return RET_OK;
} }
int quantMin = narrowRange ? 1 : 0; auto quantMinFloat = static_cast<double>(quant_min);
int quantMax = (1 << (unsigned int)numBits) - 1; auto quantMaxFloat = static_cast<double>(quant_max);
auto quantMinFloat = static_cast<double>(quantMin);
auto quantMaxFloat = static_cast<double>(quantMax);
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat); double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
const double zeroPointFromMin = quantMinFloat - mMin / scale; const double zeroPointFromMin = quantMinFloat - mMin / scale;
const double zeroPointFromMax = quantMaxFloat - mMax / scale; // const double zeroPointFromMax = quantMaxFloat - mMax / scale;
const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale); int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin));
const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale);
const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax;
int zeroPoint;
if (zpDouble < quantMinFloat) {
zeroPoint = quantMin;
} else if (zpDouble > quantMaxFloat) {
zeroPoint = quantMax;
} else {
zeroPoint = static_cast<int32_t>(std::round(zpDouble));
}
// The zero point should always be in the range of quantized value, // The zero point should always be in the range of quantized value,
// [qmin, qmax]. // [qmin, qmax].
MS_ASSERT(zeroPoint >= quantMin); MS_ASSERT(zeroPoint >= quantMin);
@ -247,12 +236,12 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
quantParam->scale = scale; quantParam->scale = scale;
quantParam->zeroPoint = zeroPoint; quantParam->zeroPoint = zeroPoint;
quantParam->narrowRange = narrowRange; quantParam->narrowRange = narrowRange;
quantParam->numBits = numBits; quantParam->numBits = num_bits;
return RET_OK; return RET_OK;
} }
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bitNum) { STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum) {
auto dims = weightPtr->tensor_shape(); auto dims = weightPtr->tensor_shape();
if (dims.size() < 1) { if (dims.size() < 1) {
MS_LOG(ERROR) << "weight dims size error"; MS_LOG(ERROR) << "weight dims size error";
@ -284,7 +273,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bit
} }
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam); std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
STATUS status = CalQuantizationParams(quantParam, min, max, false, bitNum); STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status; MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status; return status;
@ -308,8 +297,8 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bit
PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum); PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum);
} }
weightPtr->set_tensor_type(kNumberTypeUInt8); weightPtr->set_tensor_type(kNumberTypeInt8);
weightPtr->set_tensor_size(shapeSize * sizeof(uint8_t)); weightPtr->set_tensor_size(shapeSize * sizeof(int8_t));
return RET_OK; return RET_OK;
} }

@ -60,7 +60,7 @@ class QuantStrategy {
}; };
STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax, STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax,
bool narrowRange = false, int numBits = UINT8_QUANTIZATION); bool narrowRange, int quant_max, int quant_min, int num_bits);
template <typename T> template <typename T>
T QuantizeData(const float originData, const AnfQuantParam *quantParam) { T QuantizeData(const float originData, const AnfQuantParam *quantParam) {
@ -96,7 +96,7 @@ T QuantizeData(const float originData, const AnfQuantParam *quantParam) {
void CalFakeNode(const AnfNodePtr &inTensor); void CalFakeNode(const AnfNodePtr &inTensor);
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType = QuantType_AwareTraining, STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min,
size_t bitNum = UINT8_QUANTIZATION); size_t bitNum = UINT8_QUANTIZATION);
STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION);

@ -81,7 +81,7 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
} }
ParamValueLitePtr paramValue = std::static_pointer_cast<ParamValueLite>(paramNode->default_param()); ParamValueLitePtr paramValue = std::static_pointer_cast<ParamValueLite>(paramNode->default_param());
auto status = QuantFilter(paramValue, QuantType_WeightQuant, bitNum); auto status = QuantFilter(paramValue, QuantType_WeightQuant, 127, -128, bitNum);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status; MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status; return status;
@ -120,7 +120,7 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
MS_LOG(ERROR) << "No valid input param node !"; MS_LOG(ERROR) << "No valid input param node !";
continue; continue;
} }
auto status = QuantFilter(paramValue, QuantType_WeightQuant, bitNum); auto status = QuantFilter(paramValue, QuantType_WeightQuant, 127, -128, bitNum);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "QunatFilter failed" << status; MS_LOG(ERROR) << "QunatFilter failed" << status;
return RET_ERROR; return RET_ERROR;

Loading…
Cancel
Save