diff --git a/mindspore/lite/src/ir/primitive_t_value.h b/mindspore/lite/src/ir/primitive_t_value.h index 4f4ce9ac5b..170715ed2c 100644 --- a/mindspore/lite/src/ir/primitive_t_value.h +++ b/mindspore/lite/src/ir/primitive_t_value.h @@ -47,7 +47,15 @@ class PrimitiveTValue : public Value { } } - void SetInputQuantParam(std::vector> vec_quant_param) {} + + void SetInputQuantParam(const std::vector> &input_quant_param) { + this->input_quant_param_ = input_quant_param; + } + + void SetOutputQuantParam(const std::vector> &output_quant_param) { + this->output_quant_param_ = output_quant_param; + } + void AddInputQuantParam(std::vector quant_param) { this->input_quant_param_.emplace_back(quant_param); diff --git a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc index 8bcc61f2f1..feb31753b3 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp32/nchw2nhwc.cc @@ -37,8 +37,13 @@ int Nchw2NhwcCPUKernel::Run() { auto output = out_tensors_[0]; if (input->shape().size() == 4) { - PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), - output->Channel()); + if (input->data_type() == kNumberTypeFloat32) { + PackNCHWToNHWCFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), + output->Channel()); + } else if (input->data_type() == kNumberTypeInt8) { + PackNCHWToNHWCInt8(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), + output->Channel()); + } } else { memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float)); } @@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNchw2NhwcFp32KernelCreator(const std::vectorshape().size() == 4) { - PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), - output->Channel()); + if (input->data_type() == kNumberTypeFloat32) { + PackNHWCToNCHWFp32(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), + output->Channel()); + } else if (input->data_type() == kNumberTypeInt8) { + PackNHWCToNCHWInt8(input->Data(), output->Data(), output->Batch(), output->Height() * output->Width(), + output->Channel()); + } } else { memcpy(output->Data(), input->Data(), input->ElementsNum() * sizeof(float)); } @@ -67,4 +72,5 @@ kernel::LiteKernel *CpuNhwc2NchwFp32KernelCreator(const std::vector((qmax - mean) / stdDev); } -void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, - std::vector> *vecQuantParam) { +void AnfConvPopulater::PopulaterQuantParam( + const PrimitivePtr &prim, + std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam) { auto narrow_range = prim->GetAttr("narrow_range"); bool narrowRangeQuantParam = GetValue(narrow_range); auto num_bits = prim->GetAttr("num_bits"); @@ -154,7 +156,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); quants.emplace_back(quantParam); - vecQuantParam->emplace_back(quants); + vecInputQuantParam->emplace_back(quants); quants.clear(); int biasQuantSize = 0; @@ -173,7 +175,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, numbitsRangeQuantParam); quants.emplace_back(quantParam); } - vecQuantParam->emplace_back(quants); + vecInputQuantParam->emplace_back(quants); } quants.clear(); @@ -181,10 +183,12 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, quantParam.min = 0.0; quantParam.max = 0.0; quantParam.zeroPoint = 0; - quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; + + quantParam.scale = + vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale; quants.emplace_back(quantParam); } - vecQuantParam->emplace_back(quants); + vecInputQuantParam->emplace_back(quants); quants.clear(); auto outputMin = prim->GetAttr("output_minq"); @@ -199,7 +203,7 @@ void AnfConvPopulater::PopulaterQuantParam(const PrimitivePtr &prim, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); quants.emplace_back(quantParam); - vecQuantParam->emplace_back(quants); + vecOutputQuantParam->emplace_back(quants); } } @@ -215,10 +219,13 @@ int AnfConvPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *primit PopulaterConv2DSingleGroup(prim, primitive, group); } primitiveTValuePtr->SetPrimitiveT(primitive.release()); + if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) { - std::vector> vecQuantParam; - PopulaterQuantParam(prim, &vecQuantParam); - primitiveTValuePtr->SetInputQuantParam(vecQuantParam); + std::vector> vecInputQuantParam; + std::vector> vecOutputQuantParam; + PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); + primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam); + primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam); } return 0; } diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h index e4befe36df..678897fe92 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_conv_populater.h @@ -20,9 +20,10 @@ #ifndef MINDSPORE_ANF_CONV_PARSER_H #define MINDSPORE_ANF_CONV_PARSER_H -#include "tools/anf_importer/anf_populater/anf_node_populater.h" -#include #include +#include +#include "tools/anf_importer/anf_populater/anf_node_populater.h" + namespace mindspore::lite { class AnfConvPopulater : public AnfNodePopulater { public: @@ -32,12 +33,18 @@ class AnfConvPopulater : public AnfNodePopulater { const std::vector &inputs) override; private: - void PopulaterConv2DMultiGroup(const PrimitivePtr &prim, const std::unique_ptr &primitive, - const int &group); - void PopulaterConv2DSingleGroup(const PrimitivePtr &prim, const std::unique_ptr &primitive, - const int &group); - void PopulaterQuantParam(const PrimitivePtr &prim, std::vector> *vecQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); + void PopulaterConv2DMultiGroup( + const PrimitivePtr &prim, + const std::unique_ptr &primitive, const int &group); + void PopulaterConv2DSingleGroup( + const PrimitivePtr &prim, + const std::unique_ptr &primitive, const int &group); + void PopulaterQuantParam( + const PrimitivePtr &prim, + std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, + float *mMax); }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc index 874a52df5c..583e8df423 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.cc @@ -31,8 +31,10 @@ void AnfDepwiseconv2DPopulater::CalQuantParam(const double &mean, const double & *mMax = static_cast((qmax - mean) / stdDev); } -void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, - std::vector> *vecQuantParam) { +void AnfDepwiseconv2DPopulater::PopulaterQuantParam( + const PrimitivePtr &prim, + std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam) { auto narrow_range = prim->GetAttr("narrow_range"); bool narrowRangeQuantParam = GetValue(narrow_range); auto num_bits = prim->GetAttr("num_bits"); @@ -63,7 +65,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); quants.emplace_back(quantParam); - vecQuantParam->emplace_back(quants); + vecInputQuantParam->emplace_back(quants); quants.clear(); int biasQuantSize = 0; @@ -82,7 +84,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, numbitsRangeQuantParam); quants.emplace_back(quantParam); } - vecQuantParam->emplace_back(quants); + vecInputQuantParam->emplace_back(quants); } quants.clear(); @@ -90,10 +92,12 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, quantParam.min = 0.0; quantParam.max = 0.0; quantParam.zeroPoint = 0; - quantParam.scale = vecQuantParam->at(0).at(0).scale * vecQuantParam->at(1).at(i).scale; + + quantParam.scale = + vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(i).scale; quants.emplace_back(quantParam); } - vecQuantParam->emplace_back(quants); + vecInputQuantParam->emplace_back(quants); quants.clear(); auto outputMin = prim->GetAttr("output_minq"); @@ -108,7 +112,7 @@ void AnfDepwiseconv2DPopulater::PopulaterQuantParam(const PrimitivePtr &prim, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); quants.emplace_back(quantParam); - vecQuantParam->emplace_back(quants); + vecOutputQuantParam->emplace_back(quants); } } @@ -177,10 +181,12 @@ int AnfDepwiseconv2DPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValu MS_ASSERT(primitiveTValuePtr != nullptr); primitiveTValuePtr->SetPrimitiveT(primitive.release()); - if (primitiveTValuePtr->GetQuantType()) { - std::vector> vecQuantParam; - PopulaterQuantParam(prim, &vecQuantParam); - primitiveTValuePtr->SetInputQuantParam(vecQuantParam); + if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) { + std::vector> vecInputQuantParam; + std::vector> vecOutputQuantParam; + PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); + primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam); + primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam); } return 0; } diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h index 5b58bf3b6e..005f132516 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_depthwiseconv2d_populater.h @@ -28,8 +28,12 @@ class AnfDepwiseconv2DPopulater : public AnfNodePopulater { const std::vector &inputs) override; private: - void PopulaterQuantParam(const PrimitivePtr &prim, std::vector> *vecQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); + void PopulaterQuantParam( + const PrimitivePtr &prim, + std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, + float *mMax); }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc index fe315d780e..de3a84ee37 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.cc @@ -30,8 +30,10 @@ void AnfMatmulPopulater::CalQuantParam(const double &mean, const double &stdDev, *mMax = static_cast((qmax - mean) / stdDev); } -void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim, - std::vector> *vecQuantParam) { +void AnfMatmulPopulater::PopulaterQuantParam( + const PrimitivePtr &prim, + std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam) { auto narrow_range = prim->GetAttr("narrow_range"); bool narrowRangeQuantParam = GetValue(narrow_range); auto num_bits = prim->GetAttr("num_bits"); @@ -62,7 +64,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); quants.emplace_back(quantParam); - vecQuantParam->emplace_back(quants); + vecInputQuantParam->emplace_back(quants); quants.clear(); auto filterMin = prim->GetAttr("filter_minq"); @@ -79,7 +81,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim, numbitsRangeQuantParam); quants.emplace_back(quantParam); } - vecQuantParam->emplace_back(quants); + vecInputQuantParam->emplace_back(quants); } quants.clear(); @@ -95,7 +97,7 @@ void AnfMatmulPopulater::PopulaterQuantParam(const PrimitivePtr &prim, quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); quants.emplace_back(quantParam); - vecQuantParam->emplace_back(quants); + vecOutputQuantParam->emplace_back(quants); } } @@ -110,12 +112,13 @@ int AnfMatmulPopulater::Populate(const PrimitivePtr &prim, PrimitiveTValue *prim primitive->value.value = attr.release(); MS_ASSERT(primitiveTValuePtr != nullptr); primitiveTValuePtr->SetPrimitiveT(primitive.release()); - if (primitiveTValuePtr->GetQuantType()) { - std::vector> vecQuantParam; - PopulaterQuantParam(prim, &vecQuantParam); - primitiveTValuePtr->SetInputQuantParam(vecQuantParam); + if (primitiveTValuePtr->GetQuantType() == schema::QuantType_AwareTraining) { + std::vector> vecInputQuantParam; + std::vector> vecOutputQuantParam; + PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); + primitiveTValuePtr->SetInputQuantParam(vecInputQuantParam); + primitiveTValuePtr->SetOutputQuantParam(vecOutputQuantParam); } - return 0; } AnfNodePopulaterRegistrar anfMatmulPopulater("Matmul", new AnfMatmulPopulater()); diff --git a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h index 39b7be7f5a..d99cf57339 100644 --- a/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h +++ b/mindspore/lite/tools/anf_importer/anf_populater/anf_matmul_populater.h @@ -26,8 +26,12 @@ class AnfMatmulPopulater : public AnfNodePopulater { const std::vector &inputs) override; private: - void PopulaterQuantParam(const PrimitivePtr &prim, std::vector> *vecQuantParam); - void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); + void PopulaterQuantParam( + const PrimitivePtr &prim, + std::vector> *vecInputQuantParam, + std::vector> *vecOutputQuantParam); + void CalQuantParam(const double &mean, const double &stdDev, float *mMin, + float *mMax); }; } // namespace mindspore::lite diff --git a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc index 90c6be2a28..9bc4e3dcf1 100644 --- a/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/aware_quantizer.cc @@ -15,20 +15,22 @@ */ #include "tools/converter/quantizer/aware_quantizer.h" + #include #include #include #include #include + #include "schema/inner/model_generated.h" -#include "utils/log_adapter.h" #include "securec/include/securec.h" -#include "tools/converter/quantizer/quantize_util.h" #include "src/common/utils.h" -#include "tools/converter/quantizer/calc_quant_param.h" -#include "tools/common/tensor_util.h" #include "tools/common/converter_op_utils.h" #include "tools/common/node_util.h" +#include "tools/common/tensor_util.h" +#include "tools/converter/quantizer/calc_quant_param.h" +#include "tools/converter/quantizer/quantize_util.h" +#include "utils/log_adapter.h" using std::string; using std::vector; @@ -42,7 +44,8 @@ struct InputArray { int numBits = 8; TypeId dataType = TypeId::kTypeUnknown; - InputArray(float mean, float stdDev, TypeId dataType = TypeId::kNumberTypeFloat) { + InputArray(float mean, float stdDev, + TypeId dataType = TypeId::kNumberTypeFloat) { this->dataType = dataType; constexpr float qmin = 0; constexpr float qmax = 255; @@ -52,7 +55,8 @@ struct InputArray { STATUS InitQuantParam() { this->quantParam = std::make_unique(); - auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, narrowRange, numBits); + auto status = CalQuantizationParams(quantParam.get(), mMin, mMax, + narrowRange, numBits); if (status != RET_OK) { return status; } @@ -66,7 +70,8 @@ struct InputArray { if (!tensor->quantParams.empty()) { auto param = GetTensorQuantParam(tensor); if (param != nullptr && param->inited) { - MS_LOG(DEBUG) << "tensor " << inputTensorIdx << " already has quantParam"; + MS_LOG(DEBUG) << "tensor " << inputTensorIdx + << " already has quantParam"; return RET_OK; } tensor->quantParams.clear(); @@ -83,11 +88,14 @@ struct InputArray { }; const std::array AwareQuantizer::propagatedOps = { - {schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, schema::PrimitiveType_Reshape, - schema::PrimitiveType_Squeeze, schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation, - schema::PrimitiveType_DetectionPostProcess}}; + {schema::PrimitiveType_Concat, schema::PrimitiveType_Resize, + schema::PrimitiveType_Reshape, schema::PrimitiveType_Squeeze, + schema::PrimitiveType_RealDiv, schema::PrimitiveType_Activation, + schema::PrimitiveType_DetectionPostProcess}}; -AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, const string &inputInferType, const string &stdValues, +AwareQuantizer::AwareQuantizer(schema::MetaGraphT *graph, + const string &inputInferType, + const string &stdValues, const string &meanValues) : FbQuantizer(graph) { MS_ASSERT(graph != nullptr); @@ -110,9 +118,11 @@ STATUS AwareQuantizer::RemoveFakeQuant() { // MS_LOGE("GenerateDefaultQuantParam failed: %d", status); // return RET_ERROR; // } - // for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); iter++) { + // for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end(); + // iter++) { // auto *node = (*iter).get(); - // if (GetCNodeTType(*node) != OpT_FakeQuantWithMinMaxVars && GetCNodeTType(*node) != OpT_FakeQuantWithMinMax) { + // if (GetCNodeTType(*node) != OpT_FakeQuantWithMinMaxVars && + // GetCNodeTType(*node) != OpT_FakeQuantWithMinMax) { // continue; // } // auto inputIndexes = node->inputIndex; @@ -144,41 +154,43 @@ STATUS AwareQuantizer::RemoveFakeQuant() { // auto *maxData = reinterpret_cast(tensor2->data.data()); // MS_ASSERT(minData != nullptr); // MS_ASSERT(maxData != nullptr); - // std::unique_ptr quantParam(new (std::nothrow) QuantParamT()); - // if (quantParam == nullptr) { + // std::unique_ptr quantParam(new (std::nothrow) + // QuantParamT()); if (quantParam == nullptr) { // MS_LOGE("new quantParam failed"); // return RET_ERROR; // } // auto realMin = (double)minData[0]; // auto realMax = (double)maxData[0]; - // status = CalQuantizationParams(quantParam.get(), realMin, realMax, narrorRange, numBits); - // if (status != RET_OK) { - // MS_LOGE("in aware quantization run CalQuantizationParams failed, node: %s", node->name.c_str()); - // return RET_ERROR; + // status = CalQuantizationParams(quantParam.get(), realMin, realMax, + // narrorRange, numBits); if (status != RET_OK) { + // MS_LOGE("in aware quantization run CalQuantizationParams failed, + // node: %s", node->name.c_str()); return RET_ERROR; // } // if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT) { // CalFakeNode(tensor0, quantParam.get()); // } - // std::unique_ptr quantParamArray(new (std::nothrow) QuantParamArrayT()); - // if (quantParamArray == nullptr) { + // std::unique_ptr quantParamArray(new (std::nothrow) + // QuantParamArrayT()); if (quantParamArray == nullptr) { // MS_LOGE("new quantParamArray failed"); // return RET_ERROR; // } // quantParamArray->param.push_back(std::move(quantParam)); // auto quantParamArrayCopy = CopyQuantParamArrayT(quantParamArray); // if (quantParamArrayCopy == nullptr) { - // MS_LOGE("CopyQuantParamArray %s return nullptr", iter->get()->name.c_str()); - // return RET_ERROR; + // MS_LOGE("CopyQuantParamArray %s return nullptr", + // iter->get()->name.c_str()); return RET_ERROR; // } // node->quantParam.emplace_back(std::move(quantParamArrayCopy)); - // node->quantParam.emplace_back(nullptr); // secondInTensor and thirdInTensor are weightTensors who have no - // preNode node->quantParam.emplace_back(nullptr); node->quantParam.emplace_back(std::move(quantParamArray)); + // node->quantParam.emplace_back(nullptr); // secondInTensor and + // thirdInTensor are weightTensors who have no preNode + // node->quantParam.emplace_back(nullptr); + // node->quantParam.emplace_back(std::move(quantParamArray)); // // // BroadCast fakeQuantNode QuantParam // status = BroadCastQuantParam(subGraph, *iter); // if (status != RET_OK) { - // MS_LOGE("BroadCastQuantParam %s failed: %d", iter->get()->name.c_str(), status); - // return status; + // MS_LOGE("BroadCastQuantParam %s failed: %d", + // iter->get()->name.c_str(), status); return status; // } // // save post node index for SetAttrToConvolution // auto postNodeIdxes = GetOutputNodeIdx(*subGraph, *node); @@ -189,10 +201,13 @@ STATUS AwareQuantizer::RemoveFakeQuant() { // return RET_ERROR; // } // // set filter param to node - // if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT && !postNodeIdxes.empty()) { + // if (tensor0->refCount == MSCONST_WEIGHT_REFCOUNT && + // !postNodeIdxes.empty()) { // auto postNode = subGraph->nodes.at(postNodeIdxes.front()).get(); - // if (GetCNodeTType(*postNode) == OpT_Conv2D || GetCNodeTType(*postNode) == OpT_DepthwiseConv2D || - // GetCNodeTType(*postNode) == OpT_DeConv2D || GetCNodeTType(*postNode) == OpT_DeDepthwiseConv2D) { + // if (GetCNodeTType(*postNode) == OpT_Conv2D || + // GetCNodeTType(*postNode) == OpT_DepthwiseConv2D || + // GetCNodeTType(*postNode) == OpT_DeConv2D || + // GetCNodeTType(*postNode) == OpT_DeDepthwiseConv2D) { // auto status = SetAttrToConvolution(subGraph.get(), postNode); // if (status != RET_OK) { // MS_LOGE("in aware quant SetAttrToConvolution failed!"); @@ -203,7 +218,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() { // } // // // remove IsolatedNode - // for (auto iter = subGraph->nodes.begin(); iter != subGraph->nodes.end();) { + // for (auto iter = subGraph->nodes.begin(); iter != + // subGraph->nodes.end();) { // if ((*iter)->inputIndex.empty() && (*iter)->outputIndex.empty()) { // iter = subGraph->nodes.erase(iter); // } else { @@ -213,8 +229,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() { // // set graphInputNode inputTensor quantParams // MS_ASSERT(subGraph->inputIndex.size() == 1); // for (auto graphInputIndex : subGraph->inputIndex) { - // auto linkedPostIdx = GetLinkedPostIdx(*(subGraph.get()), graphInputIndex); - // for (auto nodeIdx : linkedPostIdx) { + // auto linkedPostIdx = GetLinkedPostIdx(*(subGraph.get()), + // graphInputIndex); for (auto nodeIdx : linkedPostIdx) { // MS_ASSERT(subGraph->nodes.size() > nodeIdx); // mInputArray->SetInputArrayQP(subGraph->nodes.at(nodeIdx).get()); // } @@ -223,7 +239,8 @@ STATUS AwareQuantizer::RemoveFakeQuant() { return RET_OK; } -STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGraph) { +STATUS AwareQuantizer::GenerateDefaultQuantParam( + const schema::MetaGraphT *subGraph) { MS_ASSERT(subGraph != nullptr); for (const auto &tensor : subGraph->allTensors) { if (!tensor->quantParams.empty()) { @@ -235,15 +252,18 @@ STATUS AwareQuantizer::GenerateDefaultQuantParam(const schema::MetaGraphT *subGr return RET_OK; } -STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { +STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, + schema::CNodeT *node) { // MS_ASSERT(subGraph != nullptr); // MS_ASSERT(node != nullptr); // auto inputIndexes = node->inputIndex; - // MS_ASSERT(GetCNodeTType(*node) == OpT_Conv2D || GetCNodeTType(*node) == OpT_DepthwiseConv2D || - // GetCNodeTType(*node) == OpT_DeConv2D || GetCNodeTType(*node) == OpT_DeDepthwiseConv2D); + // MS_ASSERT(GetCNodeTType(*node) == OpT_Conv2D || GetCNodeTType(*node) == + // OpT_DepthwiseConv2D || + // GetCNodeTType(*node) == OpT_DeConv2D || GetCNodeTType(*node) == + // OpT_DeDepthwiseConv2D); // if (inputIndexes.size() < 2) { - // MS_LOGE("in aware quant %s node's input tensors is invalid(%zu)!", node->name.c_str(), inputIndexes.size()); - // return RET_ERROR; + // MS_LOGE("in aware quant %s node's input tensors is invalid(%zu)!", + // node->name.c_str(), inputIndexes.size()); return RET_ERROR; // } // TensorDefT *filterTensor = subGraph->allTensors.at(inputIndexes[1]).get(); // MS_ASSERT(filterTensor != nullptr); @@ -267,14 +287,16 @@ STATUS AwareQuantizer::SetAttrToConvolution(const schema::MetaGraphT *subGraph, // if (GetCNodeTType(*node) == OpT_DepthwiseConv2D) { // if (node->fmkType == FmkType_MS) { // node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[0]; - // node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[1]; - // node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[2]; - // node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[3]; + // node->attr.AsDepthwiseConv2D()->channelMultiplier = + // (int32_t)filterDims[1]; node->attr.AsDepthwiseConv2D()->kernelH = + // (int32_t)filterDims[2]; node->attr.AsDepthwiseConv2D()->kernelW = + // (int32_t)filterDims[3]; // } else if (node->fmkType == FmkType_TF) { // node->attr.AsDepthwiseConv2D()->kernelH = (int32_t)filterDims[0]; // node->attr.AsDepthwiseConv2D()->kernelW = (int32_t)filterDims[1]; // node->attr.AsDepthwiseConv2D()->channelIn = (int32_t)filterDims[2]; - // node->attr.AsDepthwiseConv2D()->channelMultiplier = (int32_t)filterDims[3]; + // node->attr.AsDepthwiseConv2D()->channelMultiplier = + // (int32_t)filterDims[3]; // } else { // MS_LOGE("Unsupport"); // } @@ -313,15 +335,19 @@ STATUS AwareQuantizer::GenerateQuantParam() { GetCNodeTType(*node) == schema::PrimitiveType_FakeQuantWithMinMaxVars) { MS_ASSERT(false); } - auto *quantParamCalcer = quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); + auto *quantParamCalcer = + quantParamRegister->GetQuantParamCalcer(GetCNodeTType(*node)); if (quantParamCalcer == nullptr) { - MS_LOG(ERROR) << "Can not find QuantParamCalcer for " << node->name.c_str() - << ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip"; + MS_LOG(ERROR) << "Can not find QuantParamCalcer for " + << node->name.c_str() + << ", type: " << GetCNodeTTypeName(*node).c_str() + << " set node to QuantNone and skip"; node->quantType = static_cast(QuantType_QUANT_NONE); } else { status = quantParamCalcer->Calc(graph, *node); if (status != RET_OK) { - MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str(); + MS_LOG(ERROR) << "quantParamCalcer failed: " << status + << " node: " << node->name.c_str(); node->quantType = schema::QuantType_QUANT_NONE; } else { node->quantType = schema::QuantType_AwareTraining; @@ -345,7 +371,8 @@ STATUS AwareQuantizer::DoQuantize() { GetCNodeTType(*node) == schema::PrimitiveType_DepthwiseConv2D) { auto inputIndexes = node->inputIndex; if (inputIndexes.size() < 2) { - MS_LOG(ERROR) << node->name.c_str() << " node input has invalid inputs tensor count"; + MS_LOG(ERROR) << node->name.c_str() + << " node input has invalid inputs tensor count"; return RET_ERROR; } // quant weight @@ -362,7 +389,8 @@ STATUS AwareQuantizer::DoQuantize() { return RET_ERROR; } } - } else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) { + } else if (GetCNodeTType(*node) == + schema::PrimitiveType_DetectionPostProcess) { status = QuantDetectionPostProcessConstTensor(graph, node.get()); if (status != RET_OK) { MS_LOG(ERROR) << "QuantDetectionPostProcessConstTensor failed!"; @@ -388,7 +416,8 @@ STATUS AwareQuantizer::DoQuantize() { return RET_OK; } -STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, schema::CNodeT *node) { +STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, + schema::CNodeT *node) { MS_ASSERT(graph != nullptr); MS_ASSERT(node != nullptr); for (size_t i = 0; i < node->inputIndex.size(); i++) { @@ -407,7 +436,8 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche void *inData = inTensor->data.data(); auto *castedInData = static_cast(inData); for (size_t j = 0; j < constTensorShapeSize; j++) { - qDatas[j] = QuantizeData(castedInData[j], quantParam.get()); + qDatas[j] = + QuantizeData(castedInData[j], quantParam.get()); } inTensor->data = std::move(qDatas); inTensor->dataType = kNumberTypeUInt8; @@ -423,14 +453,17 @@ STATUS AwareQuantizer::QuantAddConstTensor(const schema::MetaGraphT *graph, sche return RET_OK; } -STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { +STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor( + const schema::MetaGraphT *subGraph, schema::CNodeT *node) { MS_ASSERT(subGraph != nullptr); MS_ASSERT(node != nullptr); auto &constTensor = subGraph->allTensors.at(node->inputIndex[2]); MS_ASSERT(constTensor != nullptr); - const auto *constData = reinterpret_cast(constTensor->data.data()); + const auto *constData = + reinterpret_cast(constTensor->data.data()); - if (constTensor->refCount == 999 && constTensor->dataType == TypeId::kNumberTypeFloat) { + if (constTensor->refCount == 999 && + constTensor->dataType == TypeId::kNumberTypeFloat) { size_t constTensorShapeSize = GetShapeSize(*constTensor); std::unique_ptr quantParam = GetTensorQuantParam(constTensor); if (quantParam == nullptr) { @@ -448,7 +481,8 @@ STATUS AwareQuantizer::QuantDetectionPostProcessConstTensor(const schema::MetaGr return RET_OK; } -STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, mindspore::schema::CNodeT *node) { +STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, + mindspore::schema::CNodeT *node) { MS_ASSERT(graph != nullptr); MS_ASSERT(node != nullptr); auto inputIndexes = node->inputIndex; @@ -507,7 +541,8 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, biasTensor->dataType = TypeId::kNumberTypeInt32; biasTensor->data.clear(); biasTensor->data.resize(bShapeSize * sizeof(int32_t)); - auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), qDatas, bShapeSize * sizeof(int32_t)); + auto ret = memcpy_s(biasTensor->data.data(), bShapeSize * sizeof(int32_t), + qDatas, bShapeSize * sizeof(int32_t)); if (ret != EOK) { // MS_LOGE("memcpy_s failed: %d", ret); delete[] qDatas; @@ -517,10 +552,12 @@ STATUS AwareQuantizer::QuantConvBias(const mindspore::schema::MetaGraphT *graph, return RET_OK; } -STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schema::CNodeT *node) { +STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, + schema::CNodeT *node) { MS_ASSERT(subGraph != nullptr); MS_ASSERT(node != nullptr); - MS_ASSERT(node->quantParam.size() == node->inputIndex.size() + node->outputIndex.size()); + MS_ASSERT(node->quantParam.size() == + node->inputIndex.size() + node->outputIndex.size()); auto inputIndexes = node->inputIndex; MS_ASSERT(inputIndexes.size() >= 2); MS_ASSERT(subGraph->allTensors.size() > inputIndexes.at(1)); @@ -528,8 +565,11 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem if (weightTensor->dataType == TypeId::kNumberTypeInt8) { return RET_OK; } - if (weightTensor->dataType != TypeId::kNumberTypeFloat && weightTensor->dataType != TypeId::kNumberTypeUInt8) { - MS_LOG(ERROR) << "conv " << node->name.c_str() << "'s weight data is not float or uint8"; + if (weightTensor->dataType != TypeId::kNumberTypeFloat32 && + weightTensor->dataType != TypeId::kNumberTypeFloat && + weightTensor->dataType != TypeId::kNumberTypeUInt8) { + MS_LOG(ERROR) << "conv " << node->name.c_str() + << "'s weight data is not float or uint8"; return RET_ERROR; } size_t wShapeSize = GetShapeSize(*(weightTensor.get())); @@ -537,7 +577,8 @@ STATUS AwareQuantizer::QuantConvWeight(const schema::MetaGraphT *subGraph, schem MS_ASSERT(node->quantParam.at(1)->param.front() != nullptr); vector qDatas(wShapeSize); auto weightQauntParam = GetTensorQuantParam(weightTensor); - if (weightTensor->dataType == TypeId::kNumberTypeFloat) { // normal awareing quant + if (weightTensor->dataType == + TypeId::kNumberTypeFloat) { // normal awareing quant auto *weightData = static_cast(oriWeightData); for (size_t j = 0; j < wShapeSize; j++) { qDatas[j] = QuantizeData(weightData[j], weightQauntParam.get()); @@ -565,7 +606,8 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { MS_ASSERT(graph->allTensors.size() > inTensorIdx); auto &inTensor = graph->allTensors.at(inTensorIdx); MS_ASSERT(inTensor != nullptr); - if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr || + if (inTensor->quantParams.empty() || + inTensor->quantParams.front() == nullptr || !inTensor->quantParams.front()->inited) { canQuant = false; break; @@ -577,7 +619,8 @@ STATUS AwareQuantizer::DetermineNodeQuantType() { MS_ASSERT(graph->allTensors.size() > outTensorIdx); auto &outTensor = graph->allTensors.at(outTensorIdx); MS_ASSERT(outTensor != nullptr); - if (outTensor->quantParams.empty() || outTensor->quantParams.front() == nullptr || + if (outTensor->quantParams.empty() || + outTensor->quantParams.front() == nullptr || !outTensor->quantParams.front()->inited) { canQuant = false; break;