diff --git a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc index ec708411a2..d9653a9e0d 100644 --- a/mindspore/lite/src/common/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/src/common/anf_exporter/anf_exporter.cc @@ -177,31 +177,44 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) { if (node->quantType == schema::QuantType_PostTraining) { MS_LOG(INFO) << "node: " << node->name << " add QuantParam"; // activation - auto activate_index = node->inputIndex[0]; - auto tensor_input = metaGraphT->allTensors[activate_index].get(); auto input_quant_params = primitiveT_value->GetInputQuantParams(); - if (input_quant_params.empty()) { - MS_LOG(WARNING) << "node: " << node->name - << " input quant params is empty"; - } else { + auto node_type = primitiveT_value->GetPrimitiveT()->value.type; + for (int i = 0; i < input_quant_params.size(); i++) { + if (i >= node->inputIndex.size()) { + MS_LOG(ERROR) << "node: " << node->name << " input has " << input_quant_params.size() + << " quant_params; but only " << node->inputIndex.size() << " input"; + break; + } + auto activate_index = node->inputIndex[i]; + auto tensor_input = metaGraphT->allTensors[activate_index].get(); std::unique_ptr input_quant_param = - std::make_unique(input_quant_params[0]); + std::make_unique(input_quant_params[i]); + MS_LOG(DEBUG) << "[input]node: " << node->name << " scale: " << input_quant_param->scale + << " zp: " << input_quant_param->zeroPoint; tensor_input->quantParams.emplace_back(std::move(input_quant_param)); + if (!(node_type == schema::PrimitiveType_QuantDTypeCast && + primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->srcT == kNumberTypeFloat32)) { + tensor_input->dataType = kNumberTypeInt8; + } } - tensor_input->dataType = kNumberTypeInt8; + // output auto output_index = node->outputIndex[0]; auto tensor_output = metaGraphT->allTensors[output_index].get(); auto output_quant_params = primitiveT_value->GetOutputQuantParams(); if (output_quant_params.empty()) { - MS_LOG(WARNING) << "node: " << node->name - << " output quant params is empty"; + MS_LOG(WARNING) << "node: " << node->name << " output quant params is empty"; } else { std::unique_ptr output_quant_param = - std::make_unique(output_quant_params[0]); + std::make_unique(output_quant_params[0]); + MS_LOG(DEBUG) << "[output]node: " << node->name << " scale: " << output_quant_param->scale + << " zp: " << output_quant_param->zeroPoint; tensor_output->quantParams.emplace_back(std::move(output_quant_param)); } - tensor_output->dataType = kNumberTypeInt8; + if (!(node_type == schema::PrimitiveType_QuantDTypeCast && + primitiveT_value->GetPrimitiveT()->value.AsQuantDTypeCast()->dstT == kNumberTypeFloat32)) { + tensor_output->dataType = kNumberTypeInt8; + } // // TensorType // valuePtr = primitive->GetAttr(kInputTensorDataType); // if (valuePtr != nullptr) { diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index aa402415e7..20bb8519d7 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -64,6 +64,16 @@ int LiteSession::ConvertTensors(const lite::Model *model) { // no copy data, do copy when call LiteKernel::Init dstTensor->SetData(const_cast(srcTensor->data()->data())); } + auto quant_params = srcTensor->quantParams(); + if (quant_params != nullptr) { + for (int j = 0; j < quant_params->size(); j++) { + tensor::QuantArg quant_arg{}; + quant_arg.scale = quant_params->Get(j)->scale(); + quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); + dstTensor->AddQuantParam(quant_arg); + } + } + this->tensors.emplace_back(dstTensor); } return RET_OK; diff --git a/mindspore/lite/src/ops/quant_dtype_cast.cc b/mindspore/lite/src/ops/quant_dtype_cast.cc index 93855a89c4..82a9fa6548 100644 --- a/mindspore/lite/src/ops/quant_dtype_cast.cc +++ b/mindspore/lite/src/ops/quant_dtype_cast.cc @@ -30,6 +30,7 @@ int QuantDTypeCast::InferShape(std::vector inputs_, std::vecto auto param = primitive->value_as_QuantDTypeCast(); MS_ASSERT(input->data_type() == param->srcT); output->set_data_type(static_cast(param->dstT())); + output->SetFormat(input->GetFormat()); return RET_OK; } } // namespace mindspore::lite diff --git a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc index 2ecf52ee12..0e49e78b4c 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/quant_dtype_cast.cc @@ -62,7 +62,7 @@ int QuantDTypeCastCPUKernel::Init() { } inverse_ = true; } else { - MS_LOG(ERROR) << "param data type not supported."; + MS_LOG(ERROR) << "param data type not supported:" << " src: " << param->srcT << " dst: " << param->dstT; return RET_ERROR; } @@ -148,7 +148,6 @@ kernel::LiteKernel *CpuQuantDTypeCastFp32KernelCreator(const std::vector 127) { + quant_values[i] = 127; + } else if (temp < -128) { + quant_values[i] = -128; + } else { + quant_values[i] = (int8_t)temp; + } } return NNACL_OK; } diff --git a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc index a79072281e..83db1e2b02 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/node/weight_format_pass.cc @@ -166,6 +166,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) { return -1; } } + MS_LOG(DEBUG) << "weight_tensor_format: " << weightTensor->format; return 0; } else if (fmkType == converter::FmkType_ONNX) { switch (node->quantType) { @@ -217,7 +218,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { auto opType = node->primitive->value.type; if (opType != schema::PrimitiveType_Conv2D && opType != schema::PrimitiveType_DepthwiseConv2D && opType != schema::PrimitiveType_DeConv2D && opType != schema::PrimitiveType_DeDepthwiseConv2D) { - return 0; + return RET_OK; } MS_ASSERT(node->inputIndex.size() >= 2); @@ -225,7 +226,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { MS_ASSERT(subGraph->allTensors.size() > weightIndex); auto &weightTensor = subGraph->allTensors[weightIndex]; MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT - STATUS status; + STATUS status = RET_OK; if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK if (weightTensor->format == schema::Format_KCHW) { // from caffe if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { @@ -238,11 +239,12 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { status = TransFilterFormat(weightTensor.get(), kKCHW2HWCK); } } else if (weightTensor->format == schema::Format_KHWC) { // from onnx - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { - status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); - } else { - status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); - } + return RET_OK; +// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { +// status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); +// } else { +// status = TransFilterFormat(weightTensor.get(), kKHWC2HWCK); +// } } else if (weightTensor->format == schema::Format_HWCK) { // from tf return 0; } else { @@ -273,8 +275,8 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) { } else if (weightTensor->format == schema::Format_HWCK) { // from tf return 0; } else if (weightTensor->format == schema::Format_CHWK) { // from onnx - if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { - status = TransFilterFormat(weightTensor.get(), kCHWK2HWCK); + if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) { + status = TransFilterFormat(weightTensor.get(), kCHWK2KHWC); } else { status = TransFilterFormat(weightTensor.get(), kCHWK2HWCK); } diff --git a/mindspore/lite/tools/converter/quantizer/post_training.cc b/mindspore/lite/tools/converter/quantizer/post_training.cc index fe4e16bdaa..bf149ff9cf 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training.cc @@ -54,7 +54,7 @@ struct DivergInfo { size_t bit_num; int quant_max = 255; int quant_min = 0; - DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max = 255, int quant_min = 0) { + DivergInfo(CNodePtr cnode, int bins, size_t bits, int quant_max, int quant_min) { this->cnode = cnode; this->bin_num = bins; this->bit_num = bits; @@ -81,6 +81,9 @@ struct DivergInfo { STATUS UpdateHistogram(const std::vector &data, const std::vector &shape) { for (auto value : data) { + if (value == 0) { + continue; + } int bin_index = std::min(static_cast(std::fabs(value) / this->interval), bin_num - 1); this->histogram[bin_index]++; } @@ -470,8 +473,10 @@ STATUS Calibrator::ReadConfig() { Calibrator::Calibrator(string path, size_t bitNum, int quantMax, int quantMin) : config_path_(path), bit_num_(bitNum), quant_max_(quantMax), quant_min_(quantMin) {} -PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type) +PostTrainingQuantizer::PostTrainingQuantizer(FuncGraphPtr graph, string path, int bit_num, TypeId target_type, + bool per_channel) : Quantizer(graph) { + this->per_channel_ = per_channel; this->bit_num = bit_num; this->target_type_ = target_type; if (target_type == kNumberTypeInt8) { @@ -533,7 +538,7 @@ STATUS PostTrainingQuantizer::DoWeightQuant(AnfNodePtr node) { } auto parameter = std::dynamic_pointer_cast(node); ParamValueLitePtr paramValue = std::dynamic_pointer_cast(parameter->default_param()); - auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num); + auto status = QuantFilter(paramValue, QuantType_PostTraining, quant_max, quant_min, bit_num, per_channel_); if (status != RET_OK) { MS_LOG(ERROR) << "QuantFilter failed: " << status; return status; @@ -670,18 +675,32 @@ STATUS PostTrainingQuantizer::QuantNode() { MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; continue; } - if (input_scale.find(cnode) == input_scale.end()) { primitiveT_value->SetQuantType(schema::QuantType_QUANT_NONE); continue; } auto input_vec = cnode->inputs(); auto op_name = cnode->fullname_with_scope(); + auto op_type = primitiveT_value->GetPrimitiveT()->value.type; MS_LOG(INFO) << "OpName: " << op_name; - if (input_vec.size() <= 3 && op_name != "Conv2D" && op_name != "DepthwiseConv2D") { - MS_LOG(INFO) << "todo(x): "; - // int32_t qnodeOutputZeropoint = outputZeropoint[cnode]; - // p->AddAttr(kInputTensorDataType, MakeValue((int)targetType)); + if (op_type != PrimitiveType_Conv2D && op_type != PrimitiveType_DepthwiseConv2D) { + for (auto i = 1; i < cnode->inputs().size(); i++) { + auto input_node = cnode->input(i); + if (!input_node->isa()) { + MS_LOG(WARNING) << "node: " << cnode_name << " input " << i << " not a cnode"; + continue; + } + auto input_cnode = std::dynamic_pointer_cast(input_node); + auto input_cnode_primitiveT_value = GetValueNode>(input_cnode->input(0)); + if (input_cnode_primitiveT_value == nullptr) { + MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " + << " PrimitiveTValue is null"; + continue; + } + for (auto &quant_param : input_cnode_primitiveT_value->GetOutputQuantParams()) { + primitiveT_value->AddInputQuantParam(quant_param); + } + } } else { // do input quant double scale = input_scale[cnode]; diff --git a/mindspore/lite/tools/converter/quantizer/post_training.h b/mindspore/lite/tools/converter/quantizer/post_training.h index d000df53ef..06273396b8 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training.h +++ b/mindspore/lite/tools/converter/quantizer/post_training.h @@ -55,15 +55,18 @@ struct ConfigParam { class PostTrainingQuantizer : public Quantizer { public: - PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8); + PostTrainingQuantizer(FuncGraphPtr graph, std::string path, int bit_num, TypeId target_type = kNumberTypeInt8, + bool per_channel = false); STATUS DoQuantize(FuncGraphPtr funcGraph) override; size_t bit_num; - int quant_max{255}; - int quant_min{0}; + int quant_max{127}; + int quant_min{-128}; private: + bool per_channel_; + TypeId target_type_{kNumberTypeInt8}; std::unique_ptr calibrator_; diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.cc b/mindspore/lite/tools/converter/quantizer/quant_cast.cc index 0d4857c182..f924cda6ef 100644 --- a/mindspore/lite/tools/converter/quantizer/quant_cast.cc +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.cc @@ -25,10 +25,11 @@ namespace mindspore::lite::quant { ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type, const std::vector &quant_params) { std::unique_ptr primitive = std::make_unique(); schema::QuantDTypeCastT quant_dtype_cast; - quant_dtype_cast.srcT = src_type; // kNumberTypeUInt8; + quant_dtype_cast.srcT = src_type; // kNumberTypeInt8; quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32; primitive->value.Set(quant_dtype_cast); auto primTValue = std::make_shared(primitive.release()); + primTValue->SetQuantType(schema::QuantType_PostTraining); for (auto &quant_param : quant_params) { primTValue->AddInputQuantParam(quant_param); } @@ -52,7 +53,7 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { if (first) { if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { auto value_node = - NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, primitiveT_value->GetInputQuantParams()); + NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, primitiveT_value->GetInputQuantParams()); std::vector op_inputs = {value_node, cnode->input(1)}; auto quant_cast_cnode = graph->NewCNode(op_inputs); quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); @@ -82,11 +83,11 @@ STATUS QuantCast::Run(FuncGraphPtr graph) { ValueNodePtr value_node = nullptr; if (curnode_quant_type == schema::QuantType_PostTraining && input_cnode_quant_type == schema::QuantType_QUANT_NONE) { - value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8, - input_cnode_primitiveT_value->GetInputQuantParams()); + value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeInt8, + primitiveT_value->GetInputQuantParams()); } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && input_cnode_quant_type == schema::QuantType_PostTraining) { - value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32, + value_node = NewQuantCastValueNode(kNumberTypeInt8, kNumberTypeFloat32, input_cnode_primitiveT_value->GetInputQuantParams()); } if (value_node == nullptr) { diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 0e0ba14f93..151bdacd52 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -98,7 +98,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { static const std::vector uint8OpList = { schema::PrimitiveType_Nchw2Nhwc, schema::PrimitiveType_Nhwc2Nchw, schema::PrimitiveType_Conv2D, schema::PrimitiveType_DepthwiseConv2D, schema::PrimitiveType_Add, schema::PrimitiveType_Pooling, - schema::PrimitiveType_Concat, schema::PrimitiveType_SoftMax, schema::PrimitiveType_Reshape, + schema::PrimitiveType_Concat, /*schema::PrimitiveType_SoftMax,*/ schema::PrimitiveType_Reshape, schema::PrimitiveType_Activation}; return IsContain(uint8OpList, type); } @@ -242,64 +242,122 @@ STATUS CalQuantizationParams(std::unique_ptr &quantParam, double return RET_OK; } -STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum) { +STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, size_t bitNum, + bool per_channel) { + if (per_channel) { + // per channel auto dims = weightPtr->tensor_shape(); if (dims.size() < 1) { - MS_LOG(ERROR) << "weight dims size error"; - return RET_ERROR; + MS_LOG(ERROR) << "weight dims size error"; + return RET_ERROR; } - uint32_t channels = dims[0]; + // todo(x) + uint32_t channels = dims[3]; if (channels == 0) { - MS_LOG(ERROR) << "channels error 0"; - return RET_ERROR; + MS_LOG(ERROR) << "channels error 0"; + return RET_ERROR; } size_t shapeSize = weightPtr->tensor_shape_size(); size_t oneFilterSize = shapeSize / channels; auto *rawDatas = reinterpret_cast(weightPtr->tensor_addr()); if (rawDatas == nullptr) { - MS_LOG(ERROR) << "rawDatas is nullptr"; - return RET_ERROR; + MS_LOG(ERROR) << "rawDatas is nullptr"; + return RET_ERROR; } weightPtr->quant_param().clear(); - vector qDatas(shapeSize); + vector qDatas(shapeSize); for (uint32_t i = 0; i < channels; i++) { - float min = 0; - float max = 0; - // find min and max - for (uint32_t j = 0; j < oneFilterSize; j++) { - min = std::min(min, rawDatas[j + i * oneFilterSize]); - max = std::max(max, rawDatas[j + i * oneFilterSize]); - } + float min = 0; + float max = 0; + // find min and max + for (uint32_t j = 0; j < oneFilterSize; j++) { + min = std::min(min, rawDatas[j + i * oneFilterSize]); + max = std::max(max, rawDatas[j + i * oneFilterSize]); + } + + std::unique_ptr quantParam = std::unique_ptr(new AnfQuantParam); + STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalQuantizationParams failed" << status; + return status; + } + // update data and datatype + for (uint32_t j = 0; j < oneFilterSize; j++) { + float rawData = rawDatas[j + i * oneFilterSize]; + auto qData = QuantizeData(rawData, quantParam.get(), quant_max, quant_min); + qDatas[j + i * oneFilterSize] = qData; + } + + weightPtr->set_quant_param(quantParam); + } + auto ret = memcpy_s(const_cast(rawDatas), weightPtr->tensor_size(), + qDatas.data(), shapeSize * sizeof(int8_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "memcpy error: " << ret; + return RET_ERROR; + } + if (quantType == QuantType_WeightQuant) { + PostBitPack(const_cast(rawDatas), shapeSize, bitNum); + } - std::unique_ptr quantParam = std::unique_ptr(new AnfQuantParam); - STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); - if (status != RET_OK) { - MS_LOG(ERROR) << "CalQuantizationParams failed" << status; - return status; - } - // update data and datatype - for (uint32_t j = 0; j < oneFilterSize; j++) { - float rawData = rawDatas[j + i * oneFilterSize]; - auto qData = QuantizeData(rawData, quantParam.get()); - qDatas[j + i * oneFilterSize] = qData; - } + weightPtr->set_tensor_type(kNumberTypeInt8); + weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); + } else { + // per layer + size_t shapeSize = weightPtr->tensor_shape_size(); + auto *rawDatas = static_cast(weightPtr->tensor_addr()); + if (rawDatas == nullptr) { + MS_LOG(ERROR) << "rawDatas is nullptr"; + return RET_ERROR; + } - weightPtr->set_quant_param(quantParam); + weightPtr->quant_param().clear(); + vector qDatas(shapeSize); + + float min = 0; + float max = 0; + for (uint32_t i = 0; i < shapeSize; i++) { + // find max min + min = std::min(min, rawDatas[i]); + max = std::max(max, rawDatas[i]); } - auto ret = memcpy_s(const_cast(rawDatas), weightPtr->tensor_size(), - qDatas.data(), shapeSize * sizeof(uint8_t)); + + std::unique_ptr quantParam = std::unique_ptr(new AnfQuantParam); + STATUS status = CalQuantizationParams(quantParam, min, max, false, quant_max, quant_min, bitNum); + if (status != RET_OK) { + MS_LOG(ERROR) << "CalQuantizationParams failed" << status; + return status; + } + // update data and datatype + for (uint32_t i = 0; i < shapeSize; i++) { + float rawData = rawDatas[i]; + auto quant_data = std::round(rawData / quantParam->scale + quantParam->zeroPoint); + if (quant_data > quant_max) { + qDatas[i] = quant_max; + } else if (quant_data < quant_min) { + qDatas[i] = quant_min; + } else { + qDatas[i] = static_cast(quant_data); + } + } + + weightPtr->set_quant_param(quantParam); + auto ret = memcpy_s(rawDatas, weightPtr->tensor_size() * sizeof(int8_t), + qDatas.data(), shapeSize * sizeof(int8_t)); if (ret != EOK) { - MS_LOG(ERROR) << "memcpy error: " << ret; - return RET_ERROR; + MS_LOG(ERROR) << "memcpy error: " << ret; + return RET_ERROR; } if (quantType == QuantType_WeightQuant) { - PostBitPack(const_cast(rawDatas), shapeSize, bitNum); + PostBitPack(rawDatas, shapeSize, bitNum); } weightPtr->set_tensor_type(kNumberTypeInt8); weightPtr->set_tensor_size(shapeSize * sizeof(int8_t)); + } + return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index b310d586ae..a287473458 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -63,41 +63,30 @@ STATUS CalQuantizationParams(std::unique_ptr &quantParam, double bool narrowRange, int quant_max, int quant_min, int num_bits); template -T QuantizeData(const float originData, const AnfQuantParam *quantParam) { +T QuantizeData(float originData, const AnfQuantParam *quantParam, int quant_max, int quant_min) { MS_ASSERT(quantParam != nullptr); MS_ASSERT(quantParam->inited); const auto scale = quantParam->scale; - const auto zeroPoint = quantParam->zeroPoint; - const auto numBit = quantParam->numBits; + const int zeroPoint = quantParam->zeroPoint; const auto narrowRange = quantParam->narrowRange; - const double maxLimit = static_cast((1 << (unsigned int)numBit) - 1 - zeroPoint) * scale; - double minLimit; - if (narrowRange) { - minLimit = static_cast(1 - zeroPoint) * scale; - } else { - minLimit = static_cast(0 - zeroPoint) * scale; - } + const int maxLimit = quant_max; + const int minLimit = quant_min; + return [maxLimit, minLimit, zeroPoint, scale, narrowRange, originData] { - double tmp = 0.0f; - if (originData > maxLimit) { - tmp = maxLimit; - } else if (originData < minLimit) { - tmp = minLimit; - } else { - tmp = originData; - } - auto quantData = static_cast(std::round(tmp / scale + zeroPoint)); - if (quantData == 0 && narrowRange) { - quantData++; + int quant_data = std::round(originData / scale + zeroPoint); + if (quant_data > maxLimit) { + quant_data = maxLimit; + } else if (quant_data < minLimit) { + quant_data = minLimit; } - return quantData; + return static_cast(quant_data); }(); } void CalFakeNode(const AnfNodePtr &inTensor); STATUS QuantFilter(ParamValueLitePtr &weightPtr, QuantType quantType, int quant_max, int quant_min, - size_t bitNum = UINT8_QUANTIZATION); + size_t bitNum = UINT8_QUANTIZATION, bool per_channel = false); STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION);