From 89dbcab9f25fb17c9ae172907d35c610bf15aeaf Mon Sep 17 00:00:00 2001 From: guohongzilong <2713219276@qq.com> Date: Mon, 2 Nov 2020 17:21:23 +0800 Subject: [PATCH] mix bit pack and unpack --- mindspore/lite/src/lite_session.cc | 16 +++- .../src/runtime/kernel/arm/base/dequant.cc | 15 ++++ .../src/runtime/kernel/arm/base/dequant.h | 60 +++++++++++++ mindspore/lite/src/tensor.h | 1 + .../tools/converter/quantizer/CMakeLists.txt | 1 - .../tools/converter/quantizer/bitpacking.h | 74 ++++++++++++++++ .../converter/quantizer/general_bitpacking.cc | 84 ------------------- .../converter/quantizer/general_bitpacking.h | 43 ---------- .../converter/quantizer/quantize_util.cc | 26 +----- .../tools/converter/quantizer/quantize_util.h | 31 ++++++- .../converter/quantizer/weight_quantizer.cc | 36 ++++---- 11 files changed, 209 insertions(+), 178 deletions(-) create mode 100644 mindspore/lite/tools/converter/quantizer/bitpacking.h delete mode 100644 mindspore/lite/tools/converter/quantizer/general_bitpacking.cc delete mode 100644 mindspore/lite/tools/converter/quantizer/general_bitpacking.h diff --git a/mindspore/lite/src/lite_session.cc b/mindspore/lite/src/lite_session.cc index 5658efd3cb..105dc5d52b 100644 --- a/mindspore/lite/src/lite_session.cc +++ b/mindspore/lite/src/lite_session.cc @@ -27,6 +27,7 @@ #include "src/common/graph_util.h" #include "src/kernel_registry.h" #include "src/model_common.h" +#include "mindspore/lite/src/runtime/kernel/arm/base/dequant.h" namespace mindspore { namespace lite { @@ -95,13 +96,26 @@ int LiteSession::ConvertTensors(const lite::Model *model) { memcpy(dst_data, srcTensor->data()->data(), dstTensor->Size()); copyed_tensor_idxes_.emplace_back(i); } else { - dstTensor->set_data(const_cast(srcTensor->data()->data())); + int pack_size = srcTensor->data()->size(); + int org_size = dstTensor->Size(); + if (pack_size != org_size && (dataType == kNumberTypeInt8 || dataType == kNumberTypeInt16)) { + auto ret = dstTensor->MallocData(); + if (ret != RET_OK) { + MS_LOG(ERROR) << "Malloc data for " << i << "tensor failed "; + delete dstTensor; + return RET_ERROR; + } + kernel::DequantUtil::UnPackToInt(srcTensor, dstTensor->MutableData()); + } else { + dstTensor->set_data(const_cast(srcTensor->data()->data())); + } } } auto quant_params = srcTensor->quantParams(); if (quant_params != nullptr) { for (size_t j = 0; j < quant_params->size(); j++) { QuantArg quant_arg{}; + quant_arg.bitNum = quant_params->Get(j)->numBits(); quant_arg.scale = quant_params->Get(j)->scale(); quant_arg.zeroPoint = quant_params->Get(j)->zeroPoint(); quant_arg.var_corr = quant_params->Get(j)->varCorr(); diff --git a/mindspore/lite/src/runtime/kernel/arm/base/dequant.cc b/mindspore/lite/src/runtime/kernel/arm/base/dequant.cc index f90dbad0d1..32a8c2e776 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/dequant.cc +++ b/mindspore/lite/src/runtime/kernel/arm/base/dequant.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include "src/runtime/kernel/arm/base/dequant.h" namespace mindspore::kernel { @@ -32,4 +33,18 @@ float *DequantUtil::DequantWeight(lite::Tensor *input_tensor) { return DequantData(input_tensor); } } + +void DequantUtil::UnPackToInt(const schema::Tensor *input_tensor, void *unpack_int_data) { + auto quant_params = input_tensor->quantParams(); + if (quant_params == nullptr) { + MS_LOG(ERROR) << "low bits quantparams is empty."; + return; + } + int origin_bit = quant_params->Get(0)->numBits(); + if (origin_bit < 8 && origin_bit > 0) { + UnPackUtil(input_tensor, origin_bit, unpack_int_data); + } else if (origin_bit < 16 && origin_bit > 8) { + UnPackUtil(input_tensor, origin_bit, unpack_int_data); + } +} } // namespace mindspore::kernel diff --git a/mindspore/lite/src/runtime/kernel/arm/base/dequant.h b/mindspore/lite/src/runtime/kernel/arm/base/dequant.h index 934ee09cb5..b58bf36936 100644 --- a/mindspore/lite/src/runtime/kernel/arm/base/dequant.h +++ b/mindspore/lite/src/runtime/kernel/arm/base/dequant.h @@ -18,6 +18,8 @@ #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_BASE_DEQUANT_H_ #include +#include +#include #include "src/lite_kernel.h" #include "src/common/utils.h" #include "src/tensor.h" @@ -27,6 +29,8 @@ class DequantUtil { public: static float *DequantWeight(lite::Tensor *input_tensor); + static void UnPackToInt(const schema::Tensor *input_tensor, void *weight_unpack_data); + template static float *DequantData(lite::Tensor *input_tensor) { const auto *quant_datas = static_cast(input_tensor->MutableData()); @@ -98,6 +102,62 @@ class DequantUtil { } return dequant_datas; } + + private: + template + static void UnPackData(int origin_bit, const T2 &packed_data, std::queue *unpack_bit_data, void *unpack_int, + size_t *count, bool is_last) { + T2 uint_result = 0; + T1 result = 0; + UnPackFromUintToOrigin(packed_data, unpack_bit_data); + while (static_cast(unpack_bit_data->size()) >= origin_bit) { + for (int k = 0; k < origin_bit; k++) { + bool bit_tmp = unpack_bit_data->front(); + uint_result = (static_cast(bit_tmp) << k) + uint_result; + unpack_bit_data->pop(); + } + result = uint_result - static_cast(pow(2, origin_bit - 1)); + (static_cast(unpack_int))[*count] = result; + uint_result = 0; + (*count)++; + } + if (is_last) { + int remainder = unpack_bit_data->size(); + for (int i = 0; i < remainder; i++) { + bool bit = unpack_bit_data->front(); + uint_result = (static_cast(bit) << i) + uint_result; + unpack_bit_data->pop(); + } + result = static_cast(uint_result - static_cast(pow(2, origin_bit - 1))); + (static_cast(unpack_int))[*count] = result; + } + } + + template + static void UnPackUtil(const schema::Tensor *input_tensor, int origin_bit, void *unpack_int_data) { + auto weight_data = input_tensor->data()->data(); + int pack_size = + input_tensor->dataType() == kNumberTypeInt8 ? input_tensor->data()->size() : input_tensor->data()->size() / 2; + std::queue unpack_bit_data; + size_t count = 0; + for (int i = 0; i < pack_size; ++i) { + T2 pack_data = (static_cast(static_cast(weight_data)))[i]; + bool is_last = i == pack_size - 1; + UnPackData(origin_bit, pack_data, &unpack_bit_data, unpack_int_data, &count, is_last); + } + } + + template + static void UnPackFromUintToOrigin(const T2 &packed_data, std::queue *unpack_bit_data) { + auto n = packed_data; + size_t bit_count = 0; + while (bit_count < sizeof(T2) * 8) { + bool a = n % 2; + n = n >> 1; + bit_count++; + unpack_bit_data->push(a); + } + } }; } // namespace mindspore::kernel diff --git a/mindspore/lite/src/tensor.h b/mindspore/lite/src/tensor.h index 93bfe75ec3..d67ee93183 100644 --- a/mindspore/lite/src/tensor.h +++ b/mindspore/lite/src/tensor.h @@ -37,6 +37,7 @@ struct QuantArg { double mean_corr{0}; bool inited; std::vector clusters{}; + int bitNum; }; class Tensor : public mindspore::tensor::MSTensor { diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt index ba8d4ebbeb..33d052708b 100644 --- a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -8,7 +8,6 @@ file(GLOB QUANTIZER ${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc - ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc ${CMAKE_CURRENT_SOURCE_DIR}/post_training_quantizer.cc ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc ${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc diff --git a/mindspore/lite/tools/converter/quantizer/bitpacking.h b/mindspore/lite/tools/converter/quantizer/bitpacking.h new file mode 100644 index 0000000000..c93a09cab6 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/bitpacking.h @@ -0,0 +1,74 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__GENERAL_BITPACKING_H +#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__GENERAL_BITPACKING_H +#include +#include +#include +#include +#include + +namespace mindspore { +namespace lite { +class BitPack { + public: + ~BitPack() = default; + + template + static void BitPacking(int bit_num, const std::vector &origin_data_vec, std::vector *packed_data_vec) { + std::stack bit_data_vec; + for (size_t i = 0; i < origin_data_vec.size(); i++) { + T2 tmp = origin_data_vec[i] + static_cast(pow(2, bit_num - 1)); + DoBinary(bit_num, tmp, &bit_data_vec, packed_data_vec); + } + size_t remain_bit_data = bit_data_vec.size(); + if (sizeof(T1) * 8 > remain_bit_data && remain_bit_data > 0) { + for (size_t i = 0; i < sizeof(T1) * 8 - remain_bit_data; i++) { + bit_data_vec.push(0); + } + PackFromOriginToUint(&bit_data_vec, packed_data_vec); + } + } + + private: + template + static void PackFromOriginToUint(std::stack *ans, std::vector *packed_data_vec) { + uint32_t result = 0; + for (size_t i = 0; i < sizeof(T2) * 8; i++) { + bool bit_tmp = ans->top(); + result = (result << 1) + static_cast(bit_tmp); + ans->pop(); + } + packed_data_vec->push_back(result); + } + + template + static void DoBinary(int bin_num, T2 n, std::stack *ans, std::vector *packed_data_vec) { + for (int bit_count = 0; bit_count < bin_num; bit_count++) { + bool a = n % 2; + n = n / 2; + ans->push(a); + if (ans->size() == sizeof(T2) * 8) { + PackFromOriginToUint(ans, packed_data_vec); + } + } + } +}; +} // namespace lite +} // namespace mindspore + +#endif diff --git a/mindspore/lite/tools/converter/quantizer/general_bitpacking.cc b/mindspore/lite/tools/converter/quantizer/general_bitpacking.cc deleted file mode 100644 index 448924129f..0000000000 --- a/mindspore/lite/tools/converter/quantizer/general_bitpacking.cc +++ /dev/null @@ -1,84 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tools/converter/quantizer/general_bitpacking.h" - -namespace mindspore { -namespace lite { -BitPack::BitPack(const uint8_t &bitnum) { this->bitnum = bitnum; } -void BitPack::UnPackFromUint8ToOrigin(uint8_t &n, std::queue &unpackBitData) { - int bitCount = 0; - while (bitCount < 8) { - bool a = n % 2; - n = n >> 1; - bitCount++; - unpackBitData.push(a); - } -} -void BitPack::UnPack(uint8_t bitnum, uint8_t &packedData, std::vector &originData, - std::queue &unpackBitData) { - UnPackFromUint8ToOrigin(packedData, unpackBitData); - // std::queue unpackBitTmpData; - - while (unpackBitData.size() > bitnum) { - uint32_t result = 0; - for (int k = 0; k < bitnum; k++) { - bool bitTmp = unpackBitData.front(); - result = (result << 1) + static_cast(bitTmp); - unpackBitData.pop(); - } - originData.push_back(result); - } -} -void BitPack::PackFromOriginToUint8(std::stack &ans, std::vector &packedDataVec) { - uint32_t result = 0; - for (size_t i = 0; i < 8; i++) { - bool bit_tmp = ans.top(); - result = (result << 1) + static_cast(bit_tmp); - ans.pop(); - } - packedDataVec.push_back(result); -} -void BitPack::DoBinary(uint8_t &n, std::stack &ans, std::vector &packedDataVec) { - int bitCount = 0; - while (bitCount < bitnum) { - bool a = n / (1 << (unsigned int)(bitnum - bitCount - 1)); - n = n - a * (1 << (unsigned int)(bitnum - bitCount - 1)); - bitCount++; - ans.push(a); - if (ans.size() == 8) { - PackFromOriginToUint8(ans, packedDataVec); - } - } -} - -void BitPack::BitPacking(const std::vector &originDataVec, std::vector &packedDataVec) { - std::stack bitDataVec; - for (size_t i = 0; i < originDataVec.size(); i++) { - uint8_t tmp = originDataVec[i]; - DoBinary(tmp, bitDataVec, packedDataVec); - } - - size_t remainBitData = bitDataVec.size(); - if (8 > remainBitData && remainBitData > 0) { - for (size_t i = 0; i < 8 - remainBitData; i++) { - bitDataVec.push(0); - } - PackFromOriginToUint8(bitDataVec, packedDataVec); - } -} -} // namespace lite -} // namespace mindspore diff --git a/mindspore/lite/tools/converter/quantizer/general_bitpacking.h b/mindspore/lite/tools/converter/quantizer/general_bitpacking.h deleted file mode 100644 index af91d6648e..0000000000 --- a/mindspore/lite/tools/converter/quantizer/general_bitpacking.h +++ /dev/null @@ -1,43 +0,0 @@ -/** - * Copyright 2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__GENERAL_BITPACKING_H -#define MINDSPORE_LITE_TOOLS_CONVERTER_QUANTIZER__GENERAL_BITPACKING_H -#include -#include -#include -#include -#include - -namespace mindspore { -namespace lite { -class BitPack { - public: - explicit BitPack(const uint8_t &bitbum = 8); - ~BitPack() = default; - void BitPacking(const std::vector &originDataVec, std::vector &packedDataVec); - void UnPack(uint8_t bitnum, uint8_t &packedData, std::vector &originData, std::queue &unpackBitData); - - private: - void UnPackFromUint8ToOrigin(uint8_t &n, std::queue &unpackBitData); - void PackFromOriginToUint8(std::stack &ans, std::vector &packedDataVec); - void DoBinary(uint8_t &n, std::stack &ans, std::vector &packed_data_vec); - uint8_t bitnum; -}; -} // namespace lite -} // namespace mindspore - -#endif diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index d9cad086c5..d3247dae5a 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -22,7 +22,7 @@ #include #include #include "src/ops/primitive_c.h" -#include "mindspore/lite/tools/converter/quantizer/general_bitpacking.h" +#include "mindspore/lite/tools/converter/quantizer/bitpacking.h" #include "src/common/utils.h" #include "abstract/abstract_value.h" #include "securec/include/securec.h" @@ -292,30 +292,6 @@ STATUS CalQuantizationParams(schema::QuantParamT *quantParam, double mMin, doubl return RET_OK; } -STATUS PostBitPack(float *weight, size_t shapeSize, size_t bitNum) { - auto *rawDatas = reinterpret_cast(weight); - vector qDatas(rawDatas, rawDatas + shapeSize); - vector qDatas_packed; - if (bitNum < 8 && bitNum > 1) { - BitPack weight_bitpack(bitNum); - weight_bitpack.BitPacking(qDatas, qDatas_packed); - if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas_packed[0], shapeSize)) { - MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; - return RET_ERROR; - } - } else if (bitNum == 8) { - if (EOK != memcpy_s(rawDatas, shapeSize, &qDatas[0], shapeSize)) { - MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas failed"; - return RET_ERROR; - } - } else { - MS_LOG(ERROR) << "bitNum must be between 0 and 8 : " << bitNum; - return RET_ERROR; - } - - return RET_OK; -} - static bool SearchLowerBound(const std::vector &data, const size_t &index, const float &max_tmp, float *min_tmp, size_t *min_idx) { size_t length = data.size(); diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.h b/mindspore/lite/tools/converter/quantizer/quantize_util.h index 2afe92f07f..4165820ed5 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.h +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.h @@ -34,6 +34,7 @@ #include "base/base.h" #include "ir/primitive.h" #include "abstract/dshape.h" +#include "tools/converter/quantizer/bitpacking.h" namespace mindspore { namespace lite { @@ -279,6 +280,34 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti } weight->set_tensor_size(elem_count * sizeof(T)); } + + // do bit pack + if (bitNum != 8 && bitNum != 16) { + std::vector data{}; + for (size_t i = 0; i < quant_datas.size(); ++i) { + data.emplace_back((static_cast(quant_datas[i]))); + } + if (bitNum > 0 && bitNum < 8) { + std::vector pack_data{}; + BitPack::BitPacking(bitNum, data, &pack_data); + auto ret = memcpy_s(raw_datas, weight->tensor_size(), pack_data.data(), pack_data.size() * sizeof(uint8_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; + return RET_ERROR; + } + weight->set_tensor_size(pack_data.size() * sizeof(uint8_t)); + } else if (bitNum > 8 && bitNum < 16) { + std::vector pack_data{}; + BitPack::BitPacking(bitNum, data, &pack_data); + auto ret = memcpy_s(raw_datas, weight->tensor_size(), pack_data.data(), pack_data.size() * sizeof(uint16_t)); + if (ret != EOK) { + MS_LOG(ERROR) << "PostBitPack memcpy_s qDatas_packed failed"; + return RET_ERROR; + } + weight->set_tensor_size(pack_data.size() * sizeof(uint16_t)); + } + } + if (quant_params.empty()) { MS_LOG(ERROR) << "quant_params empty"; return RET_ERROR; @@ -291,8 +320,6 @@ STATUS QuantFilter(ParamValueLitePtr weight, std::shared_ptr primiti return RET_OK; } -STATUS PostBitPack(float *weights, size_t shapeSize, size_t bitNum = UINT8_QUANTIZATION); - schema::PrimitiveType NodePrimitiveType(CNodePtr cnode); } // namespace quant } // namespace lite diff --git a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc index 04ba288305..07bf87948c 100644 --- a/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc +++ b/mindspore/lite/tools/converter/quantizer/weight_quantizer.cc @@ -48,8 +48,13 @@ STATUS WeightQuantizer::WeightQuantInputCheck(const converter::Flags *config) { MS_LOG(ERROR) << "quantSize must be valid pos num."; return RET_ERROR; } - if (!WeightQuantizer::IsPosNum(config->bitNum) || (config->bitNum != "8" && config->bitNum != "16")) { - MS_LOG(ERROR) << "bitNum must be valid pos num, current only support 8 or 16 bit weight quant."; + if (!WeightQuantizer::IsPosNum(config->bitNum)) { + MS_LOG(ERROR) << "bitNum must be valid pos num."; + return RET_ERROR; + } + int bitNum = std::stoi(config->bitNum); + if (bitNum <= 0 || bitNum > 16) { + MS_LOG(ERROR) << "bitNum should be more than 0 and less than 16 currently."; return RET_ERROR; } return RET_OK; @@ -63,10 +68,13 @@ WeightQuantizer::WeightQuantizer(FuncGraphPtr graph, const string &weightSize, mStrategy.reset(new QuantStrategy(quantSize, convQuantWeightChannelThreshold)); quant_max = (1 << (unsigned int)(this->bitNum - 1)) - 1; quant_min = -(1 << (unsigned int)(this->bitNum - 1)); - if (this->bitNum == 8) { + // parse type_id + if (this->bitNum > 0 && this->bitNum <= 8) { type_id = kNumberTypeInt8; - } else if (this->bitNum == 16) { + } else if (this->bitNum <= 16) { type_id = kNumberTypeInt16; + } else { + MS_LOG(ERROR) << "invalid input bits"; } } @@ -100,7 +108,6 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { MS_LOG(ERROR) << "model weight data type invalid which is " << param_value->tensor_type(); return RET_ERROR; } - auto status = RET_ERROR; if (type_id == kNumberTypeInt8) { status = QuantFilter(param_value, primitive_c, QuantType_WeightQuant, quant_max, quant_min, bitNum, true); @@ -127,7 +134,6 @@ STATUS WeightQuantizer::DoConvQuantize(const std::list &nodes) { abstractTensor->element()->set_type(TypeIdToType(type_id)); primitive_c->SetQuantType(schema::QuantType_WeightQuant); } - return RET_OK; } @@ -136,7 +142,6 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { if (!mStrategy->CanMulOpQuantized(node)) { continue; } - auto already_quant = false; ParamValueLitePtr param_value = nullptr; ParameterPtr param_node = nullptr; for (size_t i = 1; i < node->size(); i++) { @@ -146,16 +151,8 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { if ((param_node != nullptr) && param_node->has_default()) { param_value = std::static_pointer_cast(param_node->default_param()); if ((param_value == nullptr) || (param_value->tensor_size() == 0) || - (param_value->tensor_addr() == nullptr)) { - param_value = nullptr; - continue; - } else if (param_value->tensor_type() == mindspore::kNumberTypeInt8 || - param_value->tensor_type() == mindspore::kNumberTypeInt16) { - MS_LOG(INFO) << "the node: " << node->fullname_with_scope() << " input_i: " << i << "has been " - << " quantized"; - already_quant = true; - break; - } else if (param_value->tensor_type() != mindspore::kNumberTypeFloat32) { + (param_value->tensor_addr() == nullptr) || + (param_value->tensor_type() != mindspore::kNumberTypeFloat32)) { param_value = nullptr; continue; } else { @@ -164,11 +161,6 @@ STATUS WeightQuantizer::DoMulQuantize(const std::list &nodes) { } } } - - if (already_quant) { - continue; - } - if (param_value == nullptr) { MS_LOG(ERROR) << "No valid input param node !"; return RET_ERROR;