From ad99fb342b82bad9345bfe8ada8b1f4d373ba27e Mon Sep 17 00:00:00 2001 From: xutianchun Date: Thu, 30 Jul 2020 09:21:34 +0800 Subject: [PATCH] Insert QuantCast node after post training quantization --- mindspore/lite/tools/converter/converter.cc | 12 +- .../tools/converter/quantizer/CMakeLists.txt | 1 + .../converter/quantizer/post_training.cc | 5 +- .../tools/converter/quantizer/quant_cast.cc | 112 ++++++++++++++++++ .../tools/converter/quantizer/quant_cast.h | 39 ++++++ .../converter/quantizer/quantize_util.cc | 2 +- 6 files changed, 166 insertions(+), 5 deletions(-) create mode 100644 mindspore/lite/tools/converter/quantizer/quant_cast.cc create mode 100644 mindspore/lite/tools/converter/quantizer/quant_cast.h diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index c31fb88fd3..c816e56798 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -32,6 +32,7 @@ #include "tools/converter/parser/onnx/onnx.pb.h" #include "tools/converter/quantizer/weight_quantizer.h" #include "tools/converter/quantizer/post_training.h" +#include "tools/converter/quantizer/quant_cast.h" namespace mindspore { namespace lite { @@ -90,7 +91,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { } // auto newGraph = anfTransform->Transform(graph); - /* + CreateQuantizer(graph, flag); if (mQuantizer != nullptr) { auto status = mQuantizer->DoQuantize(graph); @@ -98,8 +99,15 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { MS_LOG(ERROR) << "Quant failed " << status; return nullptr; } + quant::QuantCast quant_cast; + quant_cast.SetInputDataDType(kNumberTypeFloat32); + status = quant_cast.Run(graph); + if (status != RET_OK) { + MS_LOG(ERROR) << "add QuantCast error"; + return nullptr; + } } - */ + // anf -- fb auto meta_graph = Export(graph); if (meta_graph == nullptr) { diff --git a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt index d7fba0631a..f22f952164 100644 --- a/mindspore/lite/tools/converter/quantizer/CMakeLists.txt +++ b/mindspore/lite/tools/converter/quantizer/CMakeLists.txt @@ -11,6 +11,7 @@ add_library(quantizer_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc ${CMAKE_CURRENT_SOURCE_DIR}/post_training.cc + ${CMAKE_CURRENT_SOURCE_DIR}/quant_cast.cc #${CMAKE_CURRENT_SOURCE_DIR}/../proto/post_training/post_training.pb.cc ) diff --git a/mindspore/lite/tools/converter/quantizer/post_training.cc b/mindspore/lite/tools/converter/quantizer/post_training.cc index b96c1967f3..392c3f9294 100644 --- a/mindspore/lite/tools/converter/quantizer/post_training.cc +++ b/mindspore/lite/tools/converter/quantizer/post_training.cc @@ -732,9 +732,10 @@ STATUS PostTrainingQuantizer::CheckTensorVec(const std::string &nodeName, return RET_ERROR; } tensor::Tensor *tensor = tensorVec[0]; - if (tensor->data_type() != kNumberTypeFloat) { + if (tensor->data_type() != kNumberTypeFloat32) { //&& tensor->RefCount() != MSCONST_WEIGHT_REFCOUNT - MS_LOG(DEBUG) << "node: " << nodeName << " will not quantize"; + MS_LOG(DEBUG) << "node: " << nodeName << " will not quantize" << " tensor data_type: " << tensor->data_type(); + return RET_ERROR; } return RET_OK; } diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.cc b/mindspore/lite/tools/converter/quantizer/quant_cast.cc new file mode 100644 index 0000000000..b878774e35 --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.cc @@ -0,0 +1,112 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +#include "mindspore/lite/tools/converter/quantizer/quant_cast.h" +#include +#include +#include "mindspore/lite/src/ir/primitive_t_value.h" + +namespace mindspore::lite::quant { + +ValueNodePtr NewQuantCastValueNode(int src_type, int dst_type) { + std::unique_ptr primitive = std::make_unique(); + schema::QuantDTypeCastT quant_dtype_cast; + quant_dtype_cast.srcT = src_type; // kNumberTypeUInt8; + quant_dtype_cast.dstT = dst_type; // kNumberTypeFloat32; + primitive->value.Set(quant_dtype_cast); + auto primTValue = std::make_shared(primitive.release()); + return NewValueNode(primTValue); +} + +STATUS QuantCast::Run(FuncGraphPtr graph) { + MS_ASSERT(graph != nullptr); + + auto cnodes = graph->GetOrderedCnodes(); + bool first = true; + + for (auto &cnode : cnodes) { + auto primitiveT_value = GetValueNode>(cnode->input(0)); + auto curnode_quant_type = schema::QuantType_QUANT_NONE; + if (primitiveT_value == nullptr) { + MS_LOG(WARNING) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); + } else { + curnode_quant_type = primitiveT_value->GetQuantType(); + } + if (first) { + if (curnode_quant_type == schema::QuantType_PostTraining && inputDataDType == kNumberTypeFloat32) { + auto value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8); + std::vector op_inputs = {value_node, cnode->input(1)}; + auto quant_cast_cnode = graph->NewCNode(op_inputs); + quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); + cnode->set_input(1, quant_cast_cnode); + MS_LOG(DEBUG) << "Add quant cast at front. " + << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type; + } + first = false; + continue; + } + + for (int i = 1; i < cnode->inputs().size(); i++) { + auto input_node = cnode->input(i); + if (!input_node->isa()) { + continue; + } + auto input_cnode = std::dynamic_pointer_cast(input_node); + auto input_cnode_primitiveT_value = GetValueNode>(input_cnode->input(0)); + if (input_cnode_primitiveT_value == nullptr) { + MS_LOG(DEBUG) << "input: " << i << " " << input_cnode->fullname_with_scope() << ": " + << " PrimitiveTValue is null"; + continue; + } + auto input_cnode_quant_type = input_cnode_primitiveT_value->GetQuantType(); + + if (curnode_quant_type != input_cnode_quant_type) { + ValueNodePtr value_node = nullptr; + if (curnode_quant_type == schema::QuantType_PostTraining && + input_cnode_quant_type == schema::QuantType_QUANT_NONE) { + value_node = NewQuantCastValueNode(kNumberTypeFloat32, kNumberTypeUInt8); + } else if (curnode_quant_type == schema::QuantType_QUANT_NONE && + input_cnode_quant_type == schema::QuantType_PostTraining) { + value_node = NewQuantCastValueNode(kNumberTypeUInt8, kNumberTypeFloat32); + } + if (value_node == nullptr) { + MS_LOG(WARNING) << "value_node is null! " + << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " + << " input_" << i << ": " << input_cnode->fullname_with_scope() + << " quant_type:" << input_cnode_quant_type; + continue; + } + std::vector op_inputs = {value_node, input_cnode}; + auto quant_cast_cnode = graph->NewCNode(op_inputs); + quant_cast_cnode->set_fullname_with_scope(cnode->fullname_with_scope() + "_quant_cast"); + cnode->set_input(i, quant_cast_cnode); + MS_LOG(DEBUG) << "Add quant cast. " + << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type + << " input_" << i << ": " << input_cnode->fullname_with_scope() + << " quant_type:" << input_cnode_quant_type; + } else { + MS_LOG(DEBUG) << "No need to add quant cast. " + << "cur_node: " << cnode->fullname_with_scope() << " quant_type: " << curnode_quant_type + << " input_" << i << ": " << input_cnode->fullname_with_scope() + << " quant_type:" << input_cnode_quant_type; + } + } + } + return RET_OK; +} + +} // namespace mindspore::lite::quant diff --git a/mindspore/lite/tools/converter/quantizer/quant_cast.h b/mindspore/lite/tools/converter/quantizer/quant_cast.h new file mode 100644 index 0000000000..2349dc37bb --- /dev/null +++ b/mindspore/lite/tools/converter/quantizer/quant_cast.h @@ -0,0 +1,39 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef LITE_QUANT_CAST_H +#define LITE_QUANT_CAST_H + +#include "mindspore/core/ir/anf.h" +#include "mindspore/lite/include/errorcode.h" +#include "mindspore/core/ir/dtype/type_id.h" +#include "mindspore/core/ir/func_graph.h" + +namespace mindspore::lite::quant { + +class QuantCast { + public: + QuantCast() = default; + STATUS Run(FuncGraphPtr graph); + void SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; } + + private: + TypeId inputDataDType = kNumberTypeFloat32; +}; + +} // namespace mindspore::lite::quant + +#endif // LITE_QUANT_CAST_H diff --git a/mindspore/lite/tools/converter/quantizer/quantize_util.cc b/mindspore/lite/tools/converter/quantizer/quantize_util.cc index 64dd0c0c18..44eecd231a 100644 --- a/mindspore/lite/tools/converter/quantizer/quantize_util.cc +++ b/mindspore/lite/tools/converter/quantizer/quantize_util.cc @@ -88,7 +88,7 @@ bool QuantStrategy::CanOpPostQuantized(AnfNodePtr &node) const { auto primitiveT_value = GetValueNode>(cnode->input(0)); if (primitiveT_value == nullptr) { - MS_LOG(ERROR) << "PrimitiveT_value is nullptr"; + MS_LOG(ERROR) << "PrimitiveT_value is nullptr: " << cnode->fullname_with_scope(); return false; }