|
|
|
@ -441,50 +441,62 @@ class CalcActivation : public QuantParamCalcer {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
QuantParamCalcRegister::~QuantParamCalcRegister() {
|
|
|
|
|
for (auto ite : _registerMap) {
|
|
|
|
|
if (ite.second != nullptr) {
|
|
|
|
|
delete ite.second;
|
|
|
|
|
ite.second = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
QuantParamCalcRegister::QuantParamCalcRegister() {
|
|
|
|
|
bool hasError = false;
|
|
|
|
|
auto baseCalcer = new (std::nothrow) QuantParamCalcer();
|
|
|
|
|
std::unique_ptr<QuantParamCalcer> baseCalcer(new (std::nothrow) QuantParamCalcer());
|
|
|
|
|
if (baseCalcer == nullptr) {
|
|
|
|
|
// MS_LOGW("new QuantParamCalcer failed");
|
|
|
|
|
MS_LOG(ERROR) << "new QuantParamCalcer failed";
|
|
|
|
|
hasError = true;
|
|
|
|
|
}
|
|
|
|
|
auto commonCalcer = new (std::nothrow) CommonCalcer();
|
|
|
|
|
std::unique_ptr<CommonCalcer> commonCalcer(new (std::nothrow) CommonCalcer());
|
|
|
|
|
if (commonCalcer == nullptr) {
|
|
|
|
|
// MS_LOGW("new commonCalcer failed");
|
|
|
|
|
MS_LOG(ERROR) << "new commonCalcer failed";
|
|
|
|
|
hasError = true;
|
|
|
|
|
}
|
|
|
|
|
auto linearCalcer = new (std::nothrow) LinearCalcer();
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<LinearCalcer> linearCalcer(new (std::nothrow) LinearCalcer());
|
|
|
|
|
if (linearCalcer == nullptr) {
|
|
|
|
|
// MS_LOGW("new linearCalcer failed");
|
|
|
|
|
MS_LOG(ERROR) << "new linearCalcer failed";
|
|
|
|
|
hasError = true;
|
|
|
|
|
}
|
|
|
|
|
if (!hasError) {
|
|
|
|
|
_registerMap[schema::PrimitiveType_Concat] = new CalcConcat();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Activation] = new CalcActivation();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Add] = new CalcAdd();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Mul] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Conv2D] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Pooling] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Resize] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Reshape] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Shape] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Mul] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Conv2D] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Pooling] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Resize] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Reshape] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Shape] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_SoftMax] = new CalcToSet(0, 1);
|
|
|
|
|
_registerMap[schema::PrimitiveType_Squeeze] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Squeeze] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_RealDiv] = new CalcRealDiv();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Reduce] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Mean] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Transpose] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_MatMul] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_FullConnection] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Reduce] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Mean] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Transpose] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_MatMul] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_FullConnection] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer.get();
|
|
|
|
|
// todo
|
|
|
|
|
// detection_postprocess op's quant param will not infer only fetch from preNode or postNode
|
|
|
|
|
// because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float.
|
|
|
|
|
// if quantTransNode is inserted after detection_postprocess node, there will be some errors
|
|
|
|
|
_registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer.get();
|
|
|
|
|
baseCalcer.release();
|
|
|
|
|
linearCalcer.release();
|
|
|
|
|
commonCalcer.release();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|