|
|
|
@ -293,13 +293,13 @@ STATUS AwareQuantizer::GenerateQuantParam() {
|
|
|
|
|
MS_ASSERT(graph->inputIndex.size() == 1);
|
|
|
|
|
// set graphInputNode input
|
|
|
|
|
for (auto graphInputIndex : graph->inputIndex) {
|
|
|
|
|
auto status = mInputArray->SetInputArrayQP(graph.get(), graphInputIndex);
|
|
|
|
|
auto status = mInputArray->SetInputArrayQP(graph, graphInputIndex);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "SetInputArrayQP failed";
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto status = GenerateDefaultQuantParam(graph.get());
|
|
|
|
|
auto status = GenerateDefaultQuantParam(graph);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "GenerateDefaultQuantParam failed";
|
|
|
|
|
return status;
|
|
|
|
@ -319,7 +319,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
|
|
|
|
|
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
|
|
|
|
|
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
|
|
|
|
|
} else {
|
|
|
|
|
status = quantParamCalcer->Calc(graph.get(), *node);
|
|
|
|
|
status = quantParamCalcer->Calc(graph, *node);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
|
|
|
|
|
node->quantType = schema::QuantType_QUANT_NONE;
|
|
|
|
@ -349,27 +349,27 @@ STATUS AwareQuantizer::DoQuantize() {
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
// quant weight
|
|
|
|
|
status = QuantConvWeight(graph.get(), node.get());
|
|
|
|
|
status = QuantConvWeight(graph, node.get());
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "QuantConvWeight failed!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
// quant bias
|
|
|
|
|
if (inputIndexes.size() == 3) {
|
|
|
|
|
status = QuantConvBias(graph.get(), node.get());
|
|
|
|
|
status = QuantConvBias(graph, node.get());
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "QuantConvBias failed!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
|
|
|
|
|
status = QuantDetectionPostProcessConstTensor(graph.get(), node.get());
|
|
|
|
|
status = QuantDetectionPostProcessConstTensor(graph, node.get());
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
} else if (GetCNodeTType(*node) == schema::PrimitiveType_Add) {
|
|
|
|
|
status = QuantAddConstTensor(graph.get(), node.get());
|
|
|
|
|
status = QuantAddConstTensor(graph, node.get());
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "QuantAddConstTensor failed!";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|