Aware training support patial quant

pull/6200/head
cjh9368 4 years ago
parent 73b5ffb99c
commit 1cd9445087

@ -203,6 +203,7 @@ if(BUILD_CONVERTER)
${LITE_DIR}/tools/optimizer/fusion/conv_scale_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/conv_bn_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/constant_folding_fusion.cc
${LITE_DIR}/tools/optimizer/fusion/quant_dtype_cast_fusion.cc
)
endif()
### train

@ -75,7 +75,7 @@ static const std::vector<schema::PrimitiveType> int8OpList = {
schema::PrimitiveType_SoftMax, schema::PrimitiveType_Split,
schema::PrimitiveType_Squeeze, schema::PrimitiveType_Sub,
schema::PrimitiveType_TopK, schema::PrimitiveType_Unsqueeze,
schema::PrimitiveType_MatMul};
schema::PrimitiveType_MatMul, schema::PrimitiveType_Pad};
static const std::vector<schema::PrimitiveType> needInsertOpList = {
schema::PrimitiveType_Eltwise, schema::PrimitiveType_Activation, schema::PrimitiveType_Concat,

@ -61,6 +61,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
../optimizer/fusion/conv_scale_fusion.cc
../optimizer/fusion/conv_bn_fusion.cc
../optimizer/fusion/constant_folding_fusion.cc
../optimizer/fusion/quant_dtype_cast_fusion.cc
)
add_subdirectory(../anf_importer anf_importer)

@ -24,6 +24,7 @@
#include "tools/optimizer/fusion/conv_scale_fusion.h"
#include "tools/optimizer/fusion/conv_bn_fusion.h"
#include "tools/optimizer/fusion/constant_folding_fusion.h"
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
#include "tools/converter/quantizer/post_training_quantizer.h"
#include "tools/converter/quantizer/quant_cast.h"
#include "tools/converter/quantizer/weight_quantizer.h"
@ -43,6 +44,10 @@ FuncGraphPtr AnfTransform::Transform(const FuncGraphPtr &old_graph, const conver
// for now - trainning is not supporting fuse operations
if (config != nullptr && config->trainModel == false) {
// remove quantdtype when awaretraining
if (config->quantType == QuantType_AwareTraining) {
pm->AddPass(std::make_shared<opt::QuantDtypeCastFusion>());
}
pm->AddPass(std::make_shared<opt::ConvBiasaddFusion>());
pm->AddPass(std::make_shared<opt::ConvBatchNormFusion>());
pm->AddPass(std::make_shared<opt::ConvScaleFusion>());

@ -102,7 +102,7 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
// generate and infer quant parameters
{
if (mQuantizer != nullptr) {
if (fbQuantizer != nullptr) {
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
@ -110,14 +110,13 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
if (!(this->graphDefT->fmkType == converter::FmkType_TF &&
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTraining)) {
status = mQuantizer->GenerateQuantParam();
if (ctx.quantType == QuantType_AwareTraining) {
status = fbQuantizer->GenerateQuantParam();
if (status != RET_OK) {
MS_LOG(ERROR) << "GenerateQuantParam failed";
return status;
}
status = mQuantizer->DetermineNodeQuantType();
status = fbQuantizer->DetermineNodeQuantType();
if (status != RET_OK) {
MS_LOG(ERROR) << "DetermineNodeQuant failed";
return status;

@ -151,7 +151,7 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
auto &graphInIdxes = graph->inputIndex;
if (preTensor->nodeType == NodeType_ValueNode && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
if (!preTensor->data.empty() && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status);

@ -46,7 +46,7 @@ STATUS TfliteDequantizeParser::Parse(const std::unique_ptr<tflite::OperatorT> &t
MS_LOG(ERROR) << "output tensor is null";
return RET_NULL_PTR;
}
if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8) {
if (GetTfliteDataType(in_tensor->type) == kNumberTypeInt8 || GetTfliteDataType(in_tensor->type) == kNumberTypeUInt8) {
std::unique_ptr<schema::QuantDTypeCastT> attr = std::make_unique<schema::QuantDTypeCastT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";

@ -77,8 +77,7 @@ void TfliteModelParser::SetTensorQuantParam(const std::unique_ptr<tflite::Tensor
}
// change quant param min to 0 to fit ms-lite ops
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8
&& tensor->dataType == TypeId::kNumberTypeInt8) {
if (GetTfliteDataType(tflite_tensor->type) == TypeId::kNumberTypeUInt8 && tensor->data.empty()) {
quant_param->zeroPoint = quant_param->zeroPoint - 128;
}

@ -115,11 +115,6 @@ STATUS AwareQuantizer::GenerateQuantParam() {
return status;
}
}
auto status = GenerateDefaultQuantParam(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "GenerateDefaultQuantParam failed";
return status;
}
auto *quantParamRegister = QuantParamCalcRegister::GetInstance();
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
@ -135,7 +130,7 @@ STATUS AwareQuantizer::GenerateQuantParam() {
<< ", type: " << GetCNodeTTypeName(*node).c_str() << " set node to QuantNone and skip";
node->quantType = static_cast<schema::QuantType>(QuantType_QUANT_NONE);
} else {
status = quantParamCalcer->Calc(graph, *node);
auto status = quantParamCalcer->Calc(graph, *node);
if (status != RET_OK) {
MS_LOG(ERROR) << "quantParamCalcer failed: " << status << " node: " << node->name.c_str();
node->quantType = schema::QuantType_QUANT_NONE;
@ -167,19 +162,25 @@ STATUS AwareQuantizer::DoQuantize() {
return RET_ERROR;
}
// quant weight
auto &weightTensor = graph->allTensors.at(node->inputIndex.at(1));
if (!weightTensor->quantParams.empty() && weightTensor->quantParams.at(0)->inited) {
status = QuantConvWeight(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvWeight failed!";
return RET_ERROR;
}
}
// quant bias
if (inputIndexes.size() == 3) {
auto &biasTensor = graph->allTensors.at(node->inputIndex.at(2));
if (!biasTensor->quantParams.empty() && biasTensor->quantParams.at(0)->inited) {
status = QuantConvBias(graph, node.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "QuantConvBias failed!";
return RET_ERROR;
}
}
}
} else if (GetCNodeTType(*node) == schema::PrimitiveType_DetectionPostProcess) {
status = QuantDetectionPostProcessConstTensor(graph, node.get());
if (status != RET_OK) {
@ -376,18 +377,6 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
for (auto &node : graph->nodes) {
MS_ASSERT(node != nullptr);
bool canQuant = true;
for (auto &inTensorIdx : node->inputIndex) {
MS_ASSERT(graph->allTensors.size() > inTensorIdx);
auto &inTensor = graph->allTensors.at(inTensorIdx);
MS_ASSERT(inTensor != nullptr);
if (inTensor->quantParams.empty() || inTensor->quantParams.front() == nullptr ||
!inTensor->quantParams.front()->inited) {
canQuant = false;
break;
}
}
if (canQuant) {
for (auto &outTensorIdx : node->outputIndex) {
MS_ASSERT(graph->allTensors.size() > outTensorIdx);
auto &outTensor = graph->allTensors.at(outTensorIdx);
@ -398,7 +387,7 @@ STATUS AwareQuantizer::DetermineNodeQuantType() {
break;
}
}
}
if (canQuant && IsContain(GetUint8OpList(), GetCNodeTType(*node))) {
node->quantType = schema::QuantType_AwareTraining;
} else {

@ -70,6 +70,9 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
auto &tensor = graph->allTensors.at(node.inputIndex.at(i));
MS_ASSERT(tensor != nullptr);
auto quantParam = GetTensorQuantParam(tensor);
if (quantParam == nullptr) {
continue;
}
if (quantParam->inited) { // inited
inputParamDone++;
continue;
@ -77,8 +80,7 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(i));
MS_ASSERT(tensor != nullptr);
if (tensor->refCount == schema::NodeType::NodeType_ValueNode &&
!IsContain(graph->inputIndex, node.inputIndex.at(i))) {
if (!tensor->data.empty() && !IsContain(graph->inputIndex, node.inputIndex.at(i))) {
auto status = ComputeConstQuantParam((*tensor), quantParam.get());
if (status != RET_OK) {
MS_LOG(WARNING) << "ComputeConstQuantParam failed: " << status;
@ -95,13 +97,12 @@ int QuantParamCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
auto &tensor = graph->allTensors.at(i);
MS_ASSERT(tensor != nullptr);
auto quantParam = GetTensorQuantParam(tensor);
MS_ASSERT(quantParam != nullptr);
if (quantParam->inited) { // inited
if (quantParam != nullptr && quantParam->inited) { // inited
outputParamDone++;
continue;
}
if (tensor->refCount == 999) {
if (!tensor->data.empty()) {
MS_ASSERT(false);
}
}
@ -146,10 +147,10 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
auto &inTensor = graph->allTensors.at(i);
MS_ASSERT(inTensor != nullptr);
auto inQuantParam = GetTensorQuantParam(inTensor);
if (inQuantParam->inited) {
if (inQuantParam == nullptr || inQuantParam->inited) {
continue;
}
inTensor->quantParams.front() = std::move(inQuantParam);
inTensor->quantParams.front() = std::move(outputQuantParam);
}
}
if (outputParamDone != node.outputIndex.size()) {
@ -157,7 +158,7 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
auto &inTensor = graph->allTensors.at(node.inputIndex.at(0));
MS_ASSERT(inTensor != nullptr);
auto inQuantParam = GetTensorQuantParam(inTensor);
if (!inQuantParam->inited) {
if (inQuantParam == nullptr || !inQuantParam->inited) {
MS_LOG(WARNING) << "Can not determine outputTensor quantParam from inputTensor for node %s" << node.name;
return RET_ERROR;
}
@ -166,10 +167,10 @@ int LinearCalcer::Calc(MetaGraphT *graph, const CNodeT &node) {
auto &outTensor = graph->allTensors.at(node.outputIndex.at(i));
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);
if (outQuantParam->inited) {
if (outQuantParam == nullptr || outQuantParam->inited) {
continue;
}
outTensor->quantParams.front() = std::move(outQuantParam);
outTensor->quantParams.front() = std::move(inQuantParam);
}
}
return RET_OK;
@ -225,13 +226,14 @@ class CalcConcat : public QuantParamCalcer {
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);
auto outQuantParam = std::make_unique<QuantParamT>();
status = quant::CalQuantizationParams(outQuantParam.get(), minMin, maxMax, narrowRange, numBits);
if (status != RET_OK) {
MS_LOG(WARNING) << "in aware quantization run CalQuantizationParams failed!";
return RET_ERROR;
}
outTensor->quantParams.front() = std::move(outQuantParam);
outputParamDone++;
}
@ -261,7 +263,7 @@ class CalcAdd : public QuantParamCalcer {
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);
auto outQuantParam = std::make_unique<QuantParamT>();
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
auto &tensor0 = graph->allTensors.at(node.inputIndex.at(0));
@ -271,10 +273,10 @@ class CalcAdd : public QuantParamCalcer {
MS_ASSERT(tensor1 != nullptr);
auto biasTensor = &tensor0;
auto paramTensor = &tensor1;
if (tensor0->refCount == 999 && (tensor0->dims.empty() || tensor0->dims.size() == 1)) {
if (!tensor0->data.empty() && (tensor0->dims.empty() || tensor0->dims.size() == 1)) {
biasTensor = &tensor0;
paramTensor = &tensor1;
} else if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) {
} else if (!tensor1->data.empty() && (tensor1->dims.empty() || tensor1->dims.size() == 1)) {
biasTensor = &tensor1;
paramTensor = &tensor0;
} else {
@ -310,6 +312,7 @@ class CalcAdd : public QuantParamCalcer {
return RET_ERROR;
}
}
outTensor->quantParams.front() = std::move(outQuantParam);
}
return RET_OK;
}
@ -337,13 +340,13 @@ class CalcRealDiv : public QuantParamCalcer {
MS_ASSERT(graph->allTensors.size() > node.outputIndex.front());
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
auto outQuantParam = GetTensorQuantParam(outTensor);
auto outQuantParam = std::make_unique<QuantParamT>();
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(0));
MS_ASSERT(graph->allTensors.size() > node.inputIndex.at(1));
auto &tensor1 = graph->allTensors.at(node.inputIndex.at(1));
MS_ASSERT(tensor1 != nullptr);
if (tensor1->refCount == 999 && (tensor1->dims.empty() || tensor1->dims.size() == 1)) {
if (!tensor1->data.empty() && (tensor1->dims.empty() || tensor1->dims.size() == 1)) {
auto quantParam = GetTensorQuantParam(tensor1);
auto min = quantParam->min;
auto max = quantParam->max;
@ -371,6 +374,7 @@ class CalcRealDiv : public QuantParamCalcer {
MS_LOG(WARNING) << "Unsupported tensor dataType: " << tensor1->dataType;
return RET_ERROR;
}
outTensor->quantParams.front() = std::move(outQuantParam);
}
} else {
MS_LOG(WARNING) << "Can not determine realDiv outputTensor quantParam, node " << node.name;
@ -399,7 +403,8 @@ class CalcToSet : public QuantParamCalcer {
return RET_ERROR;
}
// output
std::unique_ptr<QuantParamT> quantParam(new (std::nothrow) QuantParamT());
if (outputParamDone != node.outputIndex.size()) {
std::unique_ptr<QuantParamT> quantParam = std::make_unique<QuantParamT>();
if (quantParam == nullptr) {
MS_LOG(WARNING) << "new QuantParamT failed";
return RET_ERROR;
@ -414,6 +419,8 @@ class CalcToSet : public QuantParamCalcer {
auto &outTensor = graph->allTensors.at(node.outputIndex.front());
MS_ASSERT(outTensor != nullptr);
outTensor->quantParams.front() = std::move(quantParam);
outputParamDone++;
}
return RET_OK;
}

@ -357,6 +357,14 @@ bool IsPoolingNode(const BaseRef &n) {
return false;
}
bool IsQuantNode(const BaseRef &n) {
if (utils::isa<CNodePtr>(n) || utils::isa<ValueNodePtr>(n)) {
auto type = opt::GetCNodeType(n);
return type == schema::PrimitiveType_QuantDTypeCast;
}
return false;
}
bool CheckIsAllInputsParam(const AnfNodePtr &node) {
if (utils::isa<CNode>(node)) {
auto cnode = node->cast<CNodePtr>();

@ -58,6 +58,8 @@ bool IsConvNode(const BaseRef &n);
bool IsPoolingNode(const BaseRef &n);
bool IsQuantNode(const BaseRef &n);
bool CheckIsAllInputsParam(const AnfNodePtr &node);
size_t GetOutputTensorNum(const AnfNodePtr &node);

@ -0,0 +1,47 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/optimizer/fusion/quant_dtype_cast_fusion.h"
#include <memory>
#include "src/ops/primitive_c.h"
#include "src/ops/conv2d.h"
#include "src/ops/depthwise_conv2d.h"
#include "src/ops/activation.h"
#include "schema/inner/model_generated.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore::opt {
namespace {
constexpr size_t kActivationInputsLength = 2;
}
const BaseRef QuantDtypeCastFusion::DefinePattern() const {
auto quant_var = std::make_shared<CondVar>(IsQuantNode);
auto input_var = std::make_shared<Var>();
return VectorRef({quant_var, input_var});
}
const AnfNodePtr QuantDtypeCastFusion::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_LOG(DEBUG) << "quant dtype cast fusion pass process";
CheckIfFuncGraphIsNull(func_graph);
CheckIfAnfNodeIsNull(node);
auto act_node = node->cast<CNodePtr>();
CheckIfCNodeIsNull(act_node);
CheckInputSize(act_node, kActivationInputsLength);
AnfNodePtr pre_node = act_node->input(1);
CheckIfAnfNodeIsNull(pre_node);
return pre_node;
}
} // namespace mindspore::opt

@ -0,0 +1,35 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LITE_QUANT_DTYPE_CAST_FUSION_H
#define LITE_QUANT_DTYPE_CAST_FUSION_H
#include <string>
#include "backend/optimizer/common/optimizer.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
namespace opt {
class QuantDtypeCastFusion : public PatternProcessPass {
public:
explicit QuantDtypeCastFusion(bool multigraph = true, const std::string &name = "quant_dtype_cast_fusion")
: PatternProcessPass(name, multigraph) {}
~QuantDtypeCastFusion() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
} // namespace opt
} // namespace mindspore
#endif // LITE_QUANT_DTYPE_CAST_FUSION_H
Loading…
Cancel
Save