From a1fea80b074d739ad08d8765f8dc1817d8c590e2 Mon Sep 17 00:00:00 2001 From: yankai Date: Tue, 22 Sep 2020 15:47:46 +0800 Subject: [PATCH] fix quant --- mindspore/lite/src/ops/add.cc | 6 +++ mindspore/lite/src/ops/conv2d.cc | 2 +- mindspore/lite/src/ops/deconv2d.cc | 2 +- mindspore/lite/src/ops/depthwise_conv2d.cc | 2 +- mindspore/lite/src/ops/matmul.cc | 2 +- mindspore/lite/src/ops/primitive_c.cc | 33 +++++++----- mindspore/lite/src/ops/primitive_c.h | 7 +-- .../lite/tools/anf_exporter/anf_exporter.cc | 54 +++++++++---------- mindspore/lite/tools/common/node_util.cc | 32 +++++------ .../graph/dtype_trans_pass.cc | 7 ++- .../converter/quantizer/calc_quant_param.cc | 11 ++-- 11 files changed, 89 insertions(+), 69 deletions(-) diff --git a/mindspore/lite/src/ops/add.cc b/mindspore/lite/src/ops/add.cc index d01d1e16e3..251ebf4df7 100644 --- a/mindspore/lite/src/ops/add.cc +++ b/mindspore/lite/src/ops/add.cc @@ -46,6 +46,12 @@ int Add::UnPackAttr(const Primitive &prim, const std::vector &inputs return RET_ERROR; } } + if (GetQuantType() == schema::QuantType_AwareTraining) { + std::vector> vecInputQuantParam; + std::vector> vecOutputQuantParam; + PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); + SetOutputQuantParam(vecOutputQuantParam); + } return RET_OK; } diff --git a/mindspore/lite/src/ops/conv2d.cc b/mindspore/lite/src/ops/conv2d.cc index 1dccb681a3..3816ed700a 100644 --- a/mindspore/lite/src/ops/conv2d.cc +++ b/mindspore/lite/src/ops/conv2d.cc @@ -260,7 +260,7 @@ int Conv2D::UnPackAttr(const Primitive &prim, const std::vector &inp if (GetQuantType() == schema::QuantType_AwareTraining) { std::vector> vecInputQuantParam; std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); + PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); SetInputQuantParam(vecInputQuantParam); SetOutputQuantParam(vecOutputQuantParam); } diff --git a/mindspore/lite/src/ops/deconv2d.cc b/mindspore/lite/src/ops/deconv2d.cc index 5dba9d21a1..369c23be60 100644 --- a/mindspore/lite/src/ops/deconv2d.cc +++ b/mindspore/lite/src/ops/deconv2d.cc @@ -130,7 +130,7 @@ int DeConv2D::UnPackAttr(const Primitive &prim, const std::vector &i if (GetQuantType() == schema::QuantType_AwareTraining) { std::vector> vecInputQuantParam; std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); + PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); SetInputQuantParam(vecInputQuantParam); SetOutputQuantParam(vecOutputQuantParam); } diff --git a/mindspore/lite/src/ops/depthwise_conv2d.cc b/mindspore/lite/src/ops/depthwise_conv2d.cc index e5864bb24f..820ebb10c5 100644 --- a/mindspore/lite/src/ops/depthwise_conv2d.cc +++ b/mindspore/lite/src/ops/depthwise_conv2d.cc @@ -140,7 +140,7 @@ int DepthwiseConv2D::UnPackAttr(const Primitive &prim, const std::vector> vecInputQuantParam; std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); + PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); SetInputQuantParam(vecInputQuantParam); SetOutputQuantParam(vecOutputQuantParam); } diff --git a/mindspore/lite/src/ops/matmul.cc b/mindspore/lite/src/ops/matmul.cc index fc6f01e043..62b40f15ea 100644 --- a/mindspore/lite/src/ops/matmul.cc +++ b/mindspore/lite/src/ops/matmul.cc @@ -60,7 +60,7 @@ int MatMul::UnPackAttr(const Primitive &prim, const std::vector &inp if (GetQuantType() == schema::QuantType_AwareTraining) { std::vector> vecInputQuantParam; std::vector> vecOutputQuantParam; - PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam); + PopulaterQuantParam(prim, &vecInputQuantParam, &vecOutputQuantParam, inputs); SetInputQuantParam(vecInputQuantParam); SetOutputQuantParam(vecOutputQuantParam); } diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 7b4f8ce457..135409e2b8 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -158,7 +158,8 @@ void PrimitiveC::CalQuantParam(const double &mean, const double &stdDev, float * void PrimitiveC::PopulaterQuantParam(const Primitive &prim, std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam) { + std::vector> *vecOutputQuantParam, + const std::vector &inputs) { auto narrow_range = prim.GetAttr("narrow_range"); bool narrowRangeQuantParam = GetValue(narrow_range); auto num_bits = prim.GetAttr("num_bits"); @@ -179,12 +180,14 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, } else { auto inputMin = prim.GetAttr("input_minq"); auto inputMax = prim.GetAttr("input_maxq"); - auto inputMinPtr = inputMin->cast(); - auto inputMaxPtr = inputMax->cast(); - float *minBuf = static_cast(inputMinPtr->data_c()); - float *maxBuf = static_cast(inputMaxPtr->data_c()); - quantParam.min = *minBuf; - quantParam.max = *maxBuf; + if (inputMin != nullptr && inputMax != nullptr) { + auto inputMinPtr = inputMin->cast(); + auto inputMaxPtr = inputMax->cast(); + float *minBuf = static_cast(inputMinPtr->data_c()); + float *maxBuf = static_cast(inputMaxPtr->data_c()); + quantParam.min = *minBuf; + quantParam.max = *maxBuf; + } } quant::CalQuantizationParams(&quantParam, quantParam.min, quantParam.max, narrowRangeQuantParam, numbitsRangeQuantParam); @@ -212,13 +215,15 @@ void PrimitiveC::PopulaterQuantParam(const Primitive &prim, vecInputQuantParam->emplace_back(quants); } - quants.clear(); - quantParam.min = 0.0; - quantParam.max = 0.0; - quantParam.zeroPoint = 0; - quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale; - quants.emplace_back(quantParam); - vecInputQuantParam->emplace_back(quants); + if (vecInputQuantParam->size() == kDoubleNum) { + quants.clear(); + quantParam.min = 0.0; + quantParam.max = 0.0; + quantParam.zeroPoint = 0; + quantParam.scale = vecInputQuantParam->at(0).at(0).scale * vecInputQuantParam->at(1).at(0).scale; + quants.emplace_back(quantParam); + vecInputQuantParam->emplace_back(quants); + } quants.clear(); auto outputMin = prim.GetAttr("output_minq"); diff --git a/mindspore/lite/src/ops/primitive_c.h b/mindspore/lite/src/ops/primitive_c.h index cda2659d04..526300eedd 100644 --- a/mindspore/lite/src/ops/primitive_c.h +++ b/mindspore/lite/src/ops/primitive_c.h @@ -39,8 +39,8 @@ constexpr uint32_t kDoubleNum = 2; constexpr uint32_t kMultiNum = 3; constexpr uint32_t kDimension_4d = 4; -const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt32, - kNumberTypeFloat32, kNumberTypeFloat16}; +const std::set kSupportDataType = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeInt32, kNumberTypeFloat32, + kNumberTypeFloat16}; #ifdef PRIMITIVE_WRITEABLE using TensorPtr = std::shared_ptr; @@ -119,7 +119,8 @@ class PrimitiveC : public mindspore::Primitive { static std::shared_ptr Create(const Primitive &prim, const std::vector &inputs, const schema::QuantType &quantType); void PopulaterQuantParam(const Primitive &prim, std::vector> *vecInputQuantParam, - std::vector> *vecOutputQuantParam); + std::vector> *vecOutputQuantParam, + const std::vector &inputs); void CalQuantParam(const double &mean, const double &stdDev, float *mMin, float *mMax); protected: diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index ba7829630c..877b2d5868 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -98,29 +98,28 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me // activation auto input_quant_params = primitive->GetInputQuantParams(); auto node_type = (schema::PrimitiveType)primitive->Type(); - if (input_quant_params.empty()) { - MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty"; - return RET_OK; - } - for (size_t i = 0; i < input_quant_params.size(); i++) { - if (i >= dst_node->inputIndex.size()) { - MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size() - << " quant_params; but only " << dst_node->inputIndex.size() << " input"; - return RET_PARAM_INVALID; - } - auto activate_index = dst_node->inputIndex[i]; - auto tensor_input = meta_graph->allTensors[activate_index].get(); - if (tensor_input->quantParams.empty()) { - for (auto input_quant_param : input_quant_params[i]) { - std::unique_ptr input_quant_param_ptr = - std::make_unique(input_quant_param); - MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale - << " zp: " << input_quant_param_ptr->zeroPoint; - tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); + if (!input_quant_params.empty()) { + for (size_t i = 0; i < input_quant_params.size(); i++) { + if (i >= dst_node->inputIndex.size()) { + MS_LOG(ERROR) << "node: " << dst_node->name << " input has " << input_quant_params.size() + << " quant_params; but only " << dst_node->inputIndex.size() << " input"; + return RET_PARAM_INVALID; + } + auto activate_index = dst_node->inputIndex[i]; + auto tensor_input = meta_graph->allTensors[activate_index].get(); + if (tensor_input->quantParams.empty()) { + for (auto input_quant_param : input_quant_params[i]) { + std::unique_ptr input_quant_param_ptr = + std::make_unique(input_quant_param); + MS_LOG(DEBUG) << "[input][" << i << "]node: " << dst_node->name << " scale: " << input_quant_param_ptr->scale + << " zp: " << input_quant_param_ptr->zeroPoint; + tensor_input->quantParams.emplace_back(std::move(input_quant_param_ptr)); + } } } + } else { + MS_LOG(DEBUG) << "node: " << dst_node->name << " input quant params is empty"; } - // output auto output_index = dst_node->outputIndex[0]; auto tensor_output = meta_graph->allTensors[output_index].get(); @@ -171,7 +170,7 @@ void AnfExporter::SetGraphInputIndex(const std::unique_ptr & } int AnfExporter::SetGraphoutputIndex(const CNodePtr &cnode, const std::unique_ptr &meta_graphT, - schema::CNodeT *return_node) { + schema::CNodeT *return_node) { MS_ASSERT(nullptr != meta_graph); MS_ASSERT(nullptr != return_node); for (size_t i = 1; i < cnode->inputs().size(); i++) { @@ -210,9 +209,9 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &func_graph, bool kee if (primitive_c->Type() == schema::PrimitiveType_TupleGetItem || primitive_c->Type() == schema::PrimitiveType_MakeTuple #ifdef SUPPORT_TRAIN - || primitive_c->Type() == schema::PrimitiveType_Depend + || primitive_c->Type() == schema::PrimitiveType_Depend #endif - ) { + ) { continue; } RemoveIfMakeTuple(cnode); @@ -403,8 +402,7 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr input_anode, if (value_track->isa()) { shape.push_back((GetValue(value_track))); } else { - MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " - << value_track->ToString() << "."; + MS_LOG(ERROR) << "Value type is ValueSequence is not integer, it is " << value_track->ToString() << "."; } } if (shape.size()) { @@ -417,10 +415,10 @@ int AnfExporter::ConvertInputValueNode(std::shared_ptr input_anode, node_id_map_[valueNode->fullname_with_scope()] = meta_graphT->allTensors.size(); output_cnode->inputIndex.emplace_back(meta_graphT->allTensors.size()); meta_graphT->allTensors.emplace_back(std::move(paramTensor)); - } - } else { - MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << "."; } + } else { + MS_LOG(ERROR) << "Value type is ValueSequence not supported - " << valueAbstract->type_name() << "."; + } #endif } else if (value->isa()) { MS_LOG(INFO) << "Value is a number."; diff --git a/mindspore/lite/tools/common/node_util.cc b/mindspore/lite/tools/common/node_util.cc index 402dbea35b..38bc85d7b6 100644 --- a/mindspore/lite/tools/common/node_util.cc +++ b/mindspore/lite/tools/common/node_util.cc @@ -54,8 +54,7 @@ static const std::vector nhwcOpDualInputList = { static const std::vector nhwcOpAllInputList = { #ifdef SUPPORT_TRAIN - schema::PrimitiveType_PoolingGrad, - schema::PrimitiveType_ActivationGrad + schema::PrimitiveType_PoolingGrad, schema::PrimitiveType_ActivationGrad #endif }; @@ -66,20 +65,21 @@ static const std::vector fp32FullOpList = { static const std::vector int8NeedNhwcOpList = {}; static const std::vector int8OpList = { - schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, - schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, - schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, - schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, - schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation, - schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection, - schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin, - schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm, - schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div, - schema::PrimitiveType_Mul, schema::PrimitiveType_Slice, - schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split, - schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub, - schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze, - schema::PrimitiveType_MatMul, schema::PrimitiveType_Pad}; + schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, + schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, + schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, + schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, + schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation, + schema::PrimitiveType_Resize, schema::PrimitiveType_FullConnection, + schema::PrimitiveType_ArgMax, schema::PrimitiveType_ArgMin, + schema::PrimitiveType_BatchNorm, schema::PrimitiveType_FusedBatchNorm, + schema::PrimitiveType_BiasAdd, schema::PrimitiveType_Div, + schema::PrimitiveType_Mul, schema::PrimitiveType_Slice, + schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split, + schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub, + schema::PrimitiveType_StridedSlice, schema::PrimitiveType_TopK, + schema::PrimitiveType_Unsqueeze, schema::PrimitiveType_MatMul, + schema::PrimitiveType_Pad}; static const std::vector needInsertOpList = { schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat, diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index f738512c98..b84d57d6df 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -16,6 +16,7 @@ #include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h" #include +#include #include "tools/common/converter_op_utils.h" #include "tools/common/node_util.h" #include "src/common/common.h" @@ -26,6 +27,9 @@ namespace lite { #define kMinInputNum 1 #define kOutputNum 1 +static const std::set NoNeedDtypeTransList = { + PrimitiveType_QuantDTypeCast, PrimitiveType_Nchw2Nhwc, PrimitiveType_Nhwc2Nchw}; + STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) { MS_ASSERT(graph != nullptr); @@ -134,7 +138,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { if (IsContain(GetInt8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { continue; } - if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) { + auto iterType = GetCNodeTType(**iter); + if (NoNeedDtypeTransList.find(iterType) != NoNeedDtypeTransList.end()) { continue; } bool needInsertPost = true; diff --git a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc index 1c85cc5788..8440053ba0 100644 --- a/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc +++ b/mindspore/lite/tools/converter/quantizer/calc_quant_param.cc @@ -167,7 +167,11 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) { auto &outTensor = graph->allTensors.at(node.outputIndex.at(i)); MS_ASSERT(outTensor != nullptr); auto outQuantParam = GetTensorQuantParam(outTensor); - if (outQuantParam == nullptr || outQuantParam->inited) { + if (outQuantParam == nullptr) { + outTensor->quantParams.emplace_back(std::move(inQuantParam)); + continue; + } + if (outQuantParam->inited) { continue; } outTensor->quantParams.front() = std::move(inQuantParam); @@ -232,7 +236,7 @@ class CalcConcat : public QuantParamCalcer { MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!"; return RET_ERROR; } - outTensor->quantParams.front() = std::move(outQuantParam); + outTensor->quantParams.emplace_back(std::move(outQuantParam)); outputParamDone++; } @@ -417,7 +421,7 @@ class CalcToSet : public QuantParamCalcer { MS_ASSERT(graph->allTensors.size() > node.outputIndex.front()); auto &outTensor = graph->allTensors.at(node.outputIndex.front()); MS_ASSERT(outTensor != nullptr); - outTensor->quantParams.front() = std::move(quantParam); + outTensor->quantParams.emplace_back(std::move(quantParam)); outputParamDone++; } return RET_OK; @@ -475,6 +479,7 @@ QuantParamCalcRegister::QuantParamCalcRegister() { _registerMap[schema::PrimitiveType_Pooling] = linearCalcer; _registerMap[schema::PrimitiveType_Resize] = linearCalcer; _registerMap[schema::PrimitiveType_Reshape] = linearCalcer; + _registerMap[schema::PrimitiveType_StridedSlice] = linearCalcer; _registerMap[schema::PrimitiveType_Shape] = linearCalcer; _registerMap[schema::PrimitiveType_SoftMax] = std::make_shared(0, 1); _registerMap[schema::PrimitiveType_Squeeze] = linearCalcer;