!3891 fix tflite reshape parser && post training quantization

Merge pull request !3891 from xutianchun/quant_0803
pull/3891/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 28af1e5070

@ -118,6 +118,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
std::make_unique<schema::QuantParamT>(input_quant_params[0]);
tensor_input->quantParams.emplace_back(std::move(input_quant_param));
}
tensor_input->dataType = kNumberTypeInt8;
// output
auto output_index = node->outputIndex[0];
auto tensor_output = metaGraphT->allTensors[output_index].get();
@ -129,6 +130,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
std::make_unique<schema::QuantParamT>(output_quant_params[0]);
tensor_output->quantParams.emplace_back(std::move(output_quant_param));
}
tensor_output->dataType = kNumberTypeInt8;
// // TensorType
// valuePtr = primitive->GetAttr(kInputTensorDataType);
// if (valuePtr != nullptr) {
@ -210,17 +212,18 @@ void AnfExporter::SetOpInputNode(const CNodePtr &cnode, schema::MetaGraphT *meta
paramTensor->data.resize(paramValue->tensor_size());
memcpy(paramTensor->data.data(), paramValue->tensor_addr(), paramValue->tensor_size());
}
// for (auto &ite : paramValue->quant_param()) {
// auto quantPar = std::make_unique<schema::QuantParamT>();
// quantPar->scale = ite->scale;
// quantPar->zeroPoint = ite->zeroPoint;
// quantPar->min = ite->min;
// quantPar->max = ite->max;
// quantPar->narrowRange = ite->narrowRange;
// quantPar->inited = ite->inited;
// quantPar->numBits = ite->numBits;
// paramTensor->quantParams.emplace_back(std::move(quantPar));
// }
for (auto &ite : paramValue->quant_param()) {
auto quantPar = std::make_unique<schema::QuantParamT>();
quantPar->scale = ite->scale;
quantPar->zeroPoint = ite->zeroPoint;
quantPar->min = ite->min;
quantPar->max = ite->max;
quantPar->narrowRange = ite->narrowRange;
quantPar->inited = ite->inited;
quantPar->numBits = ite->numBits;
paramTensor->quantParams.emplace_back(std::move(quantPar));
paramTensor->dataType = paramValue->tensor_type();
}
nodeIdMap[paramNode->fullname_with_scope()] = meta_graph->allTensors.size();
fbNode->inputIndex.emplace_back(meta_graph->allTensors.size());
meta_graph->allTensors.emplace_back(std::move(paramTensor));

@ -140,6 +140,8 @@ lite::Primitive *ModelImpl::CopyPrimitive(const schema::Primitive *srcPrim) {
return new lite::Flatten(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_MatMul:
return new lite::MatMul(const_cast<schema::Primitive *>(srcPrim));
case schema::PrimitiveType_QuantDTypeCast:
return new lite::QuantDTypeCast(const_cast<schema::Primitive *>(srcPrim));
default:
break;
}

@ -57,6 +57,25 @@ int Reshape::CalNewShape(const tensor::Tensor *in_tensor, std::vector<int> *out_
return RET_OK;
}
template <typename T>
void CalShape(const T *data, const std::vector<tensor::Tensor *> &inputs, std::vector<int> *out_shape, int shape_size) {
int input_count = inputs[0]->ElementsNum();
int index = 0;
int size = 1;
for (size_t i = 0; i < shape_size; i++) {
if (data[i] == -1) {
index = i;
} else {
size *= data[i];
}
out_shape->push_back(data[i]);
}
if (data[index] == -1) {
(*out_shape)[index] = input_count / size;
}
}
int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::Tensor *> outputs_) {
MS_ASSERT(this->primitive != nullptr);
auto input = inputs_.front();
@ -69,31 +88,23 @@ int Reshape::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tenso
std::vector<int> out_shape;
if (inputs_.size() == kDoubleNum) {
auto shape_tensor = inputs_.at(1);
size_t shape_size = shape_tensor->shape().size();
size_t shape_size = shape_tensor->ElementsNum();
switch (shape_tensor->data_type()) {
case kNumberTypeInt8: {
auto data = reinterpret_cast<int8_t *>(shape_tensor->Data());
for (size_t i = 0; i < shape_size; i++) {
out_shape.push_back(data[i]);
}
CalShape<int8_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeInt32: {
auto data = reinterpret_cast<int32_t *>(shape_tensor->Data());
for (size_t i = 0; i < shape_size; i++) {
out_shape.push_back(data[i]);
}
CalShape<int32_t>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeFloat: {
auto data = reinterpret_cast<float *>(shape_tensor->Data());
for (size_t i = 0; i < shape_size; i++) {
out_shape.push_back(data[i]);
}
CalShape<float>(data, inputs_, &out_shape, shape_size);
} break;
case kNumberTypeUInt32: {
auto data = reinterpret_cast<uint32_t *>(shape_tensor->Data());
for (size_t i = 0; i < shape_size; i++) {
out_shape.push_back(data[i]);
}
CalShape<uint32_t>(data, inputs_, &out_shape, shape_size);
} break;
default: {
MS_LOG(ERROR) << "Reshape weight tensor has unsupported dataType: " << shape_tensor->data_type();

@ -215,7 +215,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
MS_ASSERT(node != nullptr);
auto opType = node->primitive->value.type;
if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D &&
opType != schema::PrimitiveType_DeConv2D) {
opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) {
return 0;
}
@ -230,7 +230,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
<< weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK);
status = TransFilterFormat<int8_t>(weightTensor.get(), kKCHW2HWCK);
} else {
MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
<< weightTensor->dataType;
@ -238,7 +238,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
}
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKHWC2HWCK);
status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
}

@ -31,14 +31,23 @@ STATUS TfliteReshapeParser::Parse(const std::unique_ptr<tflite::OperatorT> &tfli
const auto &tfliteAttr = tfliteOp->builtin_options.AsReshapeOptions();
if (tfliteAttr == nullptr) {
MS_LOG(ERROR) << "get op: " << op->name.c_str() << " attr failed";
return RET_NULL_PTR;
}
attr->format = schema::Format_NHWC;
attr->shape.resize(tfliteAttr->new_shape.size());
for (size_t i = 0; i < tfliteAttr->new_shape.size(); ++i) {
attr->shape[i] = tfliteAttr->new_shape[i];
if (tfliteOp->inputs.size() < 2) {
MS_LOG(ERROR) << "expected two input tensors, but got: " << tfliteOp->inputs.size();
return RET_ERROR;
}
auto shape_tensor_index = tfliteOp->inputs[1];
const auto & shape_tensor = tfliteTensors[shape_tensor_index];
std::vector<tflite::TensorT *> shape_tensors{shape_tensor.get()};
if (RET_OK != ParseWeight(shape_tensors, tfliteModelBuffer, tensor_cache, schema::Format_KHWC)) {
MS_LOG(ERROR) << "parse shape tensor error";
return RET_ERROR;
}
} else {
attr->format = schema::Format_NHWC;
attr->shape.resize(tfliteAttr->new_shape.size());
for (size_t i = 0; i < tfliteAttr->new_shape.size(); ++i) {
attr->shape[i] = tfliteAttr->new_shape[i];
}
}
if (op != nullptr) {

@ -230,6 +230,13 @@ struct DivergInfo {
} else {
zero_point = static_cast<int>(std::round(zero_point_from_min));
}
MS_LOG(DEBUG) << "zero point:" << zero_point;
if (quant_min == 0 && quant_max == 255) {
zero_point = 128;
} else if (quant_min == -128 && quant_max == 127) {
zero_point = 0;
}
return std::make_pair(this->cnode, zero_point);
}
};
@ -466,11 +473,6 @@ Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin)
PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type)
: Quantizer(graph) {
this->bit_num = bit_num;
calibrator_ = std::unique_ptr<Calibrator>(new Calibrator(path, this->bit_num, quant_max, quant_min));
if (calibrator_ == nullptr) {
MS_LOG(ERROR) << "creat calibrator failed!";
return;
}
this->target_type_ = target_type;
if (target_type == kNumberTypeInt8) {
quant_max = (1 << (this->bit_num - 1)) - 1; // 127
@ -481,6 +483,11 @@ PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, in
} else {
MS_LOG(ERROR) << "unsupported quant value type: " << target_type;
}
calibrator_ = std::unique_ptr<Calibrator>(new Calibrator(path, this->bit_num, quant_max, quant_min));
if (calibrator_ == nullptr) {
MS_LOG(ERROR) << "creat calibrator failed!";
return;
}
}
STATUS PostTrainingQuantizer::DoQuantInput(double scale, int zeropoint, struct MaxMin *max_min,
@ -526,7 +533,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) {
}
auto parameter = std::dynamic_pointer_cast<Parameter>(node);
ParamValueLitePtr paramValue = std::dynamic_pointer_cast<ParamValueLite>(parameter->default_param());
auto status = QuantFilter(paramValue, QuantType_PostTraining, bit_num);
auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed: " << status;
return status;

@ -64,7 +64,7 @@ class PostTrainingQuantizer : public Quantizer {
int quant_min{0};
private:
TypeId target_type_{kNumberTypeUInt8};
TypeId target_type_{kNumberTypeInt8};
std::unique_ptr<Calibrator> calibrator_;

@ -22,13 +22,16 @@
namespace mindspore::lite::quant {
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type) {
ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector<schema::QuantParamT> &quant_params) {
std::unique_ptr<schema::PrimitiveT> primitive = std::make_unique<schema::PrimitiveT>();
schema::QuantDTypeCastT quant_dtype_cast;
quant_dtype_cast.srcT = src_type; // kNumberTypeUInt8;
quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32;
primitive->value.Set(quant_dtype_cast);
auto primTValue = std::make_shared<PrimitiveTValue>(primitive.release());
for (auto &quant_param : quant_params) {
primTValue->AddInputQuantParam(quant_param);
}
return NewValueNode(primTValue);
}
@ -48,7 +51,8 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
}
if (first) {
if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) {
auto value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8);
auto value_node =
NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, primitiveT_value->GetInputQuantParams());
std::vector<AnfNodePtr> op_inputs = {value_node, cnode->input(1)};
auto quant_cast_cnode = graph->NewCNode(op_inputs);
quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast");
@ -78,10 +82,12 @@ STATUS QuantCast::Run(FuncGraphPtr graph) {
ValueNodePtr value_node = nullptr;
if (curnode_quant_type == schema::QuantType_PostTraining &&
input_cnode_quant_type == schema::QuantType_QUANT_NONE) {
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8);
value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8,
input_cnode_primitiveT_value->GetInputQuantParams());
} else if (curnode_quant_type == schema::QuantType_QUANT_NONE &&
input_cnode_quant_type == schema::QuantType_PostTraining) {
value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32);
value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32,
input_cnode_primitiveT_value->GetInputQuantParams());
}
if (value_node == nullptr) {
MS_LOG(WARNING) << "value_node is null! "

@ -190,7 +190,7 @@ void CalFakeNode(const AnfNodePtr &inTensor) {
}
STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin,
double mMax, bool narrowRange, int numBits) {
double mMax, bool narrowRange, int quant_max, int quant_min, int num_bits) {
MS_ASSERT(quantParam != nullptr);
if (mMin > 0.0f) {
MS_LOG(ERROR) << "min " << mMin << " is bigger then 0, set to 0, this may course low precision";
@ -215,28 +215,17 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
quantParam->scale = 0.0f;
quantParam->zeroPoint = 0;
quantParam->narrowRange = narrowRange;
quantParam->numBits = numBits;
quantParam->numBits = num_bits;
return RET_OK;
}
int quantMin = narrowRange ? 1 : 0;
int quantMax = (1 << (unsigned int)numBits) - 1;
auto quantMinFloat = static_cast<double>(quantMin);
auto quantMaxFloat = static_cast<double>(quantMax);
auto quantMinFloat = static_cast<double>(quant_min);
auto quantMaxFloat = static_cast<double>(quant_max);
double scale = (mMax - mMin) / (quantMaxFloat - quantMinFloat);
const double zeroPointFromMin = quantMinFloat - mMin / scale;
const double zeroPointFromMax = quantMaxFloat - mMax / scale;
const double zpFromMinError = std::abs(quantMinFloat) + std::abs(mMin / scale);
const double zpFromMaxError = std::abs(quantMaxFloat) + std::abs(mMax / scale);
const double zpDouble = zpFromMinError < zpFromMaxError ? zeroPointFromMin : zeroPointFromMax;
int zeroPoint;
if (zpDouble < quantMinFloat) {
zeroPoint = quantMin;
} else if (zpDouble > quantMaxFloat) {
zeroPoint = quantMax;
} else {
zeroPoint = static_cast<int32_t>(std::round(zpDouble));
}
// const double zeroPointFromMax = quantMaxFloat - mMax / scale;
int zeroPoint = static_cast<int32_t>(std::round(zeroPointFromMin));
// The zero point should always be in the range of quantized value,
// [qmin, qmax].
MS_ASSERT(zeroPoint >= quantMin);
@ -247,12 +236,12 @@ STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double
quantParam->scale = scale;
quantParam->zeroPoint = zeroPoint;
quantParam->narrowRange = narrowRange;
quantParam->numBits = numBits;
quantParam->numBits = num_bits;
return RET_OK;
}
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bitNum) {
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum) {
auto dims = weightPtr->tensor_shape();
if (dims.size() < 1) {
MS_LOG(ERROR) << "weight dims size error";
@ -284,7 +273,7 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bit
}
std::unique_ptr<AnfQuantParam> quantParam = std::unique_ptr<AnfQuantParam>(new AnfQuantParam);
STATUS status = CalQuantizationParams(quantParam, min, max, false, bitNum);
STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "CalQuantizationParams failed" << status;
return status;
@ -308,8 +297,8 @@ STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, size_t bit
PostBitPack(const_cast<float*>(rawDatas), shapeSize, bitNum);
}
weightPtr->set_tensor_type(kNumberTypeUInt8);
weightPtr->set_tensor_size(shapeSize * sizeof(uint8_t));
weightPtr->set_tensor_type(kNumberTypeInt8);
weightPtr->set_tensor_size(shapeSize * sizeof(int8_t));
return RET_OK;
}

@ -60,7 +60,7 @@ class QuantStrategy {
};
STATUS CalQuantizationParams(std::unique_ptr<AnfQuantParam> &quantParam, double mMin, double mMax,
bool narrowRange = false, int numBits = UINT8_QUANTIZATION);
bool narrowRange, int quant_max, int quant_min, int num_bits);
template <typename T>
T QuantizeData(const float originData, const AnfQuantParam *quantParam) {
@ -96,7 +96,7 @@ T QuantizeData(const float originData, const AnfQuantParam *quantParam) {
void CalFakeNode(const AnfNodePtr &inTensor);
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType = QuantType_AwareTraining,
STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min,
size_t bitNum = UINT8_QUANTIZATION);
STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION);

@ -81,7 +81,7 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list<CNodePtr> &nodes) {
}
ParamValueLitePtr paramValue = std::static_pointer_cast<ParamValueLite>(paramNode->default_param());
auto status = QuantFilter(paramValue, QuantType_WeightQuant, bitNum);
auto status = QuantFilter(paramValue, QuantType_WeightQuant, 127, -128, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantFilter failed : " << status;
return status;
@ -120,7 +120,7 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list<CNodePtr> &nodes) {
MS_LOG(ERROR) << "No valid input param node !";
continue;
}
auto status = QuantFilter(paramValue, QuantType_WeightQuant, bitNum);
auto status = QuantFilter(paramValue, QuantType_WeightQuant, 127, -128, bitNum);
if (status != RET_OK) {
MS_LOG(ERROR) << "QunatFilter failed" << status;
return RET_ERROR;

Loading…
Cancel
Save