add onnx parser and adjust the way of printing unsupport op

pull/7019/head
xuanyue 4 years ago
parent 5a83415f07
commit 357b597b4f

@ -58,7 +58,8 @@ enum ActivationType : byte {
THRESHOLDRELU = 14,
LINEAR = 15,
HARD_TANH = 16,
UNKNOW = 17
SIGN = 17,
UNKNOW = 18
}
enum ActivationGradType : byte {
NO_ACTIVATION = 0,

@ -595,10 +595,14 @@ CNodePtr AnfImporterFromProtobuf::BuildCNodeForFuncGraph(const FuncGraphPtr &out
for (int i = 0; i < node_proto.input_size(); ++i) {
const std::string &input_name = node_proto.input(i);
if (anfnode_build_map_.find(input_name) == anfnode_build_map_.end()) {
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
return nullptr;
if (!interrupt) {
MS_LOG(ERROR) << node_name << " input " << i << input_name << "can't find in nodes have parsed";
interrupt = true;
}
inputs.push_back(nullptr);
} else {
inputs.push_back(anfnode_build_map_[input_name]);
}
inputs.push_back(anfnode_build_map_[input_name]);
}
auto primitivec_ptr = PrimitiveC::Create(*prim, inputs, quantType);
if (primitivec_ptr == nullptr || interrupt) {
@ -714,6 +718,7 @@ int AnfImporterFromProtobuf::ImportNodesForGraph(const FuncGraphPtr &outputFuncG
MS_LOG(INFO) << "The CNdoe size : " << importProto.node_size();
CNodePtr cnode_ptr = nullptr;
int status = RET_OK;
NoSupportOp::GetInstance()->SetFmkType("MINDIR");
for (int i = 0; i < importProto.node_size(); ++i) {
const onnx::NodeProto &node_proto = importProto.node(i);
const std::string &node_type = node_proto.op_type();

@ -34,7 +34,6 @@ int Storage::Save(const schema::MetaGraphT &graph, const std::string &outputPath
return RET_ERROR;
}
if (access((outputPath + ".ms").c_str(), F_OK) == 0) {
MS_LOG(WARNING) << "this file " << outputPath << ".ms has been existed";
chmod((outputPath + ".ms").c_str(), S_IWUSR);
}
std::ofstream output(outputPath + ".ms", std::ofstream::binary);

@ -65,7 +65,6 @@ MetaGraphT *Converter::Convert(const converter::Flags *flag) {
FuncGraphPtr graph = nullptr;
if (flag->fmk == converter::FmkType_MS) {
MS_ASSERT(nullptr != modelImporter);
modelImporter->Import(flag->quantType);
int status = modelImporter->Import(flag->quantType);
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
graph = modelImporter->GetResult();

@ -50,16 +50,23 @@ class NoSupportOp {
static NoSupportOp noSupportOp;
return &noSupportOp;
}
void SetFmkType(const std::string &fmk_type) { fmkType = fmk_type; }
void InsertOp(const std::string &op_name) { noSupportOps.insert(op_name); }
void PrintOps() const {
for (auto &op_name : noSupportOps) {
MS_LOG(ERROR) << "The op " << op_name << " hasn't been supported";
if (!noSupportOps.empty()) {
MS_LOG(ERROR) << "===========================================";
MS_LOG(ERROR) << "UNSUPPORT OP LIST:";
for (auto &op_name : noSupportOps) {
MS_LOG(ERROR) << "FMKTYPE: " << fmkType << ", OP TYPE: " << op_name;
}
MS_LOG(ERROR) << "===========================================";
}
}
private:
NoSupportOp() { noSupportOps.clear(); }
std::set<std::string> noSupportOps;
std::string fmkType;
};
} // namespace lite
} // namespace mindspore

@ -80,6 +80,7 @@ schema::MetaGraphT *CaffeModelParser::ParseToFb(const std::string &modelFile, co
return nullptr;
}
NoSupportOp::GetInstance()->SetFmkType("CAFFE");
status = ParseLayer(proto, weight, &tensorCache, metaGraph.get(), quantType);
if (status != RET_OK) {
MS_LOG(ERROR) << "ParseLayer failed " << status;
@ -242,7 +243,11 @@ STATUS CaffeModelParser::ParseLayer(const caffe::NetParameter &proto, const caff
auto status_node = nodeParser->Parse(layer, layerP, op.get(), &weightVec);
if (status_node != RET_OK) {
interrupt = true;
MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!";
if (status_node == RET_NOT_SUPPORT) {
NoSupportOp::GetInstance()->InsertOp(layer.type());
} else {
MS_LOG(ERROR) << "Parse weight for " << layer.name() << " Failed!";
}
status = (status == RET_OK ? RET_NOT_FIND_OP : status);
continue;
}

@ -559,6 +559,29 @@ STATUS OnnxTanhParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
return RET_OK;
}
STATUS OnnxSignParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) {
MS_LOG(DEBUG) << "onnx TanhParser";
if (op == nullptr) {
MS_LOG(ERROR) << "op is null";
return RET_NULL_PTR;
}
op->primitive = std::make_unique<schema::PrimitiveT>();
if (op->primitive == nullptr) {
MS_LOG(ERROR) << "op->primitive is null";
return RET_NULL_PTR;
}
std::unique_ptr<schema::ActivationT> attr = std::make_unique<schema::ActivationT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_SIGN;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
return RET_OK;
}
OnnxNodeRegistrar g_onnxAddParser("Add", new OnnxAddParser());
OnnxNodeRegistrar g_onnxInt8AddParser("Int8Add", new OnnxAddParser());
OnnxNodeRegistrar g_onnxSubParser("Sub", new OnnxSubParser());
@ -584,5 +607,6 @@ OnnxNodeRegistrar g_onnxTanParser("Tan", new OnnxTanParser());
OnnxNodeRegistrar g_onnxAtanParser("Atan", new OnnxAtanParser());
OnnxNodeRegistrar g_onnxAsinParser("Asin", new OnnxAsinParser());
OnnxNodeRegistrar g_onnxTanhParser("Tanh", new OnnxTanhParser());
OnnxNodeRegistrar g_onnxSignParser("Sign", new OnnxTanhParser());
} // namespace lite
} // namespace mindspore

@ -165,6 +165,12 @@ class OnnxTanhParser : public OnnxNodeParser {
OnnxTanhParser() : OnnxNodeParser("Tanh") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
class OnnxSignParser : public OnnxNodeParser {
public:
OnnxSignParser() : OnnxNodeParser("Sign") {}
STATUS Parse(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node, schema::CNodeT *op) override;
};
} // namespace lite
} // namespace mindspore
#endif // MINDSPORE_LITE_TOOLS_CONVERTER_PARSER_ONNX_ARITHMETIC_OPREATION_PARSER_H

@ -47,12 +47,18 @@ STATUS OnnxClipParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod
return RET_NULL_PTR;
}
attr->type = schema::ActivationType_RELU6;
op->primitive->value.type = schema::PrimitiveType_Activation;
op->primitive->value.value = attr.release();
} else {
MS_LOG(ERROR) << "only support convert clip(0,6) to relu6, other value is not supported";
return RET_ERROR;
std::unique_ptr<schema::ClipT> attr = std::make_unique<schema::ClipT>();
if (attr == nullptr) {
MS_LOG(ERROR) << "new op failed";
return RET_NULL_PTR;
}
attr->max = max;
attr->min = min;
op->primitive->value.type = schema::PrimitiveType_Clip;
op->primitive->value.value = attr.release();
}
return RET_OK;
}

@ -271,7 +271,11 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op);
if (status != RET_OK) {
interrupt = true;
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
if (status == RET_NOT_SUPPORT) {
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
} else {
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
}
return status;
}
// set op input index
@ -514,6 +518,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
return nullptr;
}
// init op node input/output tensor, and dst_op attr
NoSupportOp::GetInstance()->SetFmkType("ONNX");
for (const auto &onnx_node : onnx_graph.node()) {
int status_node = RET_OK;
if (onnx_node.op_type() == "Constant") {

@ -96,6 +96,7 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
const QuantType &quant_type, schema::MetaGraphT *sub_graph) {
int idx = 0;
int status = RET_OK;
NoSupportOp::GetInstance()->SetFmkType("TFLITE");
for (const auto &tflite_op : tflite_subgraph->operators) {
auto tflite_op_type = (tflite_model->operator_codes[tflite_op->opcode_index])->builtin_code;
auto op_type = GetMSOpType(tflite_op_type);
@ -119,7 +120,11 @@ STATUS TfliteModelParser::ConvertOp(const std::unique_ptr<tflite::ModelT> &tflit
if (status == RET_OK) {
status = node_parser->Parse(&tensorsInfo, tflite_op, tflite_model, op.get());
if (status != RET_OK) {
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
if (status == RET_NOT_SUPPORT) {
NoSupportOp::GetInstance()->InsertOp(op_type);
} else {
MS_LOG(ERROR) << "node " << op_type.c_str() << " parser failed";
}
continue;
}

Loading…
Cancel
Save