add onnx parser and adjust the way of printing unsupport op

pull/7019/head
xuanyue 4 years ago
parent 5a83415f07
commit 357b597b4f

@ -58,7 +58,8 @@ enum ActivationType : byte {
THRESHOLDRELU = 14, THRESHOLDRELU = 14,
LINEAR = 15, LINEAR = 15,
HARD_TANH = 16, HARD_TANH = 16,
UNKNOW = 17 SIGN = 17,
UNKNOW = 18
} }
enum ActivationGradType : byte { enum ActivationGradType : byte {
NO_ACTIVATION = 0, NO_ACTIVATION = 0,

@ -595,10 +595,14 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
for (int i = 0; i < node_proto.input_size(); ++i) { for (int i = 0; i < node_proto.input_size(); ++i) {
const std::string &input_name = node_proto.input(i); const std::string &input_name = node_proto.input(i);
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) { if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed"; if (!interrupt) {
return nullptr; MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
interrupt = true;
}
inputs.push_back(nullptr);
} else {
inputs.push_back(anfnode_build_map_[input_name]);
} }
inputs.push_back(anfnode_build_map_[input_name]);
} }
auto primitivec_ptr = PrimitiveC::Create(*prim, inputs, quantType); auto primitivec_ptr = PrimitiveC::Create(*prim, inputs, quantType);
if (primitivec_ptr == nullptr || interrupt) { if (primitivec_ptr == nullptr || interrupt) {
@ -714,6 +718,7 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size(); MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
CNodePtr cnode_ptr = nullptr; CNodePtr cnode_ptr = nullptr;
int status = RET_OK; int status = RET_OK;
NoSupportOp::GetInstance()->SetFmkType("MINDIR");
for (int i = 0; i < importProto.node_size(); ++i) { for (int i = 0; i < importProto.node_size(); ++i) {
const onnx::NodeProto &node_proto = importProto.node(i); const onnx::NodeProto &node_proto = importProto.node(i);
const std::string &node_type = node_proto.op_type(); const std::string &node_type = node_proto.op_type();

@ -34,7 +34,6 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath
return RET_ERROR; return RET_ERROR;
} }
if (access((outputPath + ".ms").c_str(), F_OK) == 0) { if (access((outputPath + ".ms").c_str(), F_OK) == 0) {
MS_LOG(WARNING) << "this file " << outputPath << ".ms has been existed";
chmod((outputPath + ".ms").c_str(), S_IWUSR); chmod((outputPath + ".ms").c_str(), S_IWUSR);
} }
std::ofstream output(outputPath + ".ms", std::ofstream::binary); std::ofstream output(outputPath + ".ms", std::ofstream::binary);

@ -65,7 +65,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
FuncGraphPtr graph = nullptr; FuncGraphPtr graph = nullptr;
if (flag->fmk == converter::FmkType_MS) { if (flag->fmk == converter::FmkType_MS) {
MS_ASSERT(nullptr != modelImporter); MS_ASSERT(nullptr != modelImporter);
modelImporter->Import(flag->quantType);
int status = modelImporter->Import(flag->quantType); int status = modelImporter->Import(flag->quantType);
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status); ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
graph = modelImporter->GetResult(); graph = modelImporter->GetResult();

@ -50,16 +50,23 @@ class NoSupportOp {
static NoSupportOp noSupportOp; static NoSupportOp noSupportOp;
return &noSupportOp; return &noSupportOp;
} }
void SetFmkType(const std::string &fmk_type) { fmkType = fmk_type; }
void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); } void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); }
void PrintOps() const { void PrintOps() const {
for (auto &op_name : noSupportOps) { if (!noSupportOps.empty()) {
MS_LOG(ERROR) << "The op " << op_name << " hasn't been supported"; MS_LOG(ERROR) << "===========================================";
MS_LOG(ERROR) << "UNSUPPORT OP LIST:";
for (auto &op_name : noSupportOps) {
MS_LOG(ERROR) << "FMKTYPE: " << fmkType << ", OP TYPE: " << op_name;
}
MS_LOG(ERROR) << "===========================================";
} }
} }
private: private:
NoSupportOp() { noSupportOps.clear(); } NoSupportOp() { noSupportOps.clear(); }
std::set<std::string> noSupportOps; std::set<std::string> noSupportOps;
std::string fmkType;
}; };
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -80,6 +80,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
return nullptr; return nullptr;
} }
NoSupportOp::GetInstance()->SetFmkType("CAFFE");
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType); status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType);
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "ParseLayer failed " << status; MS_LOG(ERROR) << "ParseLayer failed " << status;
@ -242,7 +243,11 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff
auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec); auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec);
if (status_node != RET_OK) { if (status_node != RET_OK) {
interrupt = true; interrupt = true;
MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!"; if (status_node == RET_NOT_SUPPORT) {
NoSupportOp::GetInstance()->InsertOp(layer.type());
} else {
MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!";
}
status = (status == RET_OK ? RET_NOT_FIND_OP : status); status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue; continue;
} }

@ -559,6 +559,29 @@ STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
return RET_OK; return RET_OK;
} }
STATUS OnnxSignParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TanhParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_SIGN;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
return RET_OK;
}
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser()); OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());
OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser()); OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser()); OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
@ -584,5 +607,6 @@ OnnxNodeRegistrar g_onnxTanParser("Tan", new OnnxTanParser());
OnnxNodeRegistrar g_onnxAtanParser("Atan", new OnnxAtanParser()); OnnxNodeRegistrar g_onnxAtanParser("Atan", new OnnxAtanParser());
OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser()); OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser());
OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser()); OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser());
OnnxNodeRegistrar g_onnxSignParser("Sign", new OnnxTanhParser());
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore

@ -165,6 +165,12 @@ class OnnxTanhParser : public OnnxNodeParser {
OnnxTanhParser() : OnnxNodeParser("Tanh") {} OnnxTanhParser() : OnnxNodeParser("Tanh") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override; STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
}; };
class OnnxSignParser : public OnnxNodeParser {
public:
OnnxSignParser() : OnnxNodeParser("Sign") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite } // namespace lite
} // namespace mindspore } // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H #endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H

@ -47,12 +47,18 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
return RET_NULL_PTR; return RET_NULL_PTR;
} }
attr->type = schema::ActivationType_RELU6; attr->type = schema::ActivationType_RELU6;
op->primitive->value.type = schema::PrimitiveType_Activation; op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release(); op->primitive->value.value = attr.release();
} else { } else {
MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported"; std::unique_ptr<schema::ClipT> attr = std::make_unique<schema::ClipT>();
return RET_ERROR; if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->max = max;
attr->min = min;
op->primitive->value.type = schema::PrimitiveType_Clip;
op->primitive->value.value = attr.release();
} }
return RET_OK; return RET_OK;
} }

@ -271,7 +271,11 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op); auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op);
if (status != RET_OK) { if (status != RET_OK) {
interrupt = true; interrupt = true;
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed"; if (status == RET_NOT_SUPPORT) {
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
} else {
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
}
return status; return status;
} }
// set op input index // set op input index
@ -514,6 +518,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
return nullptr; return nullptr;
} }
// init op node input/output tensor, and dst_op attr // init op node input/output tensor, and dst_op attr
NoSupportOp::GetInstance()->SetFmkType("ONNX");
for (const auto &onnx_node : onnx_graph.node()) { for (const auto &onnx_node : onnx_graph.node()) {
int status_node = RET_OK; int status_node = RET_OK;
if (onnx_node.op_type() == "Constant") { if (onnx_node.op_type() == "Constant") {

@ -96,6 +96,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
const QuantType &quant_type, schema::MetaGraphT *sub_graph) { const QuantType &quant_type, schema::MetaGraphT *sub_graph) {
int idx = 0; int idx = 0;
int status = RET_OK; int status = RET_OK;
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
for (const auto &tflite_op : tflite_subgraph->operators) { for (const auto &tflite_op : tflite_subgraph->operators) {
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code; auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
auto op_type = GetMSOpType(tflite_op_type); auto op_type = GetMSOpType(tflite_op_type);
@ -119,7 +120,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
if (status == RET_OK) { if (status == RET_OK) {
status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get()); status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get());
if (status != RET_OK) { if (status != RET_OK) {
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed"; if (status == RET_NOT_SUPPORT) {
NoSupportOp::GetInstance()->InsertOp(op_type);
} else {
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
}
continue; continue;
} }

Loading…
Cancel
Save