|
|
|
@ -21,6 +21,7 @@
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include "tools/common/graph_util.h"
|
|
|
|
|
#include "src/common/utils.h"
|
|
|
|
|
#include "tools/common/protobuf_utils.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace lite {
|
|
|
|
@ -54,36 +55,7 @@ std::vector<int32_t> OnnxModelParser::GetDimsFromOnnxValue(const onnx::ValueInfo
|
|
|
|
|
return dims;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::ReadOnnxModelFromBinary(const std::string &modelFile,
|
|
|
|
|
google::protobuf::Message *onnx_model) {
|
|
|
|
|
std::unique_ptr<char> onnx_file(new (std::nothrow) char[PATH_MAX]{0});
|
|
|
|
|
#ifdef _WIN32
|
|
|
|
|
if (_fullpath(onnx_file.get(), modelFile.c_str(), 1024) == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "get realpath " << modelFile << " fail";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
if (realpath(modelFile.c_str(), onnx_file.get()) == nullptr) {
|
|
|
|
|
MS_LOG(ERROR) << "get realpath " << modelFile << " fail";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
int fd = open(onnx_file.get(), O_RDONLY);
|
|
|
|
|
google::protobuf::io::FileInputStream input(fd);
|
|
|
|
|
google::protobuf::io::CodedInputStream code_input(&input);
|
|
|
|
|
code_input.SetTotalBytesLimit(INT_MAX, 536870912);
|
|
|
|
|
bool ret = onnx_model->ParseFromCodedStream(&code_input);
|
|
|
|
|
if (!ret) {
|
|
|
|
|
MS_LOG(ERROR) << "load onnx file failed";
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
(void)close(fd);
|
|
|
|
|
onnx_file.release();
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache) {
|
|
|
|
|
MS_LOG(DEBUG) << "set onnx constant tensors";
|
|
|
|
|
for (const auto &onnx_const_value : onnx_graph.initializer()) {
|
|
|
|
|
int index;
|
|
|
|
@ -119,11 +91,8 @@ STATUS OnnxModelParser::SetGraphConstTensor(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto,
|
|
|
|
|
const std::string &name,
|
|
|
|
|
const TensorType &type,
|
|
|
|
|
TensorCache *tensor_cache,
|
|
|
|
|
int *index) {
|
|
|
|
|
STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const TensorType &type,
|
|
|
|
|
TensorCache *tensor_cache, int *index) {
|
|
|
|
|
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.type().tensor_type().elem_type()));
|
|
|
|
|
if (data_type == kTypeUnknown) {
|
|
|
|
|
MS_LOG(ERROR) << "not support onnx data type "
|
|
|
|
@ -143,11 +112,8 @@ STATUS OnnxModelParser::AddValueInfo(const onnx::ValueInfoProto &proto,
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto,
|
|
|
|
|
const std::string &name,
|
|
|
|
|
const TensorType &type,
|
|
|
|
|
TensorCache *tensor_cache,
|
|
|
|
|
int *index) {
|
|
|
|
|
STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std::string &name, const TensorType &type,
|
|
|
|
|
TensorCache *tensor_cache, int *index) {
|
|
|
|
|
auto data_type = GetDataTypeFromOnnx(static_cast<onnx::TensorProto_DataType>(proto.data_type()));
|
|
|
|
|
if (data_type == kTypeUnknown) {
|
|
|
|
|
MS_LOG(ERROR) << "not support onnx data type " << static_cast<onnx::TensorProto_DataType>(proto.data_type());
|
|
|
|
@ -174,8 +140,7 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto,
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
schema::MetaGraphT *graph,
|
|
|
|
|
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
for (const auto &input_value : onnx_graph.input()) {
|
|
|
|
|
auto ret = tensor_cache->FindTensor(input_value.name());
|
|
|
|
@ -192,8 +157,7 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
schema::MetaGraphT *graph,
|
|
|
|
|
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
for (const auto &output_value : onnx_graph.output()) {
|
|
|
|
|
int index;
|
|
|
|
@ -207,10 +171,8 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
const onnx::NodeProto &onnx_node,
|
|
|
|
|
schema::MetaGraphT *graph,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
|
|
|
|
schema::MetaGraphT *graph, TensorCache *tensor_cache) {
|
|
|
|
|
std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>();
|
|
|
|
|
dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0);
|
|
|
|
|
ParseOnnxNodeAttr(onnx_graph, onnx_node, "MatMul", dst_op_1.get());
|
|
|
|
@ -231,8 +193,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
graph->nodes.emplace_back(std::move(dst_op_2));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
|
|
|
|
|
// convert GivenTensorFill node to a weight/bias tensor
|
|
|
|
|
auto ret = tensor_cache->FindTensor(onnx_node.output(0));
|
|
|
|
|
if (ret < 0) {
|
|
|
|
@ -284,10 +245,8 @@ STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node,
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
const onnx::NodeProto &onnx_node,
|
|
|
|
|
schema::CNodeT *dst_op,
|
|
|
|
|
schema::TensorT *dst_tensor,
|
|
|
|
|
STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
|
|
|
|
schema::CNodeT *dst_op, schema::TensorT *dst_tensor,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
// change op_type() to name(), that is unique
|
|
|
|
|
dst_op->name = onnx_node.op_type() + "_" + onnx_node.output(0);
|
|
|
|
@ -319,11 +278,8 @@ STATUS OnnxModelParser::ParseOnnxNodeToDstOp(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
const onnx::NodeProto &onnx_node,
|
|
|
|
|
schema::CNodeT *dst_op,
|
|
|
|
|
schema::TensorT *dst_tensor,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
|
|
|
|
schema::CNodeT *dst_op, schema::TensorT *dst_tensor, TensorCache *tensor_cache) {
|
|
|
|
|
MS_ASSERT(dst_op != nullptr);
|
|
|
|
|
MS_ASSERT(tensor_cache != nullptr);
|
|
|
|
|
std::vector<string> quant_node_name;
|
|
|
|
@ -380,10 +336,8 @@ void OnnxModelParser::SetOpQuantParams(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
const onnx::NodeProto &onnx_node,
|
|
|
|
|
const string &onnx_op_type,
|
|
|
|
|
schema::CNodeT *dst_op) {
|
|
|
|
|
STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
|
|
|
|
const string &onnx_op_type, schema::CNodeT *dst_op) {
|
|
|
|
|
auto node_parser = OnnxNodeParserRegistry::GetInstance()->GetNodeParser(onnx_op_type);
|
|
|
|
|
if (node_parser == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "not find " << onnx_op_type << ", node parser is nullptr";
|
|
|
|
@ -392,10 +346,8 @@ STATUS OnnxModelParser::ParseOnnxNodeAttr(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
return node_parser->Parse(onnx_graph, onnx_node, dst_op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
|
|
|
|
|
schema::CNodeT *dst_op,
|
|
|
|
|
const onnx::NodeProto &onnx_node,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs, schema::CNodeT *dst_op,
|
|
|
|
|
const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
|
|
|
|
|
for (const auto &onnx_node_input : node_inputs) {
|
|
|
|
|
auto index = tensor_cache->FindTensor(onnx_node_input);
|
|
|
|
|
if (index < 0) {
|
|
|
|
@ -408,8 +360,7 @@ STATUS OnnxModelParser::SetOpInputIndex(const std::vector<string> &node_inputs,
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs,
|
|
|
|
|
schema::CNodeT *dst_op,
|
|
|
|
|
STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs, schema::CNodeT *dst_op,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
for (const auto &onnx_node_output : node_outputs) {
|
|
|
|
|
auto index = tensor_cache->FindTensor(onnx_node_output);
|
|
|
|
@ -424,8 +375,7 @@ STATUS OnnxModelParser::SetOpOutputIndex(const std::vector<string> &node_outputs
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value,
|
|
|
|
|
schema::TensorT *tensor) {
|
|
|
|
|
STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_value, schema::TensorT *tensor) {
|
|
|
|
|
size_t data_count = 1;
|
|
|
|
|
std::for_each(tensor->dims.begin(), tensor->dims.end(), [&data_count](int dim) { data_count *= dim; });
|
|
|
|
|
size_t data_size = 0;
|
|
|
|
@ -484,8 +434,7 @@ STATUS OnnxModelParser::CopyOnnxTensorData(const onnx::TensorProto &onnx_const_v
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache,
|
|
|
|
|
schema::MetaGraphT *graphDef) {
|
|
|
|
|
STATUS OnnxModelParser::SetAllTensors(const TensorCache &tensor_cache, schema::MetaGraphT *graphDef) {
|
|
|
|
|
std::vector<schema::TensorT *> tensors = tensor_cache.GetCachedTensor();
|
|
|
|
|
for (auto iter : tensors) {
|
|
|
|
|
std::unique_ptr<schema::TensorT> temp(iter);
|
|
|
|
@ -507,17 +456,16 @@ void OnnxModelParser::FindGraphInputAndConst(const onnx::GraphProto &onnx_graph)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile,
|
|
|
|
|
const std::string &weightFile,
|
|
|
|
|
const QuantType &quantType) {
|
|
|
|
|
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
|
|
|
|
const QuantType &quantType) {
|
|
|
|
|
if (ValidateFileStr(modelFile, ".onnx") != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Input illegal: modelFile must be *.onnx";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto dst_graph = std::make_unique<schema::MetaGraphT>();
|
|
|
|
|
|
|
|
|
|
onnx::ModelProto onnx_model;
|
|
|
|
|
if (ReadOnnxModelFromBinary(modelFile, &onnx_model) != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "read onnx model fail";
|
|
|
|
|
if (ReadProtoFromBinaryFile((const char *)modelFile.c_str(), &onnx_model) != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "Read onnx model file failed, model path: " << modelFile;
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
const onnx::GraphProto &onnx_graph = onnx_model.graph();
|
|
|
|
@ -531,6 +479,7 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile,
|
|
|
|
|
MS_LOG(ERROR) << "SetGraphConstTensor failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
auto dst_graph = std::make_unique<schema::MetaGraphT>();
|
|
|
|
|
// init onnx model graph input tensor
|
|
|
|
|
if (SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache)) {
|
|
|
|
|
MS_LOG(ERROR) << "SetGraphInputTensor failed";
|
|
|
|
|