|
|
|
@ -33,7 +33,7 @@ STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, Q
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
if (tensor.dataType != TypeId::kNumberTypeFloat) {
|
|
|
|
|
// MS_LOGW("Const Tensor without quantParam should has float dataType, in fact: %d", tensor.dataType);
|
|
|
|
|
MS_LOG(WARNING) << "Const Tensor without quantParam should has float dataType, in fact: " << tensor.dataType;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
const auto *constData = reinterpret_cast<const float *>(tensor.data.data());
|
|
|
|
@ -53,7 +53,7 @@ STATUS QuantParamCalcer::ComputeConstQuantParam(const schema::TensorT &tensor, Q
|
|
|
|
|
isQuantExact &= (constData[i] == min || constData[i] == max);
|
|
|
|
|
}
|
|
|
|
|
if (!isQuantExact) {
|
|
|
|
|
// //MS_LOGD("compute quantParam for const tensor may be a cause of poor inference accuracy");
|
|
|
|
|
MS_LOG(DEBUG) << "compute quantParam for const tensor may be a cause of poor inference accuracy";
|
|
|
|
|
}
|
|
|
|
|
return quant::CalQuantizationParams(quantParam, min, max);
|
|
|
|
|
}
|
|
|
|
@ -80,7 +80,7 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
|
|
|
|
|
if (tensor->refCount == schema::NodeType_ValueNode && !IsContain(graph->inputIndex, node.inputIndex.at(i))) {
|
|
|
|
|
auto status = ComputeConstQuantParam((*tensor), quantParam.get());
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("ComputeConstQuantParam failed: %d", status);
|
|
|
|
|
MS_LOG(WARNING) << "ComputeConstQuantParam failed: " << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
tensor->quantParams.front() = std::move(quantParam);
|
|
|
|
@ -110,15 +110,15 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
|
|
|
|
|
int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) {
|
|
|
|
|
auto status = QuantParamCalcer::Calc(subGraph, node);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
|
|
|
|
|
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
if (inputParamDone != node.inputIndex.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name.c_str();
|
|
|
|
|
MS_LOG(ERROR) << "Can not determine inputTensor quantParam, node " << node.name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (outputParamDone != node.outputIndex.size()) {
|
|
|
|
|
MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name.c_str();
|
|
|
|
|
MS_LOG(ERROR) << "Can not determine outputTensor quantParam, node " << node.name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
@ -127,7 +127,7 @@ int CommonCalcer::Calc(MetaGraphT *subGraph, const CNodeT &node) {
|
|
|
|
|
int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
|
|
|
|
|
auto status = QuantParamCalcer::Calc(graph, node);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
|
|
|
|
|
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
if (inputParamDone != node.inputIndex.size()) {
|
|
|
|
@ -137,7 +137,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
|
|
|
|
|
auto outputQuantParam = GetTensorQuantParam(outTensor);
|
|
|
|
|
MS_ASSERT(outputQuantParam != nullptr);
|
|
|
|
|
if (!outputQuantParam->inited) {
|
|
|
|
|
// MS_LOGW("Can not determine inputTensor quantParam from outputTensor for node %s", node.name.c_str());
|
|
|
|
|
MS_LOG(WARNING) << "Can not determine inputTensor quantParam from outputTensor for node " << node.name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
for (unsigned int i : node.inputIndex) {
|
|
|
|
@ -157,7 +157,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
|
|
|
|
|
MS_ASSERT(inTensor != nullptr);
|
|
|
|
|
auto inQuantParam = GetTensorQuantParam(inTensor);
|
|
|
|
|
if (!inQuantParam->inited) {
|
|
|
|
|
// MS_LOGW("Can not determine outputTensor quantParam from inputTensor for node %s", node.name.c_str());
|
|
|
|
|
MS_LOG(WARNING) << "Can not determine outputTensor quantParam from inputTensor for node %s" << node.name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < node.outputIndex.size(); i++) {
|
|
|
|
@ -182,12 +182,12 @@ class CalcConcat : public QuantParamCalcer {
|
|
|
|
|
MS_ASSERT(node.outputIndex.size() == 1);
|
|
|
|
|
auto status = QuantParamCalcer::Calc(graph, node);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
|
|
|
|
|
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (inputParamDone != node.inputIndex.size()) {
|
|
|
|
|
// MS_LOGW("Can not determine concat inputTensor quantParam, node %s", node.name.c_str());
|
|
|
|
|
MS_LOG(WARNING) << "Can not determine concat inputTensor quantParam, node " << node.name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -228,7 +228,7 @@ class CalcConcat : public QuantParamCalcer {
|
|
|
|
|
|
|
|
|
|
status = quant::CalQuantizationParams(outQuantParam.get(), minMin, maxMax, narrowRange, numBits);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
|
|
|
|
|
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
outputParamDone++;
|
|
|
|
@ -247,12 +247,12 @@ class CalcAdd : public QuantParamCalcer {
|
|
|
|
|
MS_ASSERT(node.outputIndex.size() == 1);
|
|
|
|
|
auto status = QuantParamCalcer::Calc(graph, node);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
|
|
|
|
|
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (inputParamDone != 2) {
|
|
|
|
|
// MS_LOGW("Can not determine add inputTensor quantParam, node %s", node.name.c_str());
|
|
|
|
|
MS_LOG(WARNING) << "Can not determine add inputTensor quantParam, node " << node.name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (outputParamDone != 1) {
|
|
|
|
@ -277,7 +277,7 @@ class CalcAdd : public QuantParamCalcer {
|
|
|
|
|
biasTensor = &tensor1;
|
|
|
|
|
paramTensor = &tensor0;
|
|
|
|
|
} else {
|
|
|
|
|
// MS_LOGW("Can not determine add outputTensor quantParam, node %s", node.name.c_str());
|
|
|
|
|
MS_LOG(WARNING) << "Can not determine add outputTensor quantParam, node " << node.name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto quantParam = GetTensorQuantParam(*paramTensor);
|
|
|
|
@ -292,7 +292,7 @@ class CalcAdd : public QuantParamCalcer {
|
|
|
|
|
auto *bias = static_cast<float *>(oriTensorData);
|
|
|
|
|
status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias));
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
|
|
|
|
|
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
} else if ((*biasTensor)->dataType == TypeId::kNumberTypeUInt8) {
|
|
|
|
@ -301,11 +301,11 @@ class CalcAdd : public QuantParamCalcer {
|
|
|
|
|
auto *bias = static_cast<uint8_t *>(oriTensorData);
|
|
|
|
|
status = quant::CalQuantizationParams(outQuantParam.get(), min + (*bias), max + (*bias));
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
|
|
|
|
|
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// MS_LOGW("Unsupported tensor dataType: %d", (*biasTensor)->dataType);
|
|
|
|
|
MS_LOG(WARNING) << "Unsupported tensor dataType: " << (*biasTensor)->dataType;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -323,12 +323,12 @@ class CalcRealDiv : public QuantParamCalcer {
|
|
|
|
|
MS_ASSERT(node.outputIndex.size() == 1);
|
|
|
|
|
auto status = QuantParamCalcer::Calc(graph, node);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
|
|
|
|
|
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: " << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (inputParamDone != 2) {
|
|
|
|
|
// MS_LOGW("Can not determine realdiv inputTensor quantParam, node %s", node.name.c_str());
|
|
|
|
|
MS_LOG(WARNING) << "Can not determine realdiv inputTensor quantParam, node " << node.name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
if (outputParamDone != 1) {
|
|
|
|
@ -354,7 +354,7 @@ class CalcRealDiv : public QuantParamCalcer {
|
|
|
|
|
MS_ASSERT(*div != 0);
|
|
|
|
|
status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max / (*div));
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
|
|
|
|
|
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
} else if (tensor1->dataType == TypeId::kNumberTypeUInt8) {
|
|
|
|
@ -363,16 +363,16 @@ class CalcRealDiv : public QuantParamCalcer {
|
|
|
|
|
auto *div = static_cast<uint8_t *>(oriTensorData);
|
|
|
|
|
status = quant::CalQuantizationParams(outQuantParam.get(), min / (*div), max + (*div));
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("in aware quantization run CalQuantizationParams failed!");
|
|
|
|
|
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// MS_LOGW("Unsupported tensor dataType: %d", tensor1->dataType);
|
|
|
|
|
MS_LOG(WARNING) << "Unsupported tensor dataType: " << tensor1->dataType;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
// MS_LOGW("Can not determine realDiv outputTensor quantParam, node %s", node.name.c_str());
|
|
|
|
|
MS_LOG(WARNING) << "Can not determine realDiv outputTensor quantParam, node " << node.name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -389,18 +389,18 @@ class CalcToSet : public QuantParamCalcer {
|
|
|
|
|
MS_ASSERT(node.outputIndex.size() == 1);
|
|
|
|
|
auto status = QuantParamCalcer::Calc(graph, node);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
// MS_LOGW("Call QuantParamCalcer::Calc failed: %d", status);
|
|
|
|
|
MS_LOG(WARNING) << "Call QuantParamCalcer::Calc failed: %d" << status;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
// input
|
|
|
|
|
if (inputParamDone != node.inputIndex.size()) {
|
|
|
|
|
// MS_LOGW("Can not determine inputTensor quantParam, node %s", node.name.c_str());
|
|
|
|
|
MS_LOG(WARNING) << "Can not determine inputTensor quantParam, node " << node.name;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
// output
|
|
|
|
|
std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) QuantParamT());
|
|
|
|
|
if (quantParam == nullptr) {
|
|
|
|
|
// MS_LOGW("new QuantParamT failed");
|
|
|
|
|
MS_LOG(WARNING) << "new QuantParamT failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
quantParam->scale = (max - min) / 256;
|
|
|
|
@ -486,7 +486,6 @@ QuantParamCalcRegister::QuantParamCalcRegister() {
|
|
|
|
|
_registerMap[schema::PrimitiveType_FullConnection] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer.get();
|
|
|
|
|
// todo
|
|
|
|
|
// detection_postprocess op's quant param will not infer only fetch from preNode or postNode
|
|
|
|
|
// because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float.
|
|
|
|
|
// if quantTransNode is inserted after detection_postprocess node, there will be some errors
|
|
|
|
|