From 357b597b4f4bde07c216492232c678f0d9324b19 Mon Sep 17 00:00:00 2001 From: xuanyue Date: Tue, 29 Sep 2020 16:05:15 +0800 Subject: [PATCH] add onnx parser and adjust the way of printing unsupport op --- mindspore/lite/schema/ops.fbs | 3 ++- .../anf_importer/import_from_protobuf.cc | 11 ++++++--- mindspore/lite/tools/common/storage.cc | 1 - mindspore/lite/tools/converter/converter.cc | 1 - .../lite/tools/converter/converter_context.h | 11 +++++++-- .../parser/caffe/caffe_model_parser.cc | 7 +++++- .../onnx/onnx_arithmetic_operation_parser.cc | 24 +++++++++++++++++++ .../onnx/onnx_arithmetic_operation_parser.h | 6 +++++ .../converter/parser/onnx/onnx_clip_parser.cc | 12 +++++++--- .../parser/onnx/onnx_model_parser.cc | 7 +++++- .../parser/tflite/tflite_model_parser.cc | 7 +++++- 11 files changed, 76 insertions(+), 14 deletions(-) diff --git a/mindspore/lite/schema/ops.fbs b/mindspore/lite/schema/ops.fbs index 509b4b298b..4928f59c72 100644 --- a/mindspore/lite/schema/ops.fbs +++ b/mindspore/lite/schema/ops.fbs @@ -58,7 +58,8 @@ enum ActivationType : byte { THRESHOLDRELU = 14, LINEAR = 15, HARD_TANH = 16, - UNKNOW = 17 + SIGN = 17, + UNKNOW = 18 } enum ActivationGradType : byte { NO_ACTIVATION = 0, diff --git a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc index a904b0f623..97ea1fbdc4 100644 --- a/mindspore/lite/tools/anf_importer/import_from_protobuf.cc +++ b/mindspore/lite/tools/anf_importer/import_from_protobuf.cc @@ -595,10 +595,14 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out for (int i = 0; i < node_proto.input_size(); ++i) { const std::string &input_name = node_proto.input(i); if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { - MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; - return nullptr; + if (!interrupt) { + MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; + interrupt = true; + } + inputs.push_back(nullptr); + } else { + inputs.push_back(anfnode_build_map_[input_name]); } - inputs.push_back(anfnode_build_map_[input_name]); } auto primitivec_ptr = PrimitiveC::Create(*prim, inputs, quantType); if (primitivec_ptr == nullptr || interrupt) { @@ -714,6 +718,7 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); CNodePtr cnode_ptr = nullptr; int status = RET_OK; + NoSupportOp::GetInstance()->SetFmkType("MINDIR"); for (int i = 0; i < importProto.node_size(); ++i) { const onnx::NodeProto &node_proto = importProto.node(i); const std::string &node_type = node_proto.op_type(); diff --git a/mindspore/lite/tools/common/storage.cc b/mindspore/lite/tools/common/storage.cc index 33f4a62a7e..abf20506f8 100644 --- a/mindspore/lite/tools/common/storage.cc +++ b/mindspore/lite/tools/common/storage.cc @@ -34,7 +34,6 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath return RET_ERROR; } if (access((outputPath + ".ms").c_str(), F_OK) == 0) { - MS_LOG(WARNING) << "this file " << outputPath << ".ms has been existed"; chmod((outputPath + ".ms").c_str(), S_IWUSR); } std::ofstream output(outputPath + ".ms", std::ofstream::binary); diff --git a/mindspore/lite/tools/converter/converter.cc b/mindspore/lite/tools/converter/converter.cc index 382aaf0202..d69c7372c3 100644 --- a/mindspore/lite/tools/converter/converter.cc +++ b/mindspore/lite/tools/converter/converter.cc @@ -65,7 +65,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) { FuncGraphPtr graph = nullptr; if (flag->fmk == converter::FmkType_MS) { MS_ASSERT(nullptr != modelImporter); - modelImporter->Import(flag->quantType); int status = modelImporter->Import(flag->quantType); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); graph = modelImporter->GetResult(); diff --git a/mindspore/lite/tools/converter/converter_context.h b/mindspore/lite/tools/converter/converter_context.h index da99538afb..8796807571 100644 --- a/mindspore/lite/tools/converter/converter_context.h +++ b/mindspore/lite/tools/converter/converter_context.h @@ -50,16 +50,23 @@ class NoSupportOp { static NoSupportOp noSupportOp; return &noSupportOp; } + void SetFmkType(const std::string &fmk_type) { fmkType = fmk_type; } void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); } void PrintOps() const { - for (auto &op_name : noSupportOps) { - MS_LOG(ERROR) << "The op " << op_name << " hasn't been supported"; + if (!noSupportOps.empty()) { + MS_LOG(ERROR) << "==========================================="; + MS_LOG(ERROR) << "UNSUPPORT OP LIST:"; + for (auto &op_name : noSupportOps) { + MS_LOG(ERROR) << "FMKTYPE: " << fmkType << ", OP TYPE: " << op_name; + } + MS_LOG(ERROR) << "==========================================="; } } private: NoSupportOp() { noSupportOps.clear(); } std::set noSupportOps; + std::string fmkType; }; } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc index 7c4aff82b0..6951714bc8 100644 --- a/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/caffe/caffe_model_parser.cc @@ -80,6 +80,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co return nullptr; } + NoSupportOp::GetInstance()->SetFmkType("CAFFE"); status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType); if (status != RET_OK) { MS_LOG(ERROR) << "ParseLayer failed " << status; @@ -242,7 +243,11 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec); if (status_node != RET_OK) { interrupt = true; - MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; + if (status_node == RET_NOT_SUPPORT) { + NoSupportOp::GetInstance()->InsertOp(layer.type()); + } else { + MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; + } status = (status == RET_OK ? RET_NOT_FIND_OP : status); continue; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc index 6c66d07666..6bd54d597d 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.cc @@ -559,6 +559,29 @@ STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod return RET_OK; } +STATUS OnnxSignParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) { + MS_LOG(DEBUG) << "onnx TanhParser"; + if (op == nullptr) { + MS_LOG(ERROR) << "op is null"; + return RET_NULL_PTR; + } + op->primitive = std::make_unique(); + if (op->primitive == nullptr) { + MS_LOG(ERROR) << "op->primitive is null"; + return RET_NULL_PTR; + } + + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + attr->type = schema::ActivationType_SIGN; + op->primitive->value.type = schema::PrimitiveType_Activation; + op->primitive->value.value = attr.release(); + return RET_OK; +} + OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser()); OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser()); OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser()); @@ -584,5 +607,6 @@ OnnxNodeRegistrar g_onnxTanParser("Tan", new OnnxTanParser()); OnnxNodeRegistrar g_onnxAtanParser("Atan", new OnnxAtanParser()); OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser()); OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser()); +OnnxNodeRegistrar g_onnxSignParser("Sign", new OnnxTanhParser()); } // namespace lite } // namespace mindspore diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h index d2761083b5..dd1fec8083 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_arithmetic_operation_parser.h @@ -165,6 +165,12 @@ class OnnxTanhParser : public OnnxNodeParser { OnnxTanhParser() : OnnxNodeParser("Tanh") {} STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; }; + +class OnnxSignParser : public OnnxNodeParser { + public: + OnnxSignParser() : OnnxNodeParser("Sign") {} + STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; +}; } // namespace lite } // namespace mindspore #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc index b6016a8bcc..398b5af837 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_clip_parser.cc @@ -47,12 +47,18 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod return RET_NULL_PTR; } attr->type = schema::ActivationType_RELU6; - op->primitive->value.type = schema::PrimitiveType_Activation; op->primitive->value.value = attr.release(); } else { - MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported"; - return RET_ERROR; + std::unique_ptr attr = std::make_unique(); + if (attr == nullptr) { + MS_LOG(ERROR) << "new op failed"; + return RET_NULL_PTR; + } + attr->max = max; + attr->min = min; + op->primitive->value.type = schema::PrimitiveType_Clip; + op->primitive->value.value = attr.release(); } return RET_OK; } diff --git a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc index d3d3e48fcb..ccd6170669 100644 --- a/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/onnx/onnx_model_parser.cc @@ -271,7 +271,11 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op); if (status != RET_OK) { interrupt = true; - MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed"; + if (status == RET_NOT_SUPPORT) { + NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type()); + } else { + MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed"; + } return status; } // set op input index @@ -514,6 +518,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con return nullptr; } // init op node input/output tensor, and dst_op attr + NoSupportOp::GetInstance()->SetFmkType("ONNX"); for (const auto &onnx_node : onnx_graph.node()) { int status_node = RET_OK; if (onnx_node.op_type() == "Constant") { diff --git a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc index 11a5151623..5352209da2 100644 --- a/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc +++ b/mindspore/lite/tools/converter/parser/tflite/tflite_model_parser.cc @@ -96,6 +96,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit const QuantType &quant_type, schema::MetaGraphT *sub_graph) { int idx = 0; int status = RET_OK; + NoSupportOp::GetInstance()->SetFmkType("TFLITE"); for (const auto &tflite_op : tflite_subgraph->operators) { auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; auto op_type = GetMSOpType(tflite_op_type); @@ -119,7 +120,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr &tflit if (status == RET_OK) { status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get()); if (status != RET_OK) { - MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; + if (status == RET_NOT_SUPPORT) { + NoSupportOp::GetInstance()->InsertOp(op_type); + } else { + MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; + } continue; }