|
|
|
@ -439,61 +439,52 @@ class CalcActivation : public QuantParamCalcer {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
QuantParamCalcRegister::~QuantParamCalcRegister() {
|
|
|
|
|
for (auto ite : _registerMap) {
|
|
|
|
|
if (ite.second != nullptr) {
|
|
|
|
|
delete ite.second;
|
|
|
|
|
ite.second = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
QuantParamCalcRegister::~QuantParamCalcRegister() {}
|
|
|
|
|
|
|
|
|
|
QuantParamCalcRegister::QuantParamCalcRegister() {
|
|
|
|
|
bool hasError = false;
|
|
|
|
|
std::unique_ptr<QuantParamCalcer> baseCalcer(new (std::nothrow) QuantParamCalcer());
|
|
|
|
|
std::shared_ptr<QuantParamCalcer> baseCalcer = std::make_shared<QuantParamCalcer>();
|
|
|
|
|
if (baseCalcer == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new QuantParamCalcer failed";
|
|
|
|
|
hasError = true;
|
|
|
|
|
}
|
|
|
|
|
std::unique_ptr<CommonCalcer> commonCalcer(new (std::nothrow) CommonCalcer());
|
|
|
|
|
std::shared_ptr<QuantParamCalcer> commonCalcer = std::make_shared<CommonCalcer>();
|
|
|
|
|
if (commonCalcer == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new commonCalcer failed";
|
|
|
|
|
hasError = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<LinearCalcer> linearCalcer(new (std::nothrow) LinearCalcer());
|
|
|
|
|
std::shared_ptr<QuantParamCalcer> linearCalcer = std::make_shared<LinearCalcer>();
|
|
|
|
|
if (linearCalcer == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new linearCalcer failed";
|
|
|
|
|
hasError = true;
|
|
|
|
|
}
|
|
|
|
|
if (!hasError) {
|
|
|
|
|
_registerMap[schema::PrimitiveType_Concat] = new CalcConcat();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Activation] = new CalcActivation();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Add] = new CalcAdd();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Mul] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Conv2D] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Pooling] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Resize] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Reshape] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Shape] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_SoftMax] = new CalcToSet(0, 1);
|
|
|
|
|
_registerMap[schema::PrimitiveType_Squeeze] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_RealDiv] = new CalcRealDiv();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Reduce] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Mean] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Transpose] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_MatMul] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_FullConnection] = commonCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer.get();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Concat] = std::make_shared<CalcConcat>();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Activation] = std::make_shared<CalcActivation>();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Add] = std::make_shared<CalcAdd>();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Mul] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Conv2D] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_DepthwiseConv2D] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Pooling] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Resize] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Reshape] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Shape] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_SoftMax] = std::make_shared<CalcToSet>(0, 1);
|
|
|
|
|
_registerMap[schema::PrimitiveType_Squeeze] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_RealDiv] = std::make_shared<CalcRealDiv>();
|
|
|
|
|
_registerMap[schema::PrimitiveType_Reduce] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_BiasAdd] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Mean] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Transpose] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_MatMul] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_FullConnection] = commonCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Nchw2Nhwc] = linearCalcer;
|
|
|
|
|
_registerMap[schema::PrimitiveType_Nhwc2Nchw] = linearCalcer;
|
|
|
|
|
// detection_postprocess op's quant param will not infer only fetch from preNode or postNode
|
|
|
|
|
// because we will not insert quantTransNode after this node in tflite_graph_8bit model if input data is float.
|
|
|
|
|
// if quantTransNode is inserted after detection_postprocess node, there will be some errors
|
|
|
|
|
_registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer.get();
|
|
|
|
|
baseCalcer.release();
|
|
|
|
|
linearCalcer.release();
|
|
|
|
|
commonCalcer.release();
|
|
|
|
|
_registerMap[schema::PrimitiveType_DetectionPostProcess] = baseCalcer;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -502,7 +493,7 @@ QuantParamCalcRegister *QuantParamCalcRegister::GetInstance() {
|
|
|
|
|
return &instance;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
QuantParamCalcer *QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) {
|
|
|
|
|
std::shared_ptr<QuantParamCalcer> QuantParamCalcRegister::GetQuantParamCalcer(schema::PrimitiveType opType) {
|
|
|
|
|
auto it = _registerMap.find(opType);
|
|
|
|
|
if (it != _registerMap.end()) {
|
|
|
|
|
return it->second;
|
|
|
|
|