From 99d6d7a07a9951c26aaa82fbcd177c7f7a4e9ef6 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Tue, 1 Sep 2020 09:49:35 +0800 Subject: [PATCH] custom parser tflite --- build.sh | 2 +- cmake/package_lite.cmake | 3 +- mindspore/lite/schema/ops.fbs | 1 + mindspore/lite/src/ops/primitive_c.cc | 5 +++ mindspore/lite/test/run_benchmark_nets.sh | 2 +- .../lite/tools/anf_exporter/anf_exporter.cc | 3 +- .../graph/dtype_trans_pass.cc | 19 +++++----- .../converter/parser/tflite/CMakeLists.txt | 3 ++ .../parser/tflite/tflite_custom_parser.cc | 35 +++++++++++++++++-- 9 files changed, 54 insertions(+), 19 deletions(-) diff --git a/build.sh b/build.sh index 6ce67d0ee5..25cae3a244 100755 --- a/build.sh +++ b/build.sh @@ -428,7 +428,7 @@ build_flatbuffer() { if [[ ! -f "${FLATC}" ]]; then git submodule update --init --recursive third_party/flatbuffers cd ${BASEPATH}/third_party/flatbuffers - rm -rf build && mkdir -pv build && cd build && cmake .. && make -j$THREAD_NUM + rm -rf build && mkdir -pv build && cd build && cmake -DFLATBUFFERS_BUILD_SHAREDLIB=ON .. && make -j$THREAD_NUM gene_flatbuffer fi if [[ "${INC_BUILD}" == "off" ]]; then diff --git a/cmake/package_lite.cmake b/cmake/package_lite.cmake index 6f220fa8a9..a8753f2204 100644 --- a/cmake/package_lite.cmake +++ b/cmake/package_lite.cmake @@ -67,6 +67,7 @@ else () install(FILES ${TOP_DIR}/mindspore/lite/build/src/libmindspore-lite.so DESTINATION ${LIB_DIR_RUN_X86} COMPONENT ${RUN_X86_COMPONENT_NAME}) install(FILES ${TOP_DIR}/third_party/protobuf/build/lib/libprotobuf.so.19.0.0 DESTINATION ${PROTOBF_DIR}/lib RENAME libprotobuf.so.19 COMPONENT ${COMPONENT_NAME}) + install(FILES ${TOP_DIR}/third_party/flatbuffers/build/libflatbuffers.so.1.11.0 DESTINATION ${FLATBF_DIR}/lib RENAME libflatbuffers.so.1 COMPONENT ${COMPONENT_NAME}) endif () if (CMAKE_SYSTEM_NAME MATCHES "Windows") @@ -89,4 +90,4 @@ else () set(CPACK_PACKAGE_DIRECTORY ${TOP_DIR}/output/tmp) endif() set(CPACK_PACKAGE_CHECKSUM SHA256) -include(CPack) \ No newline at end of file +include(CPack) diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 752d018d4f..aa51fbfe2b 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -355,6 +355,7 @@ table DetectionPostProcess { MaxClassesPreDetection: long; NumClasses: long; UseRegularNms: bool; + OutQuantized: bool; } table FullConnection { diff --git a/mindspore/lite/src/ops/primitive_c.cc b/mindspore/lite/src/ops/primitive_c.cc index 94a8b26527..2f6150cb7c 100644 --- a/mindspore/lite/src/ops/primitive_c.cc +++ b/mindspore/lite/src/ops/primitive_c.cc @@ -120,6 +120,7 @@ #include "src/ops/tuple_get_item.h" #include "src/ops/l2_norm.h" #include "src/ops/sparse_to_dense.h" +#include "src/ops/detection_post_process.h" namespace mindspore { namespace lite { @@ -467,6 +468,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitiveT(mindspore::schema::PrimitiveT return new L2Norm(primitive); case schema::PrimitiveType_SparseToDense: return new SparseToDense(primitive); + case schema::PrimitiveType_DetectionPostProcess: + return new DetectionPostProcess(primitive); default: MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitiveT : " << schema::EnumNamePrimitiveType(op_type); @@ -679,6 +682,8 @@ PrimitiveC *PrimitiveC::UnPackFromSchemaPrimitive(const schema::Primitive *primi return NewPrimitiveC(primitive); case schema::PrimitiveType_SparseToDense: return NewPrimitiveC(primitive); + case schema::PrimitiveType_DetectionPostProcess: + return NewPrimitiveC(primitive); default: MS_LOG(ERROR) << "Unsupported primitive type in UnPackFromSchemaPrimitive : " << schema::EnumNamePrimitiveType(op_type); diff --git a/mindspore/lite/test/run_benchmark_nets.sh b/mindspore/lite/test/run_benchmark_nets.sh index 811e1dfc32..abc8c9bf7a 100644 --- a/mindspore/lite/test/run_benchmark_nets.sh +++ b/mindspore/lite/test/run_benchmark_nets.sh @@ -377,7 +377,7 @@ tar -zxf mindspore-lite-${version}-runtime-x86-${process_unit_x86}.tar.gz || exi tar -zxf mindspore-lite-${version}-converter-ubuntu.tar.gz || exit 1 cd ${convertor_path}/mindspore-lite-${version}-converter-ubuntu || exit 1 cp converter/converter_lite ./ || exit 1 -export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./lib/:./third_party/protobuf/lib +export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:./lib/:./third_party/protobuf/lib:./third_party/flatbuffers/lib # Convert the models cd ${convertor_path}/mindspore-lite-${version}-converter-ubuntu || exit 1 diff --git a/mindspore/lite/tools/anf_exporter/anf_exporter.cc b/mindspore/lite/tools/anf_exporter/anf_exporter.cc index 07f4178664..af2ac37975 100644 --- a/mindspore/lite/tools/anf_exporter/anf_exporter.cc +++ b/mindspore/lite/tools/anf_exporter/anf_exporter.cc @@ -98,8 +98,7 @@ int AnfExporter::ConvertQuantParam(const std::unique_ptr &me auto output_quant_params = primitive->GetOutputQuantParams(); if (output_quant_params.empty()) { if (node_type != schema::PrimitiveType_QuantDTypeCast) { - MS_LOG(ERROR) << "node: " << dst_node->name << " output quant params is empty"; - return RET_ERROR; + MS_LOG(DEBUG) << "node: " << dst_node->name << " output quant params is empty"; } } else { for (auto output_quant_param : output_quant_params[0]) { diff --git a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc index d8fbdd846f..d39b0caa16 100644 --- a/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc +++ b/mindspore/lite/tools/converter/legacy_optimizer/graph/dtype_trans_pass.cc @@ -75,10 +75,9 @@ STATUS DTypeTransPass::DoModelInputDTypeTrans(schema::MetaGraphT *graph) { } for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - auto &node = *iter; - auto nodeName = node->name; - for (size_t inputIndexIdx = 0; inputIndexIdx < node->inputIndex.size(); inputIndexIdx++) { - if (node->inputIndex.at(inputIndexIdx) == graphInIdx) { + auto nodeName = (*iter)->name; + for (size_t inputIndexIdx = 0; inputIndexIdx < (*iter)->inputIndex.size(); inputIndexIdx++) { + if ((*iter)->inputIndex.at(inputIndexIdx) == graphInIdx) { STATUS status = RET_OK; // insert dtype cast node between input tensor and input node @@ -108,11 +107,10 @@ STATUS DTypeTransPass::DoModelOutputDTypeTrans(schema::MetaGraphT *graph) { auto &graphOutIdxes = graph->outputIndex; for (auto graphOutIdx : graphOutIdxes) { for (auto iter = graph->nodes.begin(); iter != graph->nodes.end(); iter++) { - auto &node = *iter; - auto nodeName = node->name; + auto nodeName = (*iter)->name; MS_ASSERT(node != nullptr); - for (size_t outputIndexIdx = 0; outputIndexIdx < node->outputIndex.size(); outputIndexIdx++) { - if (node->outputIndex.at(outputIndexIdx) == graphOutIdx) { + for (size_t outputIndexIdx = 0; outputIndexIdx < (*iter)->outputIndex.size(); outputIndexIdx++) { + if ((*iter)->outputIndex.at(outputIndexIdx) == graphOutIdx) { // insert transNode STATUS status = RET_OK; if (inputDataDType == TypeId::kNumberTypeFloat) { @@ -139,7 +137,6 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { if (IsContain(GetUint8OpList(), GetCNodeTType(**iter)) && (*iter)->quantType == QuantType_AwareTraining) { continue; } - auto &node = *iter; if (GetCNodeTType(**iter) == PrimitiveType_QuantDTypeCast) { continue; } @@ -147,8 +144,8 @@ STATUS DTypeTransPass::DoNodeInoutDTypeTrans(schema::MetaGraphT *graph) { if (GetCNodeTType(**iter) == PrimitiveType_Shape) { needInsertPost = false; } - auto nodeName = node->name; - if (node->inputIndex.size() < kMinInputNum) { + auto nodeName = (*iter)->name; + if ((*iter)->inputIndex.size() < kMinInputNum) { MS_LOG(ERROR) << "Op " << nodeName.c_str() << " should have " << kMinInputNum << " input tensor at least"; return RET_ERROR; } diff --git a/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt b/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt index 03f9b3670b..cd8ef76b09 100644 --- a/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt +++ b/mindspore/lite/tools/converter/parser/tflite/CMakeLists.txt @@ -1,6 +1,9 @@ file(GLOB_RECURSE TFLITE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} *.cc ) +ADD_DEFINITIONS(-DFLATBUFFERS_LOCALE_INDEPENDENT=1) +find_library(FLATBUFFERS_LIBRARY flatbuffers HINTS ${TOP_DIR}/third_party/flatbuffers/build) add_library(tflite_parser_mid OBJECT ${TFLITE_SRC_LIST} ) +target_link_libraries(tflite_parser_mid ${FLATBUFFERS_LIBRARY}) diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc index fcdb7599f3..1ee985cf01 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_custom_parser.cc @@ -18,6 +18,8 @@ #include #include #include +#include "flatbuffers/flatbuffers.h" +#include "flatbuffers/flexbuffers.h" namespace mindspore { namespace lite { @@ -39,15 +41,42 @@ STATUS TfliteCustomParser::Parse(const std::unique_ptr &tflit return RET_NULL_PTR; } - std::unique_ptr attr = std::make_unique(); + std::unique_ptr attr = std::make_unique(); if (attr == nullptr) { MS_LOG(ERROR) << "new op failed"; return RET_NULL_PTR; } const auto &custom_attr = tflite_op->custom_options; - attr->custom = custom_attr; - op->primitive->value.type = schema::PrimitiveType_Custom; + auto attr_map = flexbuffers::GetRoot(custom_attr).AsMap(); + attr->format = schema::Format_NHWC; + attr->inputSize = tflite_op->inputs.size(); + attr->hScale = attr_map["h_scale"].AsFloat(); + attr->wScale = attr_map["w_scale"].AsFloat(); + attr->xScale = attr_map["x_scale"].AsFloat(); + attr->yScale = attr_map["y_scale"].AsFloat(); + attr->NmsIouThreshold = attr_map["nms_iou_threshold"].AsFloat(); + attr->NmsScoreThreshold = attr_map["nms_score_threshold"].AsFloat(); + attr->MaxDetections = attr_map["max_detections"].AsInt32(); + if (attr_map["detections_per_class"].IsNull()) { + attr->DetectionsPreClass = 100; + } else { + attr->DetectionsPreClass = attr_map["detections_per_class"].AsInt32(); + } + attr->MaxClassesPreDetection = attr_map["max_classes_per_detection"].AsInt32(); + attr->NumClasses = attr_map["num_classes"].AsInt32(); + if (attr_map["use_regular_nms"].IsNull()) { + attr->UseRegularNms = false; + } else { + attr->UseRegularNms = attr_map["use_regular_nms"].AsBool(); + } + if (attr_map["_output_quantized"].IsNull()) { + attr->OutQuantized = false; + } else { + attr->OutQuantized = attr_map["_output_quantized"].AsBool(); + } + + op->primitive->value.type = schema::PrimitiveType_DetectionPostProcess; op->primitive->value.value = attr.release(); for (size_t i = 0; i < tflite_op->inputs.size(); ++i) {