add quant aware compile success

pull/4358/head
cjh9368 5 years ago
parent cfd37ca90d
commit 18c6ac9988

@ -188,7 +188,7 @@ schema::MetaGraphT *AnfExporter::Export(const FuncGraphPtr &funcGraph) {
// add quant param
node->quantType = primitiveT_value->GetQuantType();
if (node->quantType == schema::QuantType_PostTraining) {
if (node->quantType == schema::QuantType_PostTraining || node->quantType == schema::QuantType_AwareTrainning) {
MS_LOG(INFO) << "node: " << node->name << " add QuantParam";
// activation
auto input_quant_params = primitiveT_value->GetInputQuantParams();

@ -60,6 +60,17 @@ void AnfImporterFromMetaGraphT::ConverterConstTensor() {
param_value->set_tensor_addr(tensor_data);
param_value->set_tensor_size(size);
}
if (tensor->quantParams.size() > 0) {
std::unique_ptr<AnfQuantParam> quantParam = std::make_unique<AnfQuantParam>();
quantParam->scale = tensor->quantParams[0]->scale;
quantParam->zeroPoint = tensor->quantParams[0]->zeroPoint;
quantParam->min = tensor->quantParams[0]->min;
quantParam->max = tensor->quantParams[0]->max;
quantParam->narrowRange = tensor->quantParams[0]->narrowRange;
quantParam->numBits = tensor->quantParams[0]->numBits;
quantParam->inited = tensor->quantParams[0]->inited;
param_value->set_quant_param(quantParam);
}
parameter->set_default_param(param_value);
AddNode(i, parameter);
}
@ -77,6 +88,16 @@ int AnfImporterFromMetaGraphT::ConverterCNode() {
flag = true;
}
auto primTValue = std::make_shared<PrimitiveTValue>(cNode->primitive.release());
// add quant parameter
if (cNode->quantType == schema::QuantType_AwareTrainning || cNode->quantType == schema::QuantType_PostTraining) {
primTValue->SetQuantType(cNode->quantType);
for (int index : cNode->inputIndex) {
primTValue->AddInputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));
}
for (int index : cNode->outputIndex) {
primTValue->AddOutputQuantParam(*(meta_graph_->allTensors[index]->quantParams[0]));
}
}
cNode->primitive = nullptr;
auto value_node = NewValueNode(primTValue);

@ -28,7 +28,7 @@
namespace mindspore {
namespace lite {
OpDefCopyer GetSimpleOpCopyer() {
return [](std::unique_ptr<CNodeT> &inCNode) -> std::unique_ptr<CNodeT> {
return [](CNodeT *inCNode) -> std::unique_ptr<CNodeT> {
std::unique_ptr<CNodeT> newCNode(new CNodeT);
newCNode->name = inCNode->name;
@ -421,9 +421,13 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
}
preTensor->refCount = 0;
preTensor->data.clear();
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) {
MS_LOG(ERROR) << "copy toAddNodeIn failed";
*errorCode = RET_NULL_PTR;
@ -456,9 +460,13 @@ NodeIter InsertNodeBefore(schema::MetaGraphT *graphT, NodeIter existNodeIter, si
// MS_LOG(ERROR)("Copy TensorT failed");
return graphT->nodes.end();
}
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
preTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) {
// MS_LOG(ERROR)("copy toAddNodeIn failed");
*errorCode = RET_NULL_PTR;
@ -505,9 +513,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
*errorCode = RET_NULL_PTR;
return graphT->nodes.end();
}
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) {
// MS_LOG(ERROR)("copy toAddNodeIn failed");
*errorCode = RET_NULL_PTR;
@ -540,9 +552,13 @@ NodeIter InsertNodeAfter(schema::MetaGraphT *graphT, NodeIter existNodeIter, siz
*errorCode = RET_NULL_PTR;
return graphT->nodes.end();
}
if (toAddNodeIn->primitive->value.type == schema::PrimitiveType_QuantDTypeCast) {
postTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->srcT;
toAddTensor->dataType = toAddNodeIn->primitive->value.AsQuantDTypeCast()->dstT;
}
graphT->allTensors.emplace_back(std::move(toAddTensor));
size_t toAddTensorIdx = graphT->allTensors.size() - 1;
auto toAddNode = opDefCopyer(toAddNodeIn);
auto toAddNode = opDefCopyer(toAddNodeIn.get());
if (toAddNode == nullptr) {
// MS_LOG(ERROR)("copy toAddNodeIn failed");
*errorCode = RET_NULL_PTR;

@ -36,7 +36,7 @@ enum InsertPlace { kBefore, kAfter };
using NodeIter = std::vector<std::unique_ptr<schema::CNodeT>>::iterator;
using OpDefCopyer = std::function<std::unique_ptr<schema::CNodeT>(std::unique_ptr<schema::CNodeT> &)>;
using OpDefCopyer = std::function<std::unique_ptr<schema::CNodeT> (schema::CNodeT *)>;
OpDefCopyer GetSimpleOpCopyer();

@ -19,8 +19,29 @@
#include "tools/common/tensor_util.h"
#include "tools/common/graph_util.h"
namespace mindspore {
namespace lite {
namespace mindspore::lite {
std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor) {
MS_ASSERT(tensor != nullptr);
auto &quantParams = tensor->quantParams;
if (!quantParams.empty()) {
return std::move(CopyQuantParamT(quantParams.front()));
} else {
return nullptr;
}
}
std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schema::QuantParamT> &srcQuantParam) {
MS_ASSERT(srcQuantParam != nullptr);
std::unique_ptr<schema::QuantParamT> dstQuantParam = std::make_unique<schema::QuantParamT>();
dstQuantParam->inited = srcQuantParam->inited;
dstQuantParam->scale = srcQuantParam->scale;
dstQuantParam->zeroPoint = srcQuantParam->zeroPoint;
dstQuantParam->min = srcQuantParam->min;
dstQuantParam->max = srcQuantParam->max;
dstQuantParam->narrowRange = srcQuantParam->narrowRange;
dstQuantParam->numBits = srcQuantParam->numBits;
return std::move(dstQuantParam);
}
std::unique_ptr<QuantParamT> CopyQuantParamArrayT(const std::unique_ptr<QuantParamT> &srcQuantParamArray) {
MS_ASSERT(srcQuantParamArray != nullptr);
auto dstQuantParamArrayT = std::unique_ptr<QuantParamT>(new (std::nothrow) QuantParamT());
@ -164,6 +185,9 @@ std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &oldTenso
newTensor->refCount = oldTensor->refCount;
newTensor->nodeType = oldTensor->nodeType;
newTensor->data = oldTensor->data;
if (!oldTensor->quantParams.empty()) {
newTensor->quantParams.emplace_back(std::move(GetTensorQuantParam(oldTensor)));
}
return std::move(newTensor);
}
@ -186,6 +210,4 @@ size_t GetShapeSize(const std::vector<int32_t> &shape) {
}
return shapeSize;
}
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

@ -38,6 +38,9 @@ using schema::FusedBatchNormT;
using schema::Format_NCHW;
using schema::Format_NHWC;
using STATUS = int;
std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor);
size_t GetElementSize(const TensorT &tensor);
size_t GetElementSize(const TypeId &dataType);
@ -50,6 +53,8 @@ std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &);
size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx);
std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schema::QuantParamT> &srcQuantParam);
std::unique_ptr<schema::QuantParamT> \
CopyQuantParamArrayT(const std::unique_ptr<schema::QuantParamT> &srcQuantParamArray);

@ -101,6 +101,7 @@ target_link_libraries(converter_lite PRIVATE
node_mid
graph_pass_mid
fusion_mid
quantizer_mid
protobuf
quantizer_mid
pthread

@ -77,7 +77,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
MS_ASSERT(nullptr != modelParser);
const std::string modelFile = flag->modelFile;
const std::string weightFile = flag->weightFile;
auto meta_graph = modelParser->Parse(modelFile, weightFile);
auto meta_graph = modelParser->Parse(modelFile, weightFile, flag->quantType);
if (meta_graph == nullptr) {
MS_LOG(ERROR) << "Parse to metaGraph return nullptr";
return nullptr;
@ -118,6 +118,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
// transform
transform->SetGraphDef(meta_graph);
transform->CreateQuantizer(flag);
auto status = transform->Transform(*flag);
if (status != 0) {
MS_LOG(ERROR) << "FBTransform model failed " << status;
@ -125,6 +126,7 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
}
return meta_graph;
}
void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *flags) {
auto type = flags->quantType;
switch (type) {
@ -132,17 +134,18 @@ void Converter::CreateQuantizer(FuncGraphPtr funcGraph, const converter::Flags *
// mQuantizer.reset(new AwareQuantizer(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean));
break;
}
case mindspore::schema::QuantType_WeightQuant: {
MS_LOG(INFO) << "create WeightQuantizer!";
mQuantizer.reset(
new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold, flags->bitNum));
break;
}
case mindspore::schema::QuantType_PostTraining: {
MS_LOG(INFO) << "create PostTrainningQuantizer!";
mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
break;
}
// case mindspore::schema::QuantType_WeightQuant: {
// MS_LOG(INFO) << "create WeightQuantizer!";
// mQuantizer.reset(
// new quant::WeightQuantizer(funcGraph, flags->quantSize, flags->convWeightQuantChannelThreshold,
// flags->bitNum));
// break;
// }
// case mindspore::schema::QuantType_PostTraining: {
// MS_LOG(INFO) << "create PostTrainningQuantizer!";
// mQuantizer.reset(new quant::PostTrainingQuantizer(funcGraph, flags->configFile, 8));
// break;
// }
case mindspore::schema::QuantType_QUANT_NONE:
MS_LOG(INFO) << "Not do quantization for model!";
break;

@ -14,8 +14,12 @@
* limitations under the License.
*/
#include <string>
#include "tools/converter/converter_flags.h"
#include <regex>
#include <string>
#include "ir/dtype/type_id.h"
namespace mindspore {
namespace lite {
@ -70,9 +74,11 @@ int Flags::Init(int argc, const char **argv) {
return 1;
}
if (this->inputInferenceTypeIn == "FLOAT") {
this->inputInferenceType = 0;
this->inputInferenceType = TypeId::kNumberTypeFloat;
} else if (this->inputInferenceTypeIn == "UINT8") {
this->inputInferenceType = 1;
this->inputInferenceType = TypeId::kNumberTypeUInt8;
} else if (this->inputInferenceTypeIn == "INT8") {
this->inputInferenceType = TypeId::kNumberTypeInt8;
} else {
std::cerr << "INPUT INVALID: inputInferenceType is invalid: %s", this->inputInferenceTypeIn.c_str();
return 1;

@ -19,6 +19,7 @@
#include <string>
#include "tools/common/flag_parser.h"
#include "ir/dtype/type_id.h"
#include "schema/inner/model_generated.h"
namespace mindspore {
@ -66,7 +67,7 @@ class Flags : public virtual mindspore::lite::FlagParser {
// used for parse aware trainning
std::string inputInferenceTypeIn;
// mindspore::predict::DataType inputInferenceType = DataType_DT_FLOAT;
int inputInferenceType = 0;
TypeId inputInferenceType = TypeId::kNumberTypeFloat;
std::string stdDev;
std::string mean;
// used for post-trainning-weight

@ -16,11 +16,13 @@
#include "tools/converter/graphdef_transform.h"
#include <iostream>
#include <memory>
#include <string>
#include "schema/model_generated.h"
#include "utils/log_adapter.h"
#include "src/common/op_utils.h"
#include "tools/converter/converter_flags.h"
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_bn_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_scale_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/conv_relu_fusion_pass.h"
@ -28,7 +30,7 @@
#include "tools/converter/legacy_optimizer/fusion/conv_biasadd_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/matmul_biasadd_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/format_trans_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
#include "tools/converter/legacy_optimizer/fusion/quant_cast_fusion_pass.h"
// #include "tools/converter/legacy_optimizer/fusion/batchnorm_fold_fusion_pass.h"
//
// #include "tools/converter/legacy_optimizer/const_fold/add_const_fold_pass.h"
@ -52,18 +54,45 @@
#include "tools/converter/legacy_optimizer/graph/isolated_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/unused_node_remove_pass.h"
#include "tools/converter/legacy_optimizer/graph/topological_sort_pass.h"
#include "tools/converter/quantizer/aware_quantizer.h"
#include "tools/converter/converter.h"
using std::string;
namespace mindspore {
namespace lite {
namespace mindspore::lite {
GraphDefTransform::GraphDefTransform() = default;
GraphDefTransform::~GraphDefTransform() = default;
void GraphDefTransform::SetGraphDef(schema::MetaGraphT *_dstDef) { graphDefT = _dstDef; }
void GraphDefTransform::CreateQuantizer(const converter::Flags *flags) {
auto type = flags->quantType;
switch (type) {
case QuantType::QuantType_AwareTrainning: {
MS_LOG(INFO) << "create AwareTrainningQuantizer!";
fbQuantizer =
std::make_unique<quant::AwareQuantizer>(graphDefT, flags->inputInferenceTypeIn, flags->stdDev, flags->mean);
break;
}
// case QuantType::QuantType_WeightQuant: {
// MS_LOGI("create WeightQuantizer!");
// mQuantizer.reset(new WeightQuantizer(graphDefT, flags->quantSize));
// break;
// }
// case QuantType_PostTraining: {
// MS_LOGI("create PostTrainningQuantizer!");
// mQuantizer.reset(new PostTrainingQuantizer(graphDefT, flags->configFile));
// break;
// }
// case QuantType::QuantType_QUANT_NONE:
// MS_LOGD("Not do quantization for model!");
// break;
default:
// MS_LOGI("will support quantizer type %s in the future!", flags->quantTypeIn.c_str());
break;
}
}
int GraphDefTransform::Transform(const converter::Flags &ctx) {
STATUS status;
// // constant folding
@ -133,6 +162,53 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
{
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
status = unusedOpRemoveOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
return status;
}
}
// topological sorting
{
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
}
// generate and infer quant parameters
{
if (mQuantizer != nullptr) {
Optimizer topologicalOptimizer;
topologicalOptimizer.AddPass(new (std::nothrow) TopologicalSortPass());
status = topologicalOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run topologicalOptimizer graphPasses Failed";
return status;
}
if (!(this->graphDefT->fmkType == converter::FmkType_TF &&
this->graphDefT->nodes.front()->quantType == QuantType::QuantType_AwareTrainning)) {
status = mQuantizer->GenerateQuantParam();
if (status != RET_OK) {
MS_LOG(ERROR) << "GenerateQuantParam failed";
return status;
}
status = mQuantizer->DetermineNodeQuantType();
if (status != RET_OK) {
MS_LOG(ERROR) << "DetermineNodeQuant failed";
}
}
}
}
// format transform
if (ctx.formatTrans) {
Optimizer formatTransOptimizer;
@ -156,13 +232,30 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
}
{
Optimizer unusedOpRemoveOptimizer;
unusedOpRemoveOptimizer.AddPass(new UnusedNodeRemovePass());
unusedOpRemoveOptimizer.AddPass(new IsolatedNodeRemovePass());
status = unusedOpRemoveOptimizer.Run(graphDefT);
// do quantization
if (fbQuantizer != nullptr) {
status = fbQuantizer->DoQuantize();
if (status != RET_OK) {
MS_LOG(ERROR) << "DoQuantize failed!";
return status;
}
}
// insert quantNode and deQuantNode
if (ctx.quantType == QuantType_AwareTrainning) {
Optimizer quantNodeOptimizer;
auto dTypeTransPass = new (std::nothrow) DTypeTransPass();
if (dTypeTransPass == nullptr) {
MS_LOG(ERROR) << "new dTypeTransPass failed";
return RET_ERROR;
}
dTypeTransPass->SetInputDataDType(ctx.inputInferenceType);
quantNodeOptimizer.AddPass(dTypeTransPass);
quantNodeOptimizer.AddPass(new (std::nothrow) QuantCastFusionPass());
quantNodeOptimizer.AddPass(new (std::nothrow) IsolatedNodeRemovePass());
status = quantNodeOptimizer.Run(graphDefT);
if (status != RET_OK && status != RET_NO_CHANGE) {
MS_LOG(ERROR) << "Run unusedOpRemoveOptimizer graphPasses Failed";
MS_LOG(ERROR) << "Run quantNodeOptimizer graphPasses Failed";
return status;
}
}
@ -178,6 +271,4 @@ int GraphDefTransform::Transform(const converter::Flags &ctx) {
}
return RET_OK;
}
} // namespace lite
} // namespace mindspore
} // namespace mindspore::lite

@ -17,8 +17,9 @@
#ifndef MS_GRAPHDEF_TRANSFORM_H
#define MS_GRAPHDEF_TRANSFORM_H
#include <memory>
#include "tools/converter/optimizer.h"
// #include "quantizer/quantizer.h"
#include "tools/converter/quantizer/quantizer.h"
#include "schema/inner/model_generated.h"
#include "tools/common/storage.h"
#include "tools/converter/converter_flags.h"
@ -42,7 +43,8 @@ class GraphDefTransform {
schema::MetaGraphT *graphDefT = nullptr;
Optimizer *optimizer = nullptr;
// std::unique_ptr<Quantizer> mQuantizer;
std::unique_ptr<quant::Quantizer> mQuantizer;
std::unique_ptr<quant::FbQuantizer> fbQuantizer;
};
} // namespace lite
} // namespace mindspore

@ -53,7 +53,7 @@ class MatMulBiasAddFusionPass : public FusionPass {
bool transB = false;
size_t id = 0;
OpDefCopyer TransposeOpCopyer = [](const std::unique_ptr<CNodeT> &inOpDef) -> std::unique_ptr<CNodeT> {
OpDefCopyer TransposeOpCopyer = [](CNodeT *inOpDef) -> std::unique_ptr<CNodeT> {
std::unique_ptr<CNodeT> newOpDef(new (std::nothrow) CNodeT);
if (newOpDef == nullptr) {
MS_LOG(ERROR) << "new OpDefT failed";

@ -1,5 +1,6 @@
add_library(graph_pass_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/format_trans_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/dtype_trans_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/isolated_node_remove_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/model_input_format_preprocess_pass.cc
${CMAKE_CURRENT_SOURCE_DIR}/topological_sort_pass.cc

@ -0,0 +1,235 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tools/converter/legacy_optimizer/graph/dtype_trans_pass.h"
#include <string>
#include "tools/common/converter_op_utils.h"
#include "tools/common/node_util.h"
#include "src/common/common.h"
#include "src/common/utils.h"
namespace mindspore {
namespace lite {
#define kMinInputNum 1
#define kOutputNum 1
STATUS DTypeTransPass::Run(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
auto status = DoModelInputDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoModelInputDTypeTrans error: " << status;
return status;
}
status = DoModelOutputDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoModelOutputDTypeTrans error: " << status;
return status;
}
status = DoNodeInoutDTypeTrans(graph);
if (status != RET_OK) {
MS_LOG(ERROR) << "DoNodeInoutDTypeTrans error: " << status;
return status;
}
return RET_OK;
}
STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// modify inputTensor first
auto &graphInIdxes = graph->inputIndex;
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graph->allTensors.size() > graphInIdx);
auto &graphInTensor = graph->allTensors.at(graphInIdx);
graphInTensor->dataType = TypeId::kNumberTypeUInt8;
}
if (this->inputDataDType == TypeId::kNumberTypeInt8) {
return RET_OK;
}
if (this->inputDataDType != TypeId::kNumberTypeFloat && this->inputDataDType != TypeId::kNumberTypeUInt8) {
MS_LOG(ERROR) << "Invalid inputDataType: " << this->inputDataDType;
return RET_ERROR;
}
// insert fp2int8 node
for (auto graphInIdx : graphInIdxes) {
MS_ASSERT(graphInIdx < graph->allTensors.size());
auto &tensor = graph->allTensors.at(graphInIdx);
if (tensor->dims.size() != kNHWCDimNumber) {
continue;
}
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto nodeName = node->name;
for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) {
if (node->inputIndex.at(inputIndexIdx) == graphInIdx) {
STATUS status = RET_OK;
// insert dtype cast node between input tensor and input node
if (inputDataDType == TypeId::kNumberTypeFloat) {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kFP32ToInt8, &status);
} else {
iter = InsertDTypeTransNode(graph, iter, kBefore, inputIndexIdx, kUInt8ToInt8, &status);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode before " << nodeName.c_str() << " failed";
return status;
}
}
}
}
}
return RET_OK;
}
STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
if (inputDataDType == TypeId::kNumberTypeInt8) {
return RET_OK;
}
MS_ASSERT(inputDataDType == TypeId::kNumberTypeFloat);
auto &graphOutIdxes = graph->outputIndex;
for (auto graphOutIdx : graphOutIdxes) {
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
auto &node = *iter;
auto nodeName = node->name;
MS_ASSERT(node != nullptr);
for (size_t outputIndexIdx = 0; outputIndexIdx < node->outputIndex.size(); outputIndexIdx++) {
if (node->outputIndex.at(outputIndexIdx) == graphOutIdx) {
// insert transNode
STATUS status = RET_OK;
if (inputDataDType == TypeId::kNumberTypeFloat) {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToFP32, &status);
} else {
iter = InsertDTypeTransNode(graph, iter, kAfter, outputIndexIdx, kInt8ToUInt8, &status);
}
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertDTypeTransNode after " << nodeName.c_str() << " failed";
return status;
}
break;
}
}
}
}
return RET_OK;
}
STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) {
MS_ASSERT(graph != nullptr);
// insert transNode before and after existNode
for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) {
if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTrainning) {
continue;
}
auto &node = *iter;
if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) {
continue;
}
bool needInsertPost = true;
if (GetCNodeTType(**iter) == PrimitiveType_Shape) {
needInsertPost = false;
}
auto nodeName = node->name;
if (node->inputIndex.size() < kMinInputNum) {
MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least";
return RET_ERROR;
}
STATUS status;
// insert pre
for (size_t i = 0; i < (*iter)->inputIndex.size(); i++) {
MS_ASSERT(graph->allTensors.size() > (*iter)->inputIndex.at(i));
auto &preTensor = graph->allTensors.at((*iter)->inputIndex.at(i));
auto &graphInIdxes = graph->inputIndex;
if (preTensor->nodeType == NodeType_ValueNode && !IsContain(graphInIdxes, (*iter)->inputIndex.at(i))) {
continue;
}
iter = InsertDTypeTransNode(graph, iter, kBefore, i, kInt8ToFP32, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertInt8ToFloat32Node before " << nodeName.c_str() << " failed";
return RET_ERROR;
}
}
if (needInsertPost) {
for (size_t i = 0; i < (*iter)->outputIndex.size(); i++) {
iter = InsertDTypeTransNode(graph, iter, kAfter, i, kFP32ToInt8, &status);
if (status != RET_OK) {
MS_LOG(ERROR) << "InsertFloat32ToUint8Node after " << nodeName.c_str() << " failed";
return RET_ERROR;
}
}
}
(*iter)->quantType = QuantType_QUANT_NONE;
}
return RET_OK;
}
NodeIter DTypeTransPass::InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place,
size_t inoutIdx, DTypeTransNodeType nodeType, STATUS *errorCode) {
MS_ASSERT((*existNodeIter) != nullptr);
auto existNodeName = (*existNodeIter)->name;
std::string tileName;
if (place == kBefore) {
tileName = existNodeName + "_pre";
} else {
tileName = existNodeName + "_post";
}
auto transNode = std::unique_ptr<CNodeT>(new (std::nothrow) CNodeT);
if (transNode == nullptr) {
MS_LOG(ERROR) << "new TransNode failed";
*errorCode = RET_ERROR;
return graph->nodes.end();
}
auto quantDTypeCastParam = new (std::nothrow) QuantDTypeCastT;
if (quantDTypeCastParam == nullptr) {
MS_LOG(ERROR) << "new quantDTypeCastParam failed";
*errorCode = RET_ERROR;
return graph->nodes.end();
}
transNode->primitive = std::make_unique<schema::PrimitiveT>();
transNode->primitive->value.value = quantDTypeCastParam;
transNode->primitive->value.type = PrimitiveType_QuantDTypeCast;
transNode->quantType = QuantType_AwareTrainning;
if (nodeType == kInt8ToFP32) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeFloat32;
transNode->name = "int8toft32_" + tileName + std::to_string(id++);
} else if (nodeType == kFP32ToInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeFloat32;
quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8;
transNode->name = "ft32toint8_" + tileName + std::to_string(id++);
} else if (nodeType == kUInt8ToInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeUInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeInt8;
transNode->name = "uint8toint8_" + tileName + std::to_string(id++);
} else if (nodeType == kInt8ToUInt8) {
quantDTypeCastParam->srcT = TypeId::kNumberTypeInt8;
quantDTypeCastParam->dstT = TypeId::kNumberTypeUInt8;
transNode->name = "int8touint8_" + tileName + std::to_string(id++);
}
transNode->primitive->value.value = quantDTypeCastParam;
return InsertNode(graph, existNodeIter, place, inoutIdx, std::move(transNode), errorCode, castOpCopyer);
}
void DTypeTransPass::SetInputDataDType(TypeId dataType) { this->inputDataDType = dataType; }
} // namespace lite
} // namespace mindspore

@ -0,0 +1,81 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H
#define MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H
#include <memory>
#include <utility>
#include "tools/converter/optimizer.h"
#include "tools/common/graph_util.h"
#include "tools/converter/converter_flags.h"
#include "tools/common/tensor_util.h"
namespace mindspore {
namespace lite {
enum DTypeTransNodeType { kInt8ToFP32, kFP32ToInt8, kUInt8ToInt8, kInt8ToUInt8 };
class DTypeTransPass : public GraphPass {
public:
DTypeTransPass() : id(0) {}
~DTypeTransPass() override = default;
STATUS Run(schema::MetaGraphT *graph) override;
void SetInputDataDType(TypeId dataType);
private:
STATUS DoModelInputDTypeTrans(schema::MetaGraphT *graph);
STATUS DoModelOutputDTypeTrans(schema::MetaGraphT *graph);
STATUS DoNodeInoutDTypeTrans(schema::MetaGraphT *graph);
NodeIter InsertDTypeTransNode(schema::MetaGraphT *graph, NodeIter existNodeIter, InsertPlace place, size_t inoutIdx,
DTypeTransNodeType nodeType, STATUS *errorCode);
private:
size_t id;
TypeId inputDataDType = TypeId::kNumberTypeFloat;
OpDefCopyer castOpCopyer = [](schema::CNodeT *inCNode) -> std::unique_ptr<schema::CNodeT> {
std::unique_ptr<schema::CNodeT> newCNode(new (std::nothrow) schema::CNodeT);
if (newCNode == nullptr) {
MS_LOG(ERROR) << "new CNodeT failed";
return nullptr;
}
newCNode->name = inCNode->name;
newCNode->quantType = inCNode->quantType;
newCNode->primitive = std::make_unique<schema::PrimitiveT>();
newCNode->primitive->value.type = inCNode->primitive->value.type;
auto oldQuantDTypeCastParam = inCNode->primitive->value.AsQuantDTypeCast();
auto QuantDTypeCastParam = new (std::nothrow) QuantDTypeCastT;
if (QuantDTypeCastParam == nullptr) {
MS_LOG(ERROR) << "new QuantDTypeCast failed";
return nullptr;
}
QuantDTypeCastParam->srcT = oldQuantDTypeCastParam->srcT;
QuantDTypeCastParam->dstT = oldQuantDTypeCastParam->dstT;
newCNode->primitive->value.value = QuantDTypeCastParam;
return std::move(newCNode);
};
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_PREDICT_DTYPE_TRANS_PASS_H

@ -209,6 +209,9 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
return 0;
}
// inference needed filterFormat:
// conv deconv depth dedepth
// uint8 KHWC KHWC KHWC KHWC
int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
MS_ASSERT(graphNode != nullptr);
auto &subGraph = graphNode->subGraph;
@ -227,7 +230,7 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
auto &weightTensor = subGraph->allTensors[weightIndex];
MS_ASSERT(weightTensor->dataType == kNumberTypeInt8); // DataType_DT_FLOAT
STATUS status = RET_OK;
if (opType == schema::PrimitiveType_Conv2D) { // weight should be HWCK
if (opType == schema::PrimitiveType_Conv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
@ -236,58 +239,51 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
} else {
MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex << weightTensor->format
<< weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
}
} else if (weightTensor->format == schema::Format_KHWC) { // from onnx
return RET_OK;
// if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
// status = TransFilterFormat<int8_t>(weightTensor.get(), kKHWC2HWCK);
// } else {
// status = TransFilterFormat<float>(weightTensor.get(), kKHWC2HWCK);
// }
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else {
} else if (weightTensor->format != schema::Format_KHWC) {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
}
if (status == 0) {
node->primitive->value.AsConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_HWCK;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter %sToHWCK failed, node : "
<< (weightTensor->format == schema::Format_KCHW ? "KCHW" : "KHWC"),
node->name.c_str();
MS_LOG(WARNING) << "TransFilter %sToKHWC failed, node : "
<< (weightTensor->format == schema::Format_KHWC ? "KHWC" : "KCHW") << node->name.c_str();
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be HWCK
} else if (opType == schema::PrimitiveType_DepthwiseConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_CKHW) { // from caffe
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
MS_LOG(DEBUG) << "**weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format,
weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2HWCK);
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<int8_t>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->dataType == kNumberTypeUInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCKHW2KHWC);
} else {
MS_LOG(DEBUG) << "--weight tensor index: %d, format: %d, datatype: " << weightIndex, weightTensor->format,
weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2HWCK);
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
}
} else if (weightTensor->format == schema::Format_HWCK) { // from tf
return 0;
} else if (weightTensor->format == schema::Format_CHWK) { // from onnx
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
if (weightTensor->dataType == kNumberTypeInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<int8_t>(weightTensor.get(), kCHWK2KHWC);
MS_LOG(DEBUG) << node->name << " weight trans format: CHWK->KHWC";
} else if (weightTensor->dataType == kNumberTypeUInt8) {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<uint8_t>(weightTensor.get(), kCHWK2KHWC);
} else {
MS_LOG(DEBUG) << "**weight tensor index: " << weightIndex << "format: " << weightTensor->format
<< "datatype: " << weightTensor->dataType;
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
}
} else if (weightTensor->format == schema::Format_KCHW) {
if (weightTensor->dataType == kNumberTypeInt8) { // DataType_DT_UINT8) {
status = TransFilterFormat<uint8_t>(weightTensor.get(), kKCHW2HWCK);
} else {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2HWCK);
}
} else {
} else if (weightTensor->format != schema::Format_KHWC) {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;
}
@ -295,14 +291,13 @@ int WeightFormatPass::QuantDataFormatTrans(GraphNode *graphNode) {
node->primitive->value.AsDepthwiseConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC;
} else {
MS_LOG(WARNING) << "TransFilter %ToHWCK failed, node : "
<< (weightTensor->format == schema::Format_CHWK ? "CHWK" : "CKHW"),
node->name.c_str();
MS_LOG(WARNING) << "TransFilter" << (weightTensor->format == schema::Format_KHWC ? "KHWC" : "CKHW")
<< "To KHWC failed, node : " << node->name.c_str();
// todo(00445839): consider varible weight condition
}
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be HWCK
node->primitive->value.AsDeConv2D()->format = schema::Format_NCHW;
weightTensor->format = schema::Format_CKHW;
} else { // weight should be HWCK
node->primitive->value.AsDeConv2D()->format = schema::Format_NHWC;
weightTensor->format = schema::Format_KHWC;
}
return 0;
}
@ -354,7 +349,7 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
if (graphNode->subGraph->fmkType == converter::FmkType_MS) {
weightTensor->format = schema::Format_CKHW;
}
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
if (weightTensor->format == schema::Format_CKHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kCKHW2KHWC);
} else if (weightTensor->format == schema::Format_KCHW) {
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
@ -374,8 +369,8 @@ int WeightFormatPass::NonQuantDataFormatTrans(GraphNode *graphNode) {
} else if (opType == schema::PrimitiveType_DeConv2D) { // weight should be KHWC
if (weightTensor->format == schema::Format_KCHW) { // from caffe or onnx or ms
status = TransFilterFormat<float>(weightTensor.get(), kKCHW2KHWC);
} else if (weightTensor->format == schema::Format_CHWK) { // from tf
status = TransFilterFormat<float>(weightTensor.get(), kCHWK2KHWC);
} else if (weightTensor->format == schema::Format_KHWC) { // from tf
status = RET_OK;
} else {
MS_LOG(ERROR) << "Unsupported weightTensor format: " << weightTensor->format;
return -1;

@ -40,7 +40,8 @@ class ModelParser {
}
return Fb2Anf(Parse(modelFile, weightFile));
}
virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) = 0;
virtual schema::MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) = 0;
public:
static FuncGraphPtr Fb2Anf(schema::MetaGraphT *meta_graph) {

@ -31,7 +31,8 @@ CaffeModelParser::~CaffeModelParser() {}
const std::set<std::string> CaffeModelParser::skipedLayerType = {"Dropout"};
schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile) {
schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
std::unique_ptr<schema::MetaGraphT> graph(new schema::MetaGraphT());
if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) {
@ -91,7 +92,7 @@ schema::MetaGraphT *CaffeModelParser::Parse(const std::string &modelFile, const
// ConvertCaffeBatchNorm(graph.get());
return graph.release();
// return Fb2Anf(graph.release());
// return Fb2Anf(graph.release());
}
STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op,

@ -33,7 +33,8 @@ class CaffeModelParser : public ModelParser {
virtual ~CaffeModelParser();
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override;
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
private:
void ConvertCaffeBatchNorm(MetaGraphT *meta_graphT);

@ -37,7 +37,8 @@ class OnnxModelParser : public ModelParser {
public:
OnnxModelParser();
virtual ~OnnxModelParser();
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile) override;
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
private:
TypeId GetDateTypeFromOnnx(onnx::TensorProto_DataType onnx_type);

@ -20,7 +20,6 @@
#include "tools/common/graph_util.h"
#include "tools/common/storage.h"
#include "flatbuffers/flatbuffers.h"
#include "utils/log_adapter.h"
#include "src/common/file_utils.h"
namespace mindspore {
@ -60,42 +59,64 @@ STATUS TfliteModelParser::SetAllTensors(const TensorCache &tensor_cache, schema:
}
return RET_OK;
}
void TfliteModelParser::SetMsTensorFromTflite(const std::unique_ptr<tflite::TensorT> &tflite_tensor,
schema::TensorT *tensor) {
std::unique_ptr<schema::QuantParamT> quant_param(new QuantParamT());
if (!tflite_tensor->quantization->scale.empty()) {
quant_param->scale = tflite_tensor->quantization->scale[0];
}
STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
auto dst_op = tfliteOpMap.at(tflite_op.get());
if (!tflite_tensor->quantization->zero_point.empty()) {
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0];
}
std::vector<uint32_t> quant_params_index;
quant_params_index.insert(quant_params_index.end(), tflite_op->inputs.begin(), tflite_op->inputs.end());
quant_params_index.insert(quant_params_index.end(), tflite_op->outputs.begin(), tflite_op->outputs.end());
for (const auto &index : quant_params_index) {
const auto &tflite_tensor = tflite_subgraph->tensors[index];
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << index <<" is null";
return RET_ERROR;
}
// change quant param min to 0 to fit ms-lite ops
if (tensor->dataType == TypeId::kNumberTypeInt8) {
quant_param->zeroPoint = quant_param->zeroPoint - 128;
}
if (!tflite_tensor->quantization->min.empty()) {
quant_param->min = tflite_tensor->quantization->min[0];
}
if (!tflite_tensor->quantization->max.empty()) {
quant_param->max = tflite_tensor->quantization->max[0];
}
quant_param->inited = true;
tensor->quantParams.clear();
tensor->quantParams.emplace_back(std::move(quant_param));
}
STATUS TfliteModelParser::ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
schema::CNodeT *op, TensorCache *tensor_cache) {
MS_ASSERT(op->outputIndex.size() == tflite_op->outputs.size());
for (size_t i = 0; i < tflite_op->inputs.size() && i < op->inputIndex.size(); i++) {
const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->inputs.at(i)];
if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) {
continue;
}
std::unique_ptr<schema::QuantParamT> quant_param(new schema::QuantParamT());
if (!tflite_tensor->quantization->scale.empty()) {
quant_param->scale = tflite_tensor->quantization->scale[0];
}
if (!tflite_tensor->quantization->zero_point.empty()) {
quant_param->zeroPoint = tflite_tensor->quantization->zero_point[0];
auto &inTensor = tensor_cache->GetCachedTensor().at(op->inputIndex.at(i));
if (inTensor == nullptr) {
MS_LOG(ERROR) << "Parse tflite quant params inTensor is null";
return RET_NULL_PTR;
}
if (!tflite_tensor->quantization->min.empty()) {
quant_param->min = tflite_tensor->quantization->min[0];
SetMsTensorFromTflite(tflite_tensor, inTensor);
}
for (size_t i = 0; i < tflite_op->outputs.size() && i < op->outputIndex.size(); i++) {
const auto &tflite_tensor = tflite_subgraph->tensors[tflite_op->outputs.at(i)];
if (tflite_tensor->quantization->scale.empty() && tflite_tensor->quantization->zero_point.empty() &&
tflite_tensor->quantization->min.empty() && tflite_tensor->quantization->max.empty()) {
continue;
}
if (!tflite_tensor->quantization->max.empty()) {
quant_param->max = tflite_tensor->quantization->max[0];
auto &outTensor = tensor_cache->GetCachedTensor().at(op->outputIndex.at(i));
if (outTensor == nullptr) {
MS_LOG(ERROR) << "Parse tflite quant params outTensor is null";
return RET_NULL_PTR;
}
SetMsTensorFromTflite(tflite_tensor, outTensor);
}
dst_op->quantType = schema::QuantType_AwareTrainning;
return RET_OK;
}
@ -105,11 +126,15 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT
for (const auto &index : tflite_op->outputs) {
const auto &tflite_tensor = tflite_subgraph->tensors[index];
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << index <<" is null";
MS_LOG(ERROR) << "tensor with id = " << index << " is null";
return RET_ERROR;
}
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT());
tensor->dataType = GetTfliteDataType(tflite_tensor->type);
// change dataType to int8 to fit ms-lite op
if (tensor->dataType == TypeId::kNumberTypeUInt8) {
tensor->dataType = TypeId::kNumberTypeInt8;
}
tensor->dims = tflite_tensor->shape;
tensor->nodeType = schema::NodeType_Parameter;
auto opOutputIndex = tensorCache->AddTensor(tflite_tensor->name, tensor.release(), OP_OUTPUT);
@ -120,7 +145,8 @@ STATUS TfliteModelParser::SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT
STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache) {
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensor_cache) {
auto op_type = GetTfliteNodeType(tflite_op, tflite_model);
std::vector<int32_t> op_inputs(tflite_op->inputs);
if (op_type == "DeConv2D") {
@ -130,12 +156,11 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &t
for (const auto &tflite_index : op_inputs) {
const auto &tflite_tensor = tflite_subgraph->tensors[tflite_index];
if (tflite_tensor == nullptr) {
MS_LOG(ERROR) << "tensor with id = " << tflite_index <<" is null";
MS_LOG(ERROR) << "tensor with id = " << tflite_index << " is null";
return RET_ERROR;
}
auto tensor_name = tflite_tensor->name;
auto op = tfliteOpMap[tflite_op.get()];
unsigned int index = tensorCache->FindTensor(tensor_name);
unsigned int index = tensor_cache->FindTensor(tensor_name);
if (index != -1) {
op->inputIndex.push_back(index);
}
@ -146,19 +171,20 @@ STATUS TfliteModelParser::SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &t
STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
schema::MetaGraphT *subGraph,
mindspore::lite::TensorCache *tensorCache) {
schema::MetaGraphT *subGraph, mindspore::lite::TensorCache *tensorCache,
const QuantType &quantType) {
auto i = 0;
for (const auto &tflite_op : tflite_subgraph->operators) {
auto opType = GetTfliteNodeType(tflite_op, tflite_model);
std::unique_ptr<schema::CNodeT> op(new schema::CNodeT);
op->name = opType + "-" + std::to_string(i++);
op->quantType = quantType;
MS_LOG(INFO) << "parse op: " << op->name.c_str();
auto node_parser = TfliteNodeParserRegistry::GetInstance()->GetNodeParser(opType);
if (node_parser == nullptr) {
MS_LOG(ERROR) << "cannot find node parser, opType: "<< opType.c_str();
MS_LOG(ERROR) << "cannot find node parser, opType: " << opType.c_str();
continue;
// return RET_NULL_PTR;
}
@ -172,7 +198,19 @@ STATUS TfliteModelParser::ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_
status = SetOpOutputIdx(tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "Set Op "<< op->name.c_str() << " Output Index Failed!";
MS_LOG(ERROR) << "set op " << opType.c_str() << " output index failed";
return RET_ERROR;
}
status = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "set op " << opType.c_str() << " input index failed";
return RET_ERROR;
}
status = ParseTfliteQuantParams(tflite_subgraph, tflite_op, op.get(), tensorCache);
if (status != RET_OK) {
MS_LOG(ERROR) << "parse op " << opType.c_str() << " quant parameters failed";
return RET_ERROR;
}
@ -189,8 +227,10 @@ void TfliteModelParser::SetInputTensor(const std::unique_ptr<tflite::SubGraphT>
const auto &tflite_tensor = tflite_subgraph->tensors[index];
std::unique_ptr<schema::TensorT> tensor(new schema::TensorT());
tensor->format = schema::Format_NHWC;
tensor->dataType = GetTfliteDataType(tflite_tensor->type);
tensor->nodeType = schema::NodeType_ValueNode;
tensor->dataType = GetTfliteDataType(tflite_tensor->type) != TypeId::kNumberTypeUInt8
? GetTfliteDataType(tflite_tensor->type)
: TypeId::kNumberTypeInt8;
tensor->nodeType = schema::NodeType_Parameter;
tensor->dims = tflite_tensor->shape;
tensor_cache->AddTensor(tflite_tensor->name, tensor.release(), GRAPH_INPUT);
}
@ -212,7 +252,8 @@ void TfliteModelParser::SetGraphTensorIndex(const mindspore::lite::TensorCache &
}
}
MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile) {
MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType) {
if (ValidateFileStr(modelFile, ".tflite") != RET_OK) {
MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.tflite";
return nullptr;
@ -224,7 +265,6 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
MS_LOG(ERROR) << "read tflite model failed";
return nullptr;
}
if (tflite_model->subgraphs.size() != 1) {
MS_LOG(ERROR) << "read tflite model subgraphs failed";
return nullptr;
@ -238,30 +278,15 @@ MetaGraphT *TfliteModelParser::Parse(const std::string &modelFile, const std::st
// set dst subGraph op attr and tensor_cache.
std::unique_ptr<schema::MetaGraphT> subGraph(new schema::MetaGraphT);
subGraph->name = "MS_model converted by TF-Lite";
auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache);
auto status = ParseOp(tflite_model, tflite_subgraph, subGraph.get(), &tensorCache, quantType);
if (status != RET_OK) {
MS_LOG(ERROR) << "ParseOp failed.";
return nullptr;
}
for (const auto &tflite_op : tflite_subgraph->operators) {
auto status_tmp = SetOpInputIdx(tflite_model, tflite_subgraph, tflite_op, &tensorCache);
if (status_tmp != RET_OK) {
MS_LOG(ERROR) << "Set Op " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Input Index Failed!";
}
}
for (const auto &tflite_op : tflite_subgraph->operators) {
auto statusTmp = ParseTfliteQuantParams(tflite_subgraph, tflite_op);
if (statusTmp != RET_OK) {
MS_LOG(ERROR) << "ParseTfliteQuantParams " << tfliteOpMap.at(tflite_op.get())->name.c_str() << " Failed!";
}
}
SetGraphTensorIndex(tensorCache, subGraph.get());
SetAllTensors(tensorCache, subGraph.get());
return subGraph.release();
}
} // namespace lite
} // namespace mindspore

@ -40,22 +40,25 @@ class TfliteModelParser : public ModelParser {
virtual ~TfliteModelParser();
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile);
MetaGraphT *Parse(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
private:
std::unique_ptr<tflite::ModelT> ReadTfliteModelFromFlat(const char *buf);
void SetMsTensorFromTflite(const std::unique_ptr<tflite::TensorT> &tflite_tensor, schema::TensorT *tensor);
void SetInputTensor(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, TensorCache *tensor_cache);
void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache,
schema::MetaGraphT *subGraphDef);
void SetGraphTensorIndex(const mindspore::lite::TensorCache &tensorCache, schema::MetaGraphT *subGraphDef);
STATUS ParseOp(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph, schema::MetaGraphT *sub_graph,
TensorCache *tensor_cache);
TensorCache *tensor_cache, const QuantType &quantType);
STATUS ParseTfliteQuantParams(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensor_cache);
std::string GetTfliteNodeType(const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model);
@ -63,13 +66,13 @@ class TfliteModelParser : public ModelParser {
STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *sub_graph);
STATUS SetOpOutputIdx(const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op,
schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensorCache);
STATUS SetOpInputIdx(const std::unique_ptr<tflite::ModelT> &tflite_model,
const std::unique_ptr<tflite::SubGraphT> &tflite_subgraph,
const std::unique_ptr<tflite::OperatorT> &tflite_op, TensorCache *tensorCache);
const std::unique_ptr<tflite::OperatorT> &tflite_op, schema::CNodeT *op,
TensorCache *tensor_cache);
std::map<std::string, schema::CNodeT *> opMap;
std::map<const tflite::OperatorT *, schema::CNodeT *> tfliteOpMap;

@ -4,7 +4,9 @@ include_directories(${3RD_DIR}/flatbuffers/include)
include_directories(${3RD_DIR}/opencv/build/include/opencv4)
add_library(quantizer_mid OBJECT
${CMAKE_CURRENT_SOURCE_DIR}/calc_quant_param.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/aware_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/weight_quantizer.cc
${CMAKE_CURRENT_SOURCE_DIR}/quantize_util.cc
${CMAKE_CURRENT_SOURCE_DIR}/general_bitpacking.cc

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save