|
|
|
@ -15,12 +15,12 @@
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#include "tools/converter/parser/onnx/onnx_model_parser.h"
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <cfloat>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include "tools/common/graph_util.h"
|
|
|
|
|
#include "src/common/utils.h"
|
|
|
|
|
#include "tools/common/graph_util.h"
|
|
|
|
|
#include "tools/common/protobuf_utils.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
@ -36,7 +36,8 @@ static const std::unordered_map<int, mindspore::TypeId> TYPE_MAP = {
|
|
|
|
|
{onnx::TensorProto_DataType_UINT32, mindspore::kNumberTypeUInt32},
|
|
|
|
|
{onnx::TensorProto_DataType_INT64, mindspore::kNumberTypeInt64},
|
|
|
|
|
{onnx::TensorProto_DataType_FLOAT16, mindspore::kNumberTypeFloat16},
|
|
|
|
|
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32}};
|
|
|
|
|
{onnx::TensorProto_DataType_FLOAT, mindspore::kNumberTypeFloat32},
|
|
|
|
|
{onnx::TensorProto_DataType_BOOL, mindspore::kNumberTypeBool}};
|
|
|
|
|
|
|
|
|
|
TypeId OnnxModelParser::GetDataTypeFromOnnx(onnx::TensorProto_DataType onnx_type) {
|
|
|
|
|
auto iter = TYPE_MAP.find(onnx_type);
|
|
|
|
@ -161,10 +162,14 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
for (const auto &output_value : onnx_graph.output()) {
|
|
|
|
|
int index;
|
|
|
|
|
if (tensor_cache->FindTensor(output_value.name()) != -1) {
|
|
|
|
|
index = tensor_cache->FindTensor(output_value.name());
|
|
|
|
|
} else {
|
|
|
|
|
const auto status = AddValueInfo(output_value, output_value.name(), OP_OUTPUT, tensor_cache, &index);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
graph->outputIndex.emplace_back(index);
|
|
|
|
|
MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << index;
|
|
|
|
|
}
|
|
|
|
@ -250,7 +255,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
|
|
|
|
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
|
|
|
|
|
TensorCache *tensor_cache, const QuantType &quantType) {
|
|
|
|
|
TensorCache *tensor_cache, const QuantType &quantType,
|
|
|
|
|
schema::MetaGraphT *dst_graph) {
|
|
|
|
|
// change op_type() to name(), that is unique
|
|
|
|
|
static bool interrupt = false;
|
|
|
|
|
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
|
|
|
|
@ -260,6 +266,16 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
<< onnx_node.input_size();
|
|
|
|
|
// get the real op type
|
|
|
|
|
SetOpQuantParams(onnx_graph, onnx_node, dst_op, dst_tensor, tensor_cache);
|
|
|
|
|
if (onnx_node.op_type() == "Loop") {
|
|
|
|
|
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
|
|
|
|
interrupt = true;
|
|
|
|
|
return RET_NOT_FIND_OP;
|
|
|
|
|
int status = ParseLoopAttr(dst_op, onnx_node, quantType, dst_graph);
|
|
|
|
|
if (status != RET_OK || interrupt) {
|
|
|
|
|
interrupt = true;
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_node.op_type());
|
|
|
|
|
if (node_parser == nullptr || interrupt) {
|
|
|
|
|
interrupt = true;
|
|
|
|
@ -271,13 +287,14 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
auto status = node_parser->Parse(onnx_graph, onnx_node, dst_op);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
interrupt = true;
|
|
|
|
|
if (status == RET_NOT_SUPPORT) {
|
|
|
|
|
if (status == RET_NOT_FIND_OP) {
|
|
|
|
|
NoSupportOp::GetInstance()->InsertOp(onnx_node.op_type());
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(ERROR) << "parser onnx node " << onnx_node.op_type() << " attr failed";
|
|
|
|
|
}
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
// set op input index
|
|
|
|
|
std::vector<string> node_inputs;
|
|
|
|
|
(void)node_inputs.insert(node_inputs.begin(), onnx_node.input().begin(), onnx_node.input().end());
|
|
|
|
@ -366,7 +383,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
|
|
|
|
|
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
|
|
|
|
|
for (const auto &onnx_node_input : node_inputs) {
|
|
|
|
|
if (onnx_node_input != "") {
|
|
|
|
|
auto index = tensor_cache->FindTensor(onnx_node_input);
|
|
|
|
|
int index = tensor_cache->FindTensor(onnx_node_input);
|
|
|
|
|
if (index < 0) {
|
|
|
|
|
MS_LOG(ERROR) << "input " << onnx_node_input << " of node " << onnx_node.name() << " can't be found";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
@ -428,6 +445,9 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|
|
|
|
}
|
|
|
|
|
for (size_t i = 0; i < data_count; ++i) {
|
|
|
|
|
if (in_data[i] > static_cast<int64_t>(INT32_MAX) || in_data[i] < static_cast<int64_t>(INT32_MIN)) {
|
|
|
|
|
if (llabs(in_data[i]) == INT64_MAX || in_data[i] == INT64_MIN) {
|
|
|
|
|
buffer[i] = in_data[i] > 0 ? INT32_MAX : INT32_MIN;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(ERROR) << "int64 data " << in_data[i] << "too big to fit into int32";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
} else {
|
|
|
|
@ -438,6 +458,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|
|
|
|
break;
|
|
|
|
|
case kNumberTypeUInt8:
|
|
|
|
|
case kNumberTypeInt8:
|
|
|
|
|
case kNumberTypeBool:
|
|
|
|
|
data_size = data_count * sizeof(uint8_t);
|
|
|
|
|
tensor_data = onnx_const_value.raw_data().data();
|
|
|
|
|
break;
|
|
|
|
@ -446,7 +467,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
tensor->data.resize(data_size);
|
|
|
|
|
if (memcpy_s(static_cast<void *>(tensor->data.data()), data_size, tensor_data, data_size) != 0) {
|
|
|
|
|
if (data_size != 0 && memcpy_s(static_cast<void *>(tensor->data.data()), data_size, tensor_data, data_size) != 0) {
|
|
|
|
|
MS_LOG(ERROR) << "memcpy_s failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
@ -475,30 +496,39 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
|
|
|
|
const QuantType &quantType) {
|
|
|
|
|
int status = ValidateFileStr(modelFile, ".onnx");
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
onnx::ModelProto onnx_model;
|
|
|
|
|
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
const onnx::GraphProto &onnx_graph = onnx_model.graph();
|
|
|
|
|
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();
|
|
|
|
|
STATUS OnnxModelParser::ParseLoopAttr(schema::CNodeT *dst_op, const onnx::NodeProto &onnx_node,
|
|
|
|
|
const QuantType &quantType, schema::MetaGraphT *dst_graph) {
|
|
|
|
|
MS_LOG(DEBUG) << "onnx LoopParser";
|
|
|
|
|
if (dst_op == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "op is null";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
dst_op->primitive = std::make_unique<schema::PrimitiveT>();
|
|
|
|
|
if (dst_op->primitive == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "op->primitive is null";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
std::unique_ptr<schema::LoopT> attr = std::make_unique<schema::LoopT>();
|
|
|
|
|
if (attr == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "new op failed";
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
attr->subGraphIndex = subGraphNum;
|
|
|
|
|
auto sub_graph = std::make_unique<schema::MetaGraphT>();
|
|
|
|
|
sub_graph.reset(ParseGraph(onnx_node.attribute().at(0).g(), quantType));
|
|
|
|
|
dst_graph->subGraph.push_back(std::move(sub_graph));
|
|
|
|
|
subGraphNum += 1;
|
|
|
|
|
dst_op->primitive->value.type = schema::PrimitiveType_Loop;
|
|
|
|
|
dst_op->primitive->value.value = attr.release();
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) {
|
|
|
|
|
TensorCache tensor_cache;
|
|
|
|
|
// dst_graph->name = onnx_graph.name(); // this is not used
|
|
|
|
|
// find out input names and const names
|
|
|
|
|
FindGraphInputAndConst(onnx_graph);
|
|
|
|
|
// set const tensor
|
|
|
|
|
status = SetGraphConstTensor(onnx_graph, &tensor_cache);
|
|
|
|
|
int status = SetGraphConstTensor(onnx_graph, &tensor_cache);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "SetGraphConstTensor failed";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
@ -512,13 +542,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// init onnx model graph output tensor
|
|
|
|
|
status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// init op node input/output tensor, and dst_op attr
|
|
|
|
|
NoSupportOp::GetInstance()->SetFmkType("ONNX");
|
|
|
|
|
for (const auto &onnx_node : onnx_graph.node()) {
|
|
|
|
@ -544,7 +568,8 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
|
|
|
|
|
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
|
|
|
|
|
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType);
|
|
|
|
|
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType,
|
|
|
|
|
dst_graph.get());
|
|
|
|
|
if (status_node != RET_OK) {
|
|
|
|
|
status = (status == RET_OK ? status_node : status);
|
|
|
|
|
continue;
|
|
|
|
@ -558,9 +583,42 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
// init onnx model graph output tensor
|
|
|
|
|
status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
SetAllTensors(tensor_cache, dst_graph.get());
|
|
|
|
|
dst_graph->name = GetModelName(modelFile);
|
|
|
|
|
return dst_graph.release();
|
|
|
|
|
}
|
|
|
|
|
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
|
|
|
|
const QuantType &quantType) {
|
|
|
|
|
int status = ValidateFileStr(modelFile, ".onnx");
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
onnx::ModelProto onnx_model;
|
|
|
|
|
status = ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
const onnx::GraphProto &onnx_graph = onnx_model.graph();
|
|
|
|
|
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();
|
|
|
|
|
|
|
|
|
|
schema::MetaGraphT *dst_graph = ParseGraph(onnx_graph, quantType);
|
|
|
|
|
if (dst_graph == nullptr) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
dst_graph->name = GetModelName(modelFile);
|
|
|
|
|
return dst_graph;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace lite
|
|
|
|
|
} // namespace mindspore
|
|
|
|
|