From 40cb36efbfb48e915a7b02e1f1c73adda2de6c8b Mon Sep 17 00:00:00 2001 From: xuanyue Date: Sun, 27 Sep 2020 15:59:04 +0800 Subject: [PATCH] fix onnx weightquant and add tflite custom op --- mindspore/lite/schema/model.fbs | 5 + mindspore/lite/schema/ops.fbs | 24 ++++ .../parser/onnx/onnx_model_parser.cc | 9 +- .../converter/parser/onnx/onnx_model_parser.h | 2 +- .../parser/tflite/tflite_custom_parser.cc | 126 +++++++++++++++--- .../parser/tflite/tflite_custom_parser.h | 18 +++ .../parser/tflite/tflite_model_parser.cc | 7 +- .../converter/parser/tflite/tflite_util.cc | 5 +- .../optimizer/fusion/conv_biasadd_fusion.cc | 2 +- 9 files changed, 171 insertions(+), 27 deletions(-) diff --git a/mindspore/lite/schema/model.fbs b/mindspore/lite/schema/model.fbs index 094d236c2a..ba529639d0 100644 --- a/mindspore/lite/schema/model.fbs +++ b/mindspore/lite/schema/model.fbs @@ -207,6 +207,11 @@ union PrimitiveType { LshProjection, HashtableLookup, SkipGram, + CustomPredict, + CustomNormalize, + CustomExtractFeatures, + AudioSpectrogram, + Mfcc, } enum QuantType: int { diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 60cf1095f9..75fbb7f844 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -963,3 +963,27 @@ table SkipGram { maxSkipSize : int; ngramSize : int; } + +table CustomPredict { + outputNum : int; + weightThreshold : float; +} + +table CustomNormalize { +} + +table CustomExtractFeatures { +} + +table AudioSpectrogram { + windowSize : int; + stride : int; + magSquare : bool; +} + +table Mfcc { + freqUpperLimit : float; + freqLowerLimit : float; + filterBankChannelNum : int; + dctCoeffNum : int; +} \ No newline at end of file diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index a367b67605..d3d3e48fcb 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -172,9 +172,11 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, } void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::MetaGraphT *graph, TensorCache *tensor_cache) { + schema::MetaGraphT *graph, TensorCache *tensor_cache, + const QuantType &quant_type) { std::unique_ptr dst_op_1 = std::make_unique(); dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); + dst_op_1->quantType = quant_type; ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); auto matmul_output_id = "Gemm_MatMul_" + onnx_node.output(0); std::vector matmul_inputs{onnx_node.input(0), onnx_node.input(1)}; @@ -185,6 +187,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons std::unique_ptr dst_op_2 = std::make_unique(); dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0); + dst_op_2->quantType = quant_type; ParseOnnxNodeAttr(onnx_graph, onnx_node, "BiasAdd", dst_op_2.get()); std::vector biasadd_inputs{matmul_output_id, onnx_node.input(2)}; std::vector biasadd_outputs{onnx_node.output(0)}; @@ -343,8 +346,6 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const } if (findQuantParams == needQuantParams) { dst_op->quantType = schema::QuantType_AwareTraining; - } else { - dst_op->quantType = schema::QuantType_QUANT_NONE; } } @@ -520,7 +521,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con } if (onnx_node.op_type() == "Gemm") { if (status == RET_OK) { - ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache); + ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache, quantType); } continue; } else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") { diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 7b7b952a8c..abc4bccf8b 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -65,7 +65,7 @@ class OnnxModelParser : public ModelParser { const QuantType &quantType); void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, - schema::MetaGraphT *graph, TensorCache *tensor_cache); + schema::MetaGraphT *graph, TensorCache *tensor_cache, const QuantType &quant_type); STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc index de61798b34..91d1a0fe68 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc @@ -23,26 +23,14 @@ namespace mindspore { namespace lite { -STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, - const std::unique_ptr &tflite_model, schema::CNodeT *op) { - MS_LOG(DEBUG) << "parse TfliteCustomParser"; - if (op == nullptr) { - MS_LOG(ERROR) << "op is null"; - return RET_NULL_PTR; - } - op->primitive = std::make_unique(); - if (op->primitive == nullptr) { - MS_LOG(ERROR) << "op->primitive is null"; - return RET_NULL_PTR; - } - +STATUS TfliteCustomParser::DetectPostProcess(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op) { std::unique_ptr attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } - const auto &custom_attr = tflite_op->custom_options; auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap(); attr->format = schema::Format::Format_NHWC; attr->inputSize = tflite_op->inputs.size(); @@ -73,7 +61,115 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni op->primitive->value.type = schema::PrimitiveType_DetectionPostProcess; op->primitive->value.value = attr.release(); + return RET_OK; +} +STATUS TfliteCustomParser::AudioSpectrogram(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap(); + attr->windowSize = attr_map["window_size"].AsInt64(); + attr->stride = attr_map["stride"].AsInt64(); + attr->magSquare = attr_map["magnitude_squared"].AsBool(); + + op->primitive->value.type = schema::PrimitiveType_AudioSpectrogram; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +STATUS TfliteCustomParser::Mfcc(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap(); + attr->freqUpperLimit = attr_map["upper_frequency_limit"].AsInt64(); + attr->freqLowerLimit = attr_map["lower_frequency_limit"].AsInt64(); + attr->filterBankChannelNum = attr_map["filterbank_channel_count"].AsInt64(); + attr->dctCoeffNum = attr_map["dct_coefficient_count"].AsInt64(); + + op->primitive->value.type = schema::PrimitiveType_Mfcc; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +STATUS TfliteCustomParser::Predict(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + attr->outputNum = reinterpret_cast(custom_attr.data())[0]; + attr->weightThreshold = reinterpret_cast(custom_attr.data())[1]; + op->primitive->value.type = schema::PrimitiveType_CustomPredict; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +STATUS TfliteCustomParser::Normalize(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + op->primitive->value.type = schema::PrimitiveType_CustomNormalize; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +STATUS TfliteCustomParser::ExtractFeatures(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op) { + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + op->primitive->value.type = schema::PrimitiveType_CustomExtractFeatures; + op->primitive->value.value = attr.release(); + return RET_OK; +} + +STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, + const std::unique_ptr &tflite_model, schema::CNodeT *op) { + MS_LOG(DEBUG) << "parse TfliteCustomParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + const auto &custom_attr = tflite_op->custom_options; + const auto &opcode_index = tflite_op->opcode_index; + const auto &custom_type = tflite_model->operator_codes[opcode_index]->custom_code; + int status = RET_OK; + if (custom_type == "TFLite_Detection_PostProcess") { + status = DetectPostProcess(custom_attr, op, tflite_op); + } else if (custom_type == "Predict") { + status = Predict(custom_attr, op, tflite_op); + } else if (custom_type == "Normalize") { + status = Normalize(custom_attr, op, tflite_op); + } else if (custom_type == "ExtractFeatures") { + status = ExtractFeatures(custom_attr, op, tflite_op); + } else if (custom_type == "AudioSpectrogram") { + status = AudioSpectrogram(custom_attr, op, tflite_op); + } else { + MS_LOG(ERROR) << "the custom op hasn't been supported now"; + status = RET_NOT_FIND_OP; + } + if (status != RET_OK) { + return status; + } for (size_t i = 0; i < tflite_op->inputs.size(); ++i) { AddOpInput(op, tensors_info, tflite_op->inputs[i], tflite_model->subgraphs[0]->tensors.size(), schema::Format::Format_NHWC); @@ -82,7 +178,7 @@ STATUS TfliteCustomParser::Parse(TfliteTensorsInfo *tensors_info, const std::uni AddOpOutput(op, tensors_info, tflite_op->outputs[i], tflite_model->subgraphs[0]->tensors.size(), schema::Format::Format_NHWC); } - return RET_OK; + return status; } TfliteNodeRegister g_tfliteCustomParser("Custom", new TfliteCustomParser()); diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h index b39188ae4f..91a0c7a669 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.h @@ -31,6 +31,24 @@ class TfliteCustomParser : public TfliteNodeParser { STATUS Parse(TfliteTensorsInfo *tensors_info, const std::unique_ptr &tflite_op, const std::unique_ptr &tflite_model, schema::CNodeT *op) override; + + STATUS DetectPostProcess(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op); + + STATUS AudioSpectrogram(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op); + + STATUS Mfcc(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op); + + STATUS Predict(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op); + + STATUS Normalize(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op); + + STATUS ExtractFeatures(const std::vector &custom_attr, schema::CNodeT *op, + const std::unique_ptr &tflite_op); }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 1837a8d39f..11a5151623 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -47,14 +47,11 @@ std::unique_ptr TfliteModelParser::ReadTfliteModel(const char *m STATUS TfliteModelParser::CopyConstTensorData(const std::vector> &tflite_model_buffer, const tflite::TensorT *tflite_tensor, schema::TensorT *tensor) { - auto count = 1; - std::for_each(tflite_tensor->shape.begin(), tflite_tensor->shape.end(), [&](int32_t sha) { count *= sha; }); - auto data_size = count * GetDataTypeSize(TypeId(tensor->dataType)); auto buffer_idx = tflite_tensor->buffer; if (!tflite_model_buffer[buffer_idx]->data.empty()) { + auto data_size = tflite_model_buffer[buffer_idx]->data.size(); tensor->data.resize(data_size); - if (memcpy_s(tensor->data.data(), tensor->data.size(), tflite_model_buffer[buffer_idx]->data.data(), - tflite_model_buffer[buffer_idx]->data.size())) { + if (memcpy_s(tensor->data.data(), data_size, tflite_model_buffer[buffer_idx]->data.data(), data_size) != EOK) { MS_LOG(ERROR) << "memcpy tensor data failed"; return RET_MEMORY_FAILED; } diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc index 25a166d093..0b103d0d12 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_util.cc @@ -120,6 +120,9 @@ std::map tfMsOpTypeMap{ {tflite::BuiltinOperator_MIRROR_PAD, "MirrorPad"}, {tflite::BuiltinOperator_NEG, "Neg"}, {tflite::BuiltinOperator_PRELU, "PRELU"}, + {tflite::BuiltinOperator_HASHTABLE_LOOKUP, "HashtableLookup"}, + {tflite::BuiltinOperator_LSH_PROJECTION, "LshProjection"}, + {tflite::BuiltinOperator_SKIP_GRAM, "SKipGram"}, }; std::map tfMsActivationFunctionMap{ @@ -134,7 +137,7 @@ std::map type_map = { {tflite::TensorType_FLOAT16, TypeId::kNumberTypeFloat16}, {tflite::TensorType_INT32, TypeId::kNumberTypeInt32}, {tflite::TensorType_INT16, TypeId::kNumberTypeInt16}, {tflite::TensorType_INT8, TypeId::kNumberTypeInt8}, {tflite::TensorType_INT64, TypeId::kNumberTypeInt64}, {tflite::TensorType_UINT8, TypeId::kNumberTypeUInt8}, - {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, + {tflite::TensorType_BOOL, TypeId::kNumberTypeBool}, {tflite::TensorType_STRING, TypeId::kObjectTypeString}, }; schema::ActivationType GetActivationFunctionType(tflite::ActivationFunctionType tfliteAFType) { diff --git a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc index 5c32e579e1..0d5c5bc948 100644 --- a/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc +++ b/mindspore/lite/tools/optimizer/fusion/conv_biasadd_fusion.cc @@ -117,7 +117,7 @@ int GenConvNewBias(const FuncGraphPtr &func_graph, const CNodePtr &conv_node, co } } else { if (EOK != memcpy_s(add_bias_data, kernel_nums * sizeof(float), add_weight_data, kernel_nums * sizeof(float))) { - MS_LOG(ERROR) << "memset_s conv_bias_data failed"; + MS_LOG(ERROR) << "memcpy_s conv_bias_data failed"; delete[] add_bias_data; return lite::RET_MEMORY_FAILED; }