From 5a6c358d2ecc1ee282c96285b24872c52db84f28 Mon Sep 17 00:00:00 2001 From: hangq Date: Fri, 28 Aug 2020 21:40:29 +0800 Subject: [PATCH] fix onnx | mindir read protobuf bug in windows --- mindspore/lite/test/CMakeLists.txt | 1 + .../anf_importer/import_from_protobuf.cc | 26 +---- .../protobuf_utils.cc} | 11 +- .../protobuf_utils.h} | 7 +- mindspore/lite/tools/converter/CMakeLists.txt | 1 + .../converter/parser/caffe/CMakeLists.txt | 1 - .../converter/parser/caffe/caffe_converter.cc | 1 - .../parser/caffe/caffe_model_parser.cc | 37 +++--- .../parser/onnx/onnx_model_parser.cc | 107 +++++------------- .../converter/parser/onnx/onnx_model_parser.h | 101 ++++++----------- 10 files changed, 88 insertions(+), 205 deletions(-) rename mindspore/lite/tools/{converter/parser/caffe/caffe_parse_utils.cc => common/protobuf_utils.cc} (88%) rename mindspore/lite/tools/{converter/parser/caffe/caffe_parse_utils.h => common/protobuf_utils.h} (85%) mode change 100755 => 100644 mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc diff --git a/mindspore/lite/test/CMakeLists.txt b/mindspore/lite/test/CMakeLists.txt index 8e34bfe1df..74028c8b3d 100644 --- a/mindspore/lite/test/CMakeLists.txt +++ b/mindspore/lite/test/CMakeLists.txt @@ -230,6 +230,7 @@ if(BUILD_CONVERTER) ${TEST_LITE_SRC} ${TEST_CASE_TFLITE_PARSERS_SRC} ${TOP_DIR}/mindspore/core/utils/flags.cc + ${LITE_DIR}/tools/common/protobuf_utils.cc ${LITE_DIR}/tools/converter/optimizer.cc ${LITE_DIR}/tools/converter/anf_transform.cc ${LITE_DIR}/tools/converter/graphdef_transform.cc diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc index e033296d21..85f8465c49 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc @@ -27,7 +27,6 @@ #include #include "src/ops/primitive_c.h" #include "frontend/operator/ops.h" -#include "google/protobuf/io/zero_copy_stream_impl.h" #include "include/errorcode.h" #include "ir/anf.h" #include "ir/func_graph.h" @@ -37,6 +36,7 @@ #include "src/param_value_lite.h" #include "tools/converter/parser/onnx/onnx.pb.h" #include "utils/log_adapter.h" +#include "tools/common/protobuf_utils.h" using string = std::string; using int32 = int32_t; @@ -651,31 +651,11 @@ int AnfImporterFromProtobuf::Import(const schema::QuantType &quantType) { } onnx::ModelProto *AnfImporterFromProtobuf::ReadOnnxFromBinary(const std::string &model_path) { - std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); -#ifdef _WIN32 - if (_fullpath(onnx_file.get(), model_path.c_str(), 1024) == nullptr) { - MS_LOG(ERROR) << "open file failed."; - return nullptr; - } -#else - if (realpath(model_path.c_str(), onnx_file.get()) == nullptr) { - MS_LOG(ERROR) << "open file failed."; - return nullptr; - } -#endif - int fd = open(onnx_file.get(), O_RDONLY); - google::protobuf::io::FileInputStream input(fd); - google::protobuf::io::CodedInputStream code_input(&input); - code_input.SetTotalBytesLimit(INT_MAX, 536870912); auto onnx_model = new onnx::ModelProto; - bool ret = onnx_model->ParseFromCodedStream(&code_input); - if (!ret) { - MS_LOG(ERROR) << "load onnx file failed"; - delete onnx_model; + if (ReadProtoFromBinaryFile((const char *)model_path.c_str(), onnx_model) != RET_OK) { + MS_LOG(ERROR) << "Read onnx model file failed, model path: " << model_path; return nullptr; } - (void)close(fd); - MS_LOG(INFO) << "enter ReadProtoFromBinary success!" << std::endl; return onnx_model; } diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.cc b/mindspore/lite/tools/common/protobuf_utils.cc similarity index 88% rename from mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.cc rename to mindspore/lite/tools/common/protobuf_utils.cc index b1222376f6..4023845d8f 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.cc +++ b/mindspore/lite/tools/common/protobuf_utils.cc @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" +#include "tools/common/protobuf_utils.h" #include #include #include "google/protobuf/io/zero_copy_stream_impl.h" @@ -37,15 +37,14 @@ bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded return proto->ParseFromCodedStream(coded_stream); } -STATUS ReadProtoFromText(const char *file, - google::protobuf::Message *message) { +STATUS ReadProtoFromText(const char *file, google::protobuf::Message *message) { if (file == nullptr || message == nullptr) { return RET_ERROR; } std::string realPath = RealPath(file); if (realPath.empty()) { - MS_LOG(ERROR) << "Proto file path " << file <<" is not valid"; + MS_LOG(ERROR) << "Proto file path " << file << " is not valid"; return RET_ERROR; } @@ -67,8 +66,7 @@ STATUS ReadProtoFromText(const char *file, return RET_OK; } -STATUS ReadProtoFromBinaryFile(const char *file, - google::protobuf::Message *message) { +STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *message) { if (file == nullptr || message == nullptr) { return RET_ERROR; } @@ -100,4 +98,3 @@ STATUS ReadProtoFromBinaryFile(const char *file, } } // namespace lite } // namespace mindspore - diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h b/mindspore/lite/tools/common/protobuf_utils.h similarity index 85% rename from mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h rename to mindspore/lite/tools/common/protobuf_utils.h index 4cb6ff6b3b..501a34a142 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h +++ b/mindspore/lite/tools/common/protobuf_utils.h @@ -29,13 +29,10 @@ namespace lite { bool ReadProtoFromCodedInputStream(google::protobuf::io::CodedInputStream *coded_stream, google::protobuf::Message *proto); -STATUS ReadProtoFromText(const char *file, - google::protobuf::Message *message); +STATUS ReadProtoFromText(const char *file, google::protobuf::Message *message); -STATUS ReadProtoFromBinaryFile(const char *file, - google::protobuf::Message *message); +STATUS ReadProtoFromBinaryFile(const char *file, google::protobuf::Message *message); } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_CAFFE_CAFFE_PARSE_UTILS_H_ - diff --git a/mindspore/lite/tools/converter/CMakeLists.txt b/mindspore/lite/tools/converter/CMakeLists.txt index cb443a356d..de1130fc47 100644 --- a/mindspore/lite/tools/converter/CMakeLists.txt +++ b/mindspore/lite/tools/converter/CMakeLists.txt @@ -94,6 +94,7 @@ file(GLOB_RECURSE CONVERTER_SRC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_SOURCE_DIR}/../common/graph_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/node_util.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/tensor_util.cc + ${CMAKE_CURRENT_SOURCE_DIR}/../common/protobuf_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/flag_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/../common/storage.cc ${CMAKE_CURRENT_SOURCE_DIR}/../../src/ir/primitive_t_value.cc diff --git a/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt b/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt index f468c1e2c5..bf27a111ba 100644 --- a/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt +++ b/mindspore/lite/tools/converter/parser/caffe/CMakeLists.txt @@ -15,7 +15,6 @@ add_library(caffe_parser_mid OBJECT ${CMAKE_CURRENT_SOURCE_DIR}/caffe_model_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_node_parser_registry.cc - ${CMAKE_CURRENT_SOURCE_DIR}/caffe_parse_utils.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_pooling_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_power_parser.cc ${CMAKE_CURRENT_SOURCE_DIR}/caffe_prelu_parser.cc diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc index 16056fa39d..0a87989886 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_converter.cc @@ -15,7 +15,6 @@ */ #include "mindspore/lite/tools/converter/parser/caffe/caffe_converter.h" -#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" namespace mindspore { namespace lite { diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 387337c67b..5f75fc1ee4 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -14,14 +14,14 @@ * limitations under the License. */ -#include "mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.h" +#include "tools/converter/parser/caffe/caffe_model_parser.h" #include #include #include -#include "mindspore/lite/tools/converter/parser/caffe/caffe_node_parser_registry.h" -#include "mindspore/lite/tools/converter/parser/caffe/caffe_parse_utils.h" -#include "mindspore/lite/tools/converter/parser/caffe/caffe_inspector.h" +#include "tools/converter/parser/caffe/caffe_node_parser_registry.h" +#include "tools/converter/parser/caffe/caffe_inspector.h" #include "tools/common/graph_util.h" +#include "tools/common/protobuf_utils.h" namespace mindspore { namespace lite { @@ -31,9 +31,8 @@ CaffeModelParser::~CaffeModelParser() {} const std::set CaffeModelParser::skipedLayerType = {"Dropout"}; -schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, - const std::string &weightFile, - const QuantType &quantType) { +schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType) { if (ValidateFileStr(modelFile, ".prototxt") != RET_OK) { MS_LOG(ERROR) << "INPUT ILLEGAL: modelFile must be *.prototxt"; return nullptr; @@ -89,8 +88,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, return metaGraph.release(); } -STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, - schema::CNodeT *op, +STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache) { for (int i = 0; i < layer.bottom_size(); i++) { int index = tensorCache->FindTensor(layer.bottom(i)); @@ -104,8 +102,7 @@ STATUS CaffeModelParser::SetOpInputIdx(const caffe::LayerParameter &layer, return RET_OK; } -STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, - schema::CNodeT *op, +STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, schema::CNodeT *op, TensorCache *tensorCache) { for (int i = 0; i < layer.top_size(); i++) { std::unique_ptr msTensor = std::make_unique(); @@ -114,8 +111,7 @@ STATUS CaffeModelParser::SetOpOutputIdx(const caffe::LayerParameter &layer, return RET_OK; } -STATUS CaffeModelParser::SetWeightTensor(const std::vector &weightVec, - schema::CNodeT *op, +STATUS CaffeModelParser::SetWeightTensor(const std::vector &weightVec, schema::CNodeT *op, TensorCache *tensorCache) { for (auto iter : weightVec) { op->inputIndex.emplace_back(tensorCache->AddTensor("Weight", iter, CONST)); @@ -123,8 +119,7 @@ STATUS CaffeModelParser::SetWeightTensor(const std::vector &w return RET_OK; } -STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache, - schema::MetaGraphT *subGraphDef) { +STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache, schema::MetaGraphT *subGraphDef) { std::vector tensors = tensorCache.GetCachedTensor(); for (auto iter : tensors) { std::unique_ptr temp(iter); @@ -133,8 +128,7 @@ STATUS CaffeModelParser::SetAllTensors(const TensorCache &tensorCache, return RET_OK; } -STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, - TensorCache *tensorCache, +STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, TensorCache *tensorCache, schema::MetaGraphT *subGraphDef) { CaffeInspector caffeInspector; caffeInspector.InspectModel(proto); @@ -160,10 +154,8 @@ STATUS CaffeModelParser::SetGraphTensorIndex(const caffe::NetParameter &proto, return RET_OK; } -STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, - const caffe::NetParameter &weight, - TensorCache *tensorCache, - schema::MetaGraphT *subGraphDef) { +STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caffe::NetParameter &weight, + TensorCache *tensorCache, schema::MetaGraphT *subGraphDef) { for (int i = 0; i < proto.layer_size(); i++) { auto layer = proto.layer(i); @@ -235,8 +227,7 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, return RET_OK; } -STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, - TensorCache *tensorCache) { +STATUS CaffeModelParser::GetModelInput(const caffe::NetParameter &proto, TensorCache *tensorCache) { for (int i = 0; i < proto.input_size(); i++) { if (proto.input_dim_size() <= 0) { continue; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc old mode 100755 new mode 100644 index 926573940f..1327074f11 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -21,6 +21,7 @@ #include #include "tools/common/graph_util.h" #include "src/common/utils.h" +#include "tools/common/protobuf_utils.h" namespace mindspore { namespace lite { @@ -54,36 +55,7 @@ std::vector OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo return dims; } -STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile, - google::protobuf::Message *onnx_model) { - std::unique_ptr onnx_file(new (std::nothrow) char[PATH_MAX]{0}); -#ifdef _WIN32 - if (_fullpath(onnx_file.get(), modelFile.c_str(), 1024) == nullptr) { - MS_LOG(ERROR) << "get realpath " << modelFile << " fail"; - return RET_ERROR; - } -#else - if (realpath(modelFile.c_str(), onnx_file.get()) == nullptr) { - MS_LOG(ERROR) << "get realpath " << modelFile << " fail"; - return RET_ERROR; - } -#endif - int fd = open(onnx_file.get(), O_RDONLY); - google::protobuf::io::FileInputStream input(fd); - google::protobuf::io::CodedInputStream code_input(&input); - code_input.SetTotalBytesLimit(INT_MAX, 536870912); - bool ret = onnx_model->ParseFromCodedStream(&code_input); - if (!ret) { - MS_LOG(ERROR) << "load onnx file failed"; - return RET_ERROR; - } - (void)close(fd); - onnx_file.release(); - return RET_OK; -} - -STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, - TensorCache *tensor_cache) { +STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) { MS_LOG(DEBUG) << "set onnx constant tensors"; for (const auto &onnx_const_value : onnx_graph.initializer()) { int index; @@ -119,11 +91,8 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, return RET_OK; } -STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, - const std::string &name, - const TensorType &type, - TensorCache *tensor_cache, - int *index) { +STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type, + TensorCache *tensor_cache, int *index) { auto data_type = GetDataTypeFromOnnx(static_cast(proto.type().tensor_type().elem_type())); if (data_type == kTypeUnknown) { MS_LOG(ERROR) << "not support onnx data type " @@ -143,11 +112,8 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, return RET_OK; } -STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, - const std::string &name, - const TensorType &type, - TensorCache *tensor_cache, - int *index) { +STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type, + TensorCache *tensor_cache, int *index) { auto data_type = GetDataTypeFromOnnx(static_cast(proto.data_type())); if (data_type == kTypeUnknown) { MS_LOG(ERROR) << "not support onnx data type " << static_cast(proto.data_type()); @@ -174,8 +140,7 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, return RET_OK; } -STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, - schema::MetaGraphT *graph, +STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache) { for (const auto &input_value : onnx_graph.input()) { auto ret = tensor_cache->FindTensor(input_value.name()); @@ -192,8 +157,7 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, return RET_OK; } -STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, - schema::MetaGraphT *graph, +STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache) { for (const auto &output_value : onnx_graph.output()) { int index; @@ -207,10 +171,8 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, return RET_OK; } -void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, - const onnx::NodeProto &onnx_node, - schema::MetaGraphT *graph, - TensorCache *tensor_cache) { +void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::MetaGraphT *graph, TensorCache *tensor_cache) { std::unique_ptr dst_op_1 = std::make_unique(); dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0); ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get()); @@ -231,8 +193,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, graph->nodes.emplace_back(std::move(dst_op_2)); } -STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, - TensorCache *tensor_cache) { +STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { // convert GivenTensorFill node to a weight/bias tensor auto ret = tensor_cache->FindTensor(onnx_node.output(0)); if (ret < 0) { @@ -284,10 +245,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, return RET_OK; } -STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, - const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, - schema::TensorT *dst_tensor, +STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) { // change op_type() to name(), that is unique dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0); @@ -319,11 +278,8 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, return RET_OK; } -void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, - const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, - schema::TensorT *dst_tensor, - TensorCache *tensor_cache) { +void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) { MS_ASSERT(dst_op != nullptr); MS_ASSERT(tensor_cache != nullptr); std::vector quant_node_name; @@ -380,10 +336,8 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, } } -STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, - const onnx::NodeProto &onnx_node, - const string &onnx_op_type, - schema::CNodeT *dst_op) { +STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + const string &onnx_op_type, schema::CNodeT *dst_op) { auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type); if (node_parser == nullptr) { MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr"; @@ -392,10 +346,8 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, return node_parser->Parse(onnx_graph, onnx_node, dst_op); } -STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, - schema::CNodeT *dst_op, - const onnx::NodeProto &onnx_node, - TensorCache *tensor_cache) { +STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, schema::CNodeT *dst_op, + const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) { for (const auto &onnx_node_input : node_inputs) { auto index = tensor_cache->FindTensor(onnx_node_input); if (index < 0) { @@ -408,8 +360,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector &node_inputs, return RET_OK; } -STATUS OnnxModelParser::SetOpOutputIndex(const std::vector &node_outputs, - schema::CNodeT *dst_op, +STATUS OnnxModelParser::SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache) { for (const auto &onnx_node_output : node_outputs) { auto index = tensor_cache->FindTensor(onnx_node_output); @@ -424,8 +375,7 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector &node_outputs return RET_OK; } -STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, - schema::TensorT *tensor) { +STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) { size_t data_count = 1; std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; }); size_t data_size = 0; @@ -484,8 +434,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v return RET_OK; } -STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, - schema::MetaGraphT *graphDef) { +STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) { std::vector tensors = tensor_cache.GetCachedTensor(); for (auto iter : tensors) { std::unique_ptr temp(iter); @@ -507,17 +456,16 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph) } } -schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, - const std::string &weightFile, - const QuantType &quantType) { +schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile, + const QuantType &quantType) { if (ValidateFileStr(modelFile, ".onnx") != RET_OK) { MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx"; return nullptr; } - auto dst_graph = std::make_unique(); + onnx::ModelProto onnx_model; - if (ReadOnnxModelFromBinary(modelFile, &onnx_model) != RET_OK) { - MS_LOG(ERROR) << "read onnx model fail"; + if (ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model) != RET_OK) { + MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile; return nullptr; } const onnx::GraphProto &onnx_graph = onnx_model.graph(); @@ -531,6 +479,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, MS_LOG(ERROR) << "SetGraphConstTensor failed"; return nullptr; } + auto dst_graph = std::make_unique(); // init onnx model graph input tensor if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) { MS_LOG(ERROR) << "SetGraphInputTensor failed"; diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h index 4ce0615bd5..e227dec4fc 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.h @@ -41,78 +41,47 @@ class OnnxModelParser : public ModelParser { virtual ~OnnxModelParser(); schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile, - const QuantType &quantType = QuantType_QUANT_NONE) override; + const QuantType &quantType = QuantType_QUANT_NONE) override; private: TypeId GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type); std::vector GetDimsFromOnnxValue(const onnx::ValueInfoProto &onnx_value); - STATUS ReadOnnxModelFromBinary(const std::string &modelFile, - google::protobuf::Message *model_proto); - - STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, - TensorCache *tensor_cache); - - STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, - schema::MetaGraphT *graph, - TensorCache *tensor_cache); - - STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, - schema::MetaGraphT *graph, - TensorCache *tensor_cache); - - STATUS AddValueInfo(const onnx::ValueInfoProto &proto, - const std::string &name, - const TensorType &type, - TensorCache *tensor_cache, - int *index); - - STATUS AddTensorProto(const onnx::TensorProto &proto, - const std::string &name, - const TensorType &type, - TensorCache *tensor_cache, - int *index); - - STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, - const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, - schema::TensorT *dst_tensor, - TensorCache *tensor_cache); - - void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, - const onnx::NodeProto &onnx_node, - schema::MetaGraphT *graph, - TensorCache *tensor_cache); - - STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, - TensorCache *tensor_cache); - - STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, - const onnx::NodeProto &onnx_node, - const string &onnx_op_type, - schema::CNodeT *dst_op); - - void SetOpQuantParams(const onnx::GraphProto &onnx_graph, - const onnx::NodeProto &onnx_node, - schema::CNodeT *dst_op, - schema::TensorT *dst_tensor, - TensorCache *tensor_cache); - - STATUS SetOpInputIndex(const std::vector &node_inputs, - schema::CNodeT *dst_op, - const onnx::NodeProto &onnx_node, - TensorCache *tensor_cache); - - STATUS SetOpOutputIndex(const std::vector &node_outputs, - schema::CNodeT *dst_op, - TensorCache *tensor_cache); - - STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, - schema::TensorT *tensor); - - STATUS SetAllTensors(const TensorCache &tensor_cache, - schema::MetaGraphT *graphDef); + STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache); + + STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); + + STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache); + + STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type, + TensorCache *tensor_cache, int *index); + + STATUS AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type, + TensorCache *tensor_cache, int *index); + + STATUS ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache); + + void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + schema::MetaGraphT *graph, TensorCache *tensor_cache); + + STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); + + STATUS ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, + const string &onnx_op_type, schema::CNodeT *dst_op); + + void SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *dst_op, + schema::TensorT *dst_tensor, TensorCache *tensor_cache); + + STATUS SetOpInputIndex(const std::vector &node_inputs, schema::CNodeT *dst_op, + const onnx::NodeProto &onnx_node, TensorCache *tensor_cache); + + STATUS SetOpOutputIndex(const std::vector &node_outputs, schema::CNodeT *dst_op, TensorCache *tensor_cache); + + STATUS CopyOnnxTensorData(const onnx::TensorProto &onnx_init_value, schema::TensorT *tensor); + + STATUS SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef); void FindGraphInputAndConst(const onnx::GraphProto &onnx_graph);