!6931 MSLITE fix onnx weightquant and add tflite custom op

Merge pull request !6931 from 徐安越/master
pull/6931/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 83d13d4c87

@ -207,6 +207,11 @@ union PrimitiveType {
LshProjection,
HashtableLookup,
SkipGram,
CustomPredict,
CustomNormalize,
CustomExtractFeatures,
AudioSpectrogram,
Mfcc,
}
enum QuantType: int {

@ -963,3 +963,27 @@ table SkipGram {
maxSkipSize : int;
ngramSize : int;
}
table CustomPredict {
outputNum : int;
weightThreshold : float;
}
table CustomNormalize {
}
table CustomExtractFeatures {
}
table AudioSpectrogram {
windowSize : int;
stride : int;
magSquare : bool;
}
table Mfcc {
freqUpperLimit : float;
freqLowerLimit : float;
filterBankChannelNum : int;
dctCoeffNum : int;
}

@ -172,9 +172,11 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
}
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph, TensorCache *tensor_cache) {
schema::MetaGraphT *graph, TensorCache *tensor_cache,
const QuantType &quant_type) {
std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>();
dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0);
dst_op_1->quantType = quant_type;
ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get());
auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0);
std::vector<string> matmul_inputs{onnx_node.input(0), onnx_node.input(1)};
@ -185,6 +187,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
std::unique_ptr<schema::CNodeT> dst_op_2 = std::make_unique<schema::CNodeT>();
dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0);
dst_op_2->quantType = quant_type;
ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get());
std::vector<string> biasadd_inputs{matmul_output_id, onnx_node.input(2)};
std::vector<string> biasadd_outputs{onnx_node.output(0)};
@ -343,8 +346,6 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const
}
if (findQuantParams == needQuantParams) {
dst_op->quantType = schema::QuantType_AwareTraining;
} else {
dst_op->quantType = schema::QuantType_QUANT_NONE;
}
}
@ -520,7 +521,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
}
if (onnx_node.op_type() == "Gemm") {
if (status == RET_OK) {
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache);
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache, quantType);
}
continue;
} else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") {

@ -65,7 +65,7 @@ class OnnxModelParser : public ModelParser {
const QuantType &quantType);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph, TensorCache *tensor_cache);
schema::MetaGraphT *graph, TensorCache *tensor_cache, const QuantType &quant_type);
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);

@ -23,26 +23,14 @@
namespace mindspore {
namespace lite {
STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteCustomParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
STATUS TfliteCustomParser::DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
std::unique_ptr<schema::DetectionPostProcessT> attr = std::make_unique<schema::DetectionPostProcessT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
const auto &custom_attr = tflite_op->custom_options;
auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap();
attr->format = schema::Format::Format_NHWC;
attr->inputSize = tflite_op->inputs.size();
@ -73,7 +61,115 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
op->primitive->value.type = schema::PrimitiveType_DetectionPostProcess;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TfliteCustomParser::AudioSpectrogram(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
std::unique_ptr<schema::AudioSpectrogramT> attr = std::make_unique<schema::AudioSpectrogramT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap();
attr->windowSize = attr_map["window_size"].AsInt64();
attr->stride = attr_map["stride"].AsInt64();
attr->magSquare = attr_map["magnitude_squared"].AsBool();
op->primitive->value.type = schema::PrimitiveType_AudioSpectrogram;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TfliteCustomParser::Mfcc(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
std::unique_ptr<schema::MfccT> attr = std::make_unique<schema::MfccT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap();
attr->freqUpperLimit = attr_map["upper_frequency_limit"].AsInt64();
attr->freqLowerLimit = attr_map["lower_frequency_limit"].AsInt64();
attr->filterBankChannelNum = attr_map["filterbank_channel_count"].AsInt64();
attr->dctCoeffNum = attr_map["dct_coefficient_count"].AsInt64();
op->primitive->value.type = schema::PrimitiveType_Mfcc;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TfliteCustomParser::Predict(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
std::unique_ptr<schema::CustomPredictT> attr = std::make_unique<schema::CustomPredictT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->outputNum = reinterpret_cast<const int *>(custom_attr.data())[0];
attr->weightThreshold = reinterpret_cast<const float *>(custom_attr.data())[1];
op->primitive->value.type = schema::PrimitiveType_CustomPredict;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TfliteCustomParser::Normalize(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
std::unique_ptr<schema::CustomNormalizeT> attr = std::make_unique<schema::CustomNormalizeT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_CustomNormalize;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TfliteCustomParser::ExtractFeatures(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op) {
std::unique_ptr<schema::CustomExtractFeaturesT> attr = std::make_unique<schema::CustomExtractFeaturesT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
op->primitive->value.type = schema::PrimitiveType_CustomExtractFeatures;
op->primitive->value.value = attr.release();
return RET_OK;
}
STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) {
MS_LOG(DEBUG) << "parse TfliteCustomParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
const auto &custom_attr = tflite_op->custom_options;
const auto &opcode_index = tflite_op->opcode_index;
const auto &custom_type = tflite_model->operator_codes[opcode_index]->custom_code;
int status = RET_OK;
if (custom_type == "TFLite_Detection_PostProcess") {
status = DetectPostProcess(custom_attr, op, tflite_op);
} else if (custom_type == "Predict") {
status = Predict(custom_attr, op, tflite_op);
} else if (custom_type == "Normalize") {
status = Normalize(custom_attr, op, tflite_op);
} else if (custom_type == "ExtractFeatures") {
status = ExtractFeatures(custom_attr, op, tflite_op);
} else if (custom_type == "AudioSpectrogram") {
status = AudioSpectrogram(custom_attr, op, tflite_op);
} else {
MS_LOG(ERROR) << "the custom op hasn't been supported now";
status = RET_NOT_FIND_OP;
}
if (status != RET_OK) {
return status;
}
for (size_t i = 0; i < tflite_op->inputs.size(); ++i) {
AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
@ -82,7 +178,7 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni
AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(),
schema::Format::Format_NHWC);
}
return RET_OK;
return status;
}
TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser());

@ -31,6 +31,24 @@ class TfliteCustomParser : public TfliteNodeParser {
STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr<tflite::OperatorT> &tflite_op,
const std::unique_ptr<tflite::ModelT> &tflite_model, schema::CNodeT *op) override;
STATUS DetectPostProcess(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
STATUS AudioSpectrogram(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
STATUS Mfcc(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
STATUS Predict(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
STATUS Normalize(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
STATUS ExtractFeatures(const std::vector<uint8_t> &custom_attr, schema::CNodeT *op,
const std::unique_ptr<tflite::OperatorT> &tflite_op);
};
} // namespace lite
} // namespace mindspore

@ -47,14 +47,11 @@ std::unique_ptr<tflite::ModelT> TfliteModelParser::ReadTfliteModel(const char *m
STATUS TfliteModelParser::CopyConstTensorData(const std::vector<std::unique_ptr<tflite::BufferT>> &tflite_model_buffer,
const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) {
auto count = 1;
std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; });
auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType));
auto buffer_idx = tflite_tensor->buffer;
if (!tflite_model_buffer[buffer_idx]->data.empty()) {
auto data_size = tflite_model_buffer[buffer_idx]->data.size();
tensor->data.resize(data_size);
if (memcpy_s(tensor->data.data(), tensor->data.size(), tflite_model_buffer[buffer_idx]->data.data(),
tflite_model_buffer[buffer_idx]->data.size())) {
if (memcpy_s(tensor->data.data(), data_size, tflite_model_buffer[buffer_idx]->data.data(), data_size) != EOK) {
MS_LOG(ERROR) << "memcpy tensor data failed";
return RET_MEMORY_FAILED;
}

@ -120,6 +120,9 @@ std::map<tflite::BuiltinOperator, std::string> tfMsOpTypeMap{
{tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"},
{tflite::BuiltinOperator_NEG, "Neg"},
{tflite::BuiltinOperator_PRELU, "PRELU"},
{tflite::BuiltinOperator_HASHTABLE_LOOKUP, "HashtableLookup"},
{tflite::BuiltinOperator_LSH_PROJECTION, "LshProjection"},
{tflite::BuiltinOperator_SKIP_GRAM, "SKipGram"},
};
std::map<tflite::ActivationFunctionType, schema::ActivationType> tfMsActivationFunctionMap{
@ -134,7 +137,7 @@ std::map<int, TypeId> type_map = {
{tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, {tflite::TensorType_INT32, TypeId::kNumberTypeInt32},
{tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8},
{tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8},
{tflite::TensorType_BOOL, TypeId::kNumberTypeBool},
{tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, {tflite::TensorType_STRING, TypeId::kObjectTypeString},
};
schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) {

@ -117,7 +117,7 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co
}
} else {
if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) {
MS_LOG(ERROR) << "memset_s conv_bias_data failed";
MS_LOG(ERROR) << "memcpy_s conv_bias_data failed";
delete[] add_bias_data;
return lite::RET_MEMORY_FAILED;
}

Loading…
Cancel
Save