!8046 [MS][LITE]while op parser model

Merge pull request !8046 from yefeng/while_op_parser_model
pull/8046/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit 28297a549b

@ -29,13 +29,20 @@ struct MS_API Model {
Uint32Vector output_indices_;
};
using NodePtrVector = std::vector<Node *>;
struct SubGraph {
String name_;
Uint32Vector input_indices_;
Uint32Vector output_indices_;
Uint32Vector node_indices_;
Uint32Vector tensor_indices_;
};
using SubGraphPtrVector = std::vector<SubGraph *>;
String name_;
String version_;
TensorPtrVector all_tensors_;
Uint32Vector input_indices_;
Uint32Vector output_indices_;
NodePtrVector nodes_;
NodePtrVector all_nodes_;
char *buf;
SubGraphPtrVector sub_graphs_;
/// \brief Static method to create a Model pointer.
///

@ -256,6 +256,14 @@ table CNode {
quantType: QuantType = QUANT_NONE;
}
table SubGraph {
name:string;
inputIndices: [uint];
outputIndices: [uint];
nodeIndices: [uint];
tensorIndices: [uint];
}
table MetaGraph {
name: string;
version: string;
@ -265,7 +273,7 @@ table MetaGraph {
mempoolSize: uint;
nodes: [CNode];
allTensors: [Tensor]; // weight + input + output
subGraph : [MetaGraph];
subGraph : [SubGraph];
}
root_type MetaGraph;

@ -26,11 +26,12 @@ namespace mindspore {
namespace lite {
std::vector<size_t> GetGraphInputNodes(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(!(model->sub_graphs_.empty()));
std::vector<size_t> ret;
for (auto graph_in_index : model->input_indices_) {
auto node_size = model->nodes_.size();
for (auto graph_in_index : model->sub_graphs_.front()->input_indices_) {
auto node_size = model->all_nodes_.size();
for (size_t j = 0; j < node_size; ++j) {
auto node = model->nodes_[j];
auto node = model->all_nodes_[j];
MS_ASSERT(node != nullptr);
if (std::any_of(node->input_indices_.begin(), node->input_indices_.end(),
[&](const uint32_t &node_in_index) { return node_in_index == graph_in_index; })) {
@ -46,10 +47,10 @@ std::vector<size_t> GetGraphInputNodes(const lite::Model *model) {
std::vector<size_t> GetGraphOutputNodes(const lite::Model *model) {
MS_ASSERT(model != nullptr);
std::vector<size_t> ret;
for (auto graph_out_index : model->output_indices_) {
auto node_size = model->nodes_.size();
for (auto graph_out_index : model->sub_graphs_.front()->output_indices_) {
auto node_size = model->all_nodes_.size();
for (size_t j = 0; j < node_size; ++j) {
auto node = model->nodes_[j];
auto node = model->all_nodes_[j];
MS_ASSERT(node != nullptr);
if (std::any_of(node->output_indices_.begin(), node->output_indices_.end(),
[&](const uint32_t &node_out_index) { return node_out_index == graph_out_index; })) {
@ -65,9 +66,9 @@ std::vector<size_t> GetGraphOutputNodes(const lite::Model *model) {
std::vector<size_t> GetLinkedPostNodeIdx(const lite::Model *model, const size_t tensor_idx) {
MS_ASSERT(model != nullptr);
std::vector<size_t> post_node_idxes;
auto nodes_size = model->nodes_.size();
auto nodes_size = model->all_nodes_.size();
for (size_t i = 0; i < nodes_size; ++i) {
auto node = model->nodes_[i];
auto node = model->all_nodes_[i];
if (node == nullptr) {
continue;
}

@ -43,7 +43,7 @@ static bool WeightTensorNeedCopy(const lite::Model *model, const uint32_t tensor
MS_ASSERT(model != nullptr);
auto post_node_idxes = GetLinkedPostNodeIdx(model, tensor_idx);
return std::none_of(post_node_idxes.begin(), post_node_idxes.end(), [&](const size_t &post_node_idx) {
auto node = model->nodes_[post_node_idx];
auto node = model->all_nodes_[post_node_idx];
MS_ASSERT(node != nullptr);
return IsContain(packed_op, static_cast<schema::PrimitiveType>(node->primitive_->Type()));
});
@ -126,9 +126,10 @@ int LiteSession::ConvertTensors(const lite::Model *model) {
void LiteSession::InitGraphInputTensors(const lite::Model *model) {
MS_ASSERT(model != nullptr);
auto graph_in_size = model->input_indices_.size();
MS_ASSERT(!(model->sub_graphs_.empty()));
auto graph_in_size = model->sub_graphs_.front()->input_indices_.size();
for (size_t i = 0; i < graph_in_size; ++i) {
auto in_tensor_idx = model->input_indices_[i];
auto in_tensor_idx = model->sub_graphs_.front()->input_indices_[i];
MS_ASSERT(in_tensor_idx < this->tensors_.size());
auto *in_tensor = this->tensors_.at(in_tensor_idx);
MS_ASSERT(in_tensor != nullptr);
@ -148,9 +149,9 @@ void LiteSession::InitGraphOutputTensors(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(this->outputs_.empty());
MS_ASSERT(meta_graph != nullptr);
auto graph_out_size = model->output_indices_.size();
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
for (size_t i = 0; i < graph_out_size; ++i) {
auto out_tensor_idx = model->output_indices_[i];
auto out_tensor_idx = model->sub_graphs_.front()->output_indices_[i];
MS_ASSERT(out_tensor_idx < this->tensors_.size());
auto *out_tensor = this->tensors_.at(out_tensor_idx);
MS_ASSERT(out_tensor != nullptr);
@ -162,9 +163,9 @@ void LiteSession::InitGraphInputMap(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(this->input_map_.empty());
auto graph_input_node_indexes = GetGraphInputNodes(model);
auto graph_in_size = model->input_indices_.size();
auto graph_in_size = model->sub_graphs_.front()->input_indices_.size();
for (auto in_node_index : graph_input_node_indexes) {
auto in_node = model->nodes_[in_node_index];
auto in_node = model->all_nodes_[in_node_index];
MS_ASSERT(in_node != nullptr);
MS_ASSERT(this->input_map_.find(in_node->name()->str()) == this->input_map_.end());
auto in_size = in_node->input_indices_.size();
@ -172,7 +173,7 @@ void LiteSession::InitGraphInputMap(const lite::Model *model) {
auto in_tensor_index = size_t(in_node->input_indices_[i]);
bool is_graph_input = false;
for (size_t j = 0; j < graph_in_size; ++j) {
if (in_tensor_index == model->input_indices_[j]) {
if (in_tensor_index == model->sub_graphs_.front()->input_indices_[j]) {
is_graph_input = true;
break;
}
@ -194,10 +195,11 @@ void LiteSession::InitGraphInputMap(const lite::Model *model) {
void LiteSession::InitGraphOutputNodeMap(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(!(model->sub_graphs_.empty()));
auto graph_output_node_indexes = GetGraphOutputNodes(model);
auto graph_out_size = model->output_indices_.size();
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
for (auto out_node_index : graph_output_node_indexes) {
auto out_node = model->nodes_[out_node_index];
auto out_node = model->all_nodes_[out_node_index];
MS_ASSERT(out_node != nullptr);
MS_ASSERT(this->output_map_.find(out_node->name()->str()) == this->output_map_.end());
auto out_size = out_node->output_indices_.size();
@ -205,7 +207,7 @@ void LiteSession::InitGraphOutputNodeMap(const lite::Model *model) {
auto out_tensor_index = out_node->output_indices_[i];
bool is_graph_output = false;
for (size_t j = 0; j < graph_out_size; ++j) {
if (out_tensor_index == model->output_indices_[j]) {
if (out_tensor_index == model->sub_graphs_.front()->output_indices_[j]) {
is_graph_output = true;
break;
}
@ -227,18 +229,18 @@ void LiteSession::InitGraphOutputNodeMap(const lite::Model *model) {
void LiteSession::InitGraphOutputTensorNames(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(this->output_tensor_names_.empty());
auto out_size = model->output_indices_.size();
auto out_size = model->sub_graphs_.front()->output_indices_.size();
for (size_t i = 0; i < out_size; ++i) {
this->output_tensor_names_.emplace_back(std::to_string(model->output_indices_[i]));
this->output_tensor_names_.emplace_back(std::to_string(model->sub_graphs_.front()->output_indices_[i]));
}
}
void LiteSession::InitGraphOutputTensorMap(const lite::Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(this->output_tensor_map_.empty());
auto graph_out_size = model->output_indices_.size();
auto graph_out_size = model->sub_graphs_.front()->output_indices_.size();
for (size_t i = 0; i < graph_out_size; ++i) {
size_t graph_out_index = model->output_indices_[i];
size_t graph_out_index = model->sub_graphs_.front()->output_indices_[i];
MS_ASSERT(graph_out_index < this->tensors_.size());
auto *out_tensor = this->tensors_.at(graph_out_index);
if (out_tensor == nullptr) {

@ -30,16 +30,22 @@ void Model::Free() {
void Model::Destroy() {
Free();
auto nodes_size = this->nodes_.size();
auto nodes_size = this->all_nodes_.size();
for (size_t i = 0; i < nodes_size; ++i) {
auto node = this->nodes_[i];
auto node = this->all_nodes_[i];
MS_ASSERT(node != nullptr);
MS_ASSERT(node->primitive_ != nullptr);
delete node->primitive_;
node->primitive_ = nullptr;
delete node;
}
this->nodes_.clear();
this->all_nodes_.clear();
auto sub_graph_size = this->sub_graphs_.size();
for (size_t i = 0; i < sub_graph_size; ++i) {
auto sub_graph = this->sub_graphs_[i];
delete sub_graph;
}
}
Model::~Model() { Destroy(); }

@ -53,7 +53,7 @@ bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model) {
node->output_indices_.push_back(size_t(c_node->outputIndex()->GetAs<uint32_t>(j)));
}
}
model->nodes_.push_back(node);
model->all_nodes_.push_back(node);
}
return true;
}
@ -71,6 +71,66 @@ bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model) {
return true;
}
int ConvertSubGraph(const schema::SubGraph *sub_graph, Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(sub_graph != nullptr);
auto *sub_graph_temp = new (std::nothrow) Model::SubGraph();
if (sub_graph_temp == nullptr) {
MS_LOG(ERROR) << "new subGraph fail!";
return RET_ERROR;
}
sub_graph_temp->name_ = sub_graph->name()->c_str();
auto in_count = sub_graph->inputIndices()->size();
for (uint32_t i = 0; i < in_count; ++i) {
sub_graph_temp->input_indices_.push_back(size_t(sub_graph->inputIndices()->GetAs<uint32_t>(i)));
}
auto out_count = sub_graph->outputIndices()->size();
for (uint32_t i = 0; i < out_count; ++i) {
sub_graph_temp->output_indices_.push_back(size_t(sub_graph->outputIndices()->GetAs<uint32_t>(i)));
}
auto node_count = sub_graph->nodeIndices()->size();
for (uint32_t i = 0; i < node_count; ++i) {
sub_graph_temp->node_indices_.push_back(size_t(sub_graph->nodeIndices()->GetAs<uint32_t>(i)));
}
auto tensor_count = sub_graph->nodeIndices()->size();
for (uint32_t i = 0; i < tensor_count; ++i) {
sub_graph_temp->tensor_indices_.push_back(size_t(sub_graph->tensorIndices()->GetAs<uint32_t>(i)));
}
model->sub_graphs_.push_back(sub_graph_temp);
return RET_OK;
}
int MetaGraphMappingSubGraph(const mindspore::schema::MetaGraph *meta_graph, Model *model) {
MS_ASSERT(model != nullptr);
MS_ASSERT(meta_graph != nullptr);
auto *sub_graph_temp = new (std::nothrow) Model::SubGraph();
if (sub_graph_temp == nullptr) {
MS_LOG(ERROR) << "new subGraph fail!";
return RET_ERROR;
}
if (meta_graph->name() != nullptr) {
sub_graph_temp->name_ = meta_graph->name()->c_str();
}
auto in_count = meta_graph->inputIndex()->size();
for (uint32_t i = 0; i < in_count; ++i) {
sub_graph_temp->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i)));
}
auto out_count = meta_graph->outputIndex()->size();
for (uint32_t i = 0; i < out_count; ++i) {
sub_graph_temp->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i)));
}
auto node_count = meta_graph->nodes()->size();
for (uint32_t i = 0; i < node_count; ++i) {
sub_graph_temp->node_indices_.push_back(i);
}
auto tensor_count = meta_graph->nodes()->size();
for (uint32_t i = 0; i < tensor_count; ++i) {
sub_graph_temp->tensor_indices_.push_back(i);
}
model->sub_graphs_.push_back(sub_graph_temp);
return RET_OK;
}
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
if (model_buf == nullptr) {
MS_LOG(ERROR) << "The model buf is nullptr";
@ -121,15 +181,6 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
MS_LOG(WARNING) << "model version is " << model->version_ << ", inference version is " << Version() << " not equal";
}
auto in_count = meta_graph->inputIndex()->size();
for (uint32_t i = 0; i < in_count; ++i) {
model->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i)));
}
auto out_count = meta_graph->outputIndex()->size();
for (uint32_t i = 0; i < out_count; ++i) {
model->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i)));
}
if (!ConvertNodes(meta_graph, model)) {
delete model;
return nullptr;
@ -139,6 +190,28 @@ Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf) {
delete model;
return nullptr;
}
if (meta_graph->subGraph() == nullptr) {
int ret = MetaGraphMappingSubGraph(meta_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter old version model wrong.";
return nullptr;
}
} else {
auto sub_graphs = meta_graph->subGraph();
auto sub_graph_size = sub_graphs->size();
for (size_t i = 0; i < sub_graph_size; i++) {
auto sub_graph = sub_graphs->GetAs<schema::SubGraph>(i);
int ret = ConvertSubGraph(sub_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter subgraph wrong.";
return nullptr;
}
}
}
if (model->sub_graphs_.empty()) {
return nullptr;
}
return model;
}
} // namespace mindspore::lite

@ -24,6 +24,10 @@ bool ConvertNodes(const schema::MetaGraph *meta_graph, Model *model);
bool ConvertTensors(const schema::MetaGraph *meta_graph, Model *model);
int ConvertSubGraph(const schema::SubGraph *sub_graph, Model *model);
int MetaGraphMappingSubGraph(const mindspore::schema::MetaGraph *meta_graph, Model *model);
Model *ImportFromBuffer(const char *model_buf, size_t size, bool take_buf);
} // namespace mindspore::lite
#endif // MINDSPORE_LITE_SRC_MODEL_COMMON_H_

@ -94,9 +94,9 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<Tensor *> *tenso
MS_ASSERT(model != nullptr);
MS_ASSERT(tensors != nullptr);
bool infer_shape_interrupt = false;
uint32_t kernelCount = model->nodes_.size();
uint32_t kernelCount = model->all_nodes_.size();
for (uint32_t i = 0; i < kernelCount; ++i) {
auto node = model->nodes_[i];
auto node = model->all_nodes_[i];
MS_ASSERT(node != nullptr);
std::vector<Tensor *> inputs;
std::vector<Tensor *> outputs;
@ -137,10 +137,10 @@ int Scheduler::BuildKernels(const lite::Model *model, std::vector<Tensor *> *ten
std::vector<kernel::LiteKernel *> *kernels) {
MS_ASSERT(model != nullptr);
MS_ASSERT(tensors != nullptr);
uint32_t kernelCount = model->nodes_.size();
uint32_t kernelCount = model->all_nodes_.size();
auto graph_output_node_indexes = GetGraphOutputNodes(model);
for (uint32_t i = 0; i < kernelCount; ++i) {
auto node = model->nodes_[i];
auto node = model->all_nodes_[i];
MS_ASSERT(node != nullptr);
std::vector<Tensor *> inputs;
std::vector<Tensor *> outputs;

@ -18,6 +18,7 @@
#include "src/common/log_adapter.h"
#include "include/errorcode.h"
#include "src/common/graph_util.h"
#include "src/model_common.h"
namespace mindspore::lite {
@ -61,15 +62,6 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
if (meta_graph->version() != nullptr) {
model->version_ = meta_graph->version()->c_str();
}
auto in_count = meta_graph->inputIndex()->size();
for (uint32_t i = 0; i < in_count; ++i) {
model->input_indices_.push_back(size_t(meta_graph->inputIndex()->GetAs<uint32_t>(i)));
}
auto out_count = meta_graph->outputIndex()->size();
for (uint32_t i = 0; i < out_count; ++i) {
model->output_indices_.push_back(size_t(meta_graph->outputIndex()->GetAs<uint32_t>(i)));
}
if (!ConvertNodes(meta_graph, model)) {
delete model;
return nullptr;
@ -79,6 +71,25 @@ TrainModel *TrainModel::Import(const char *model_buf, size_t size) {
delete model;
return nullptr;
}
if (meta_graph->subGraph() == nullptr) {
int ret = MetaGraphMappingSubGraph(meta_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter old version model wrong.";
return nullptr;
}
} else {
auto sub_graphs = meta_graph->subGraph();
auto sub_graph_size = sub_graphs->size();
for (size_t i = 0; i < sub_graph_size; i++) {
auto sub_graph = sub_graphs->GetAs<schema::SubGraph>(i);
int ret = ConvertSubGraph(sub_graph, model);
if (ret != RET_OK) {
MS_LOG(ERROR) << "converter subgraph wrong.";
return nullptr;
}
}
}
return model;
}

@ -141,7 +141,7 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std
return RET_OK;
}
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph,
TensorCache *tensor_cache) {
for (const auto &input_value : onnx_graph.input()) {
auto ret = tensor_cache->FindTensor(input_value.name());
@ -152,13 +152,13 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
return status;
}
MS_LOG(DEBUG) << "input_value name: " << input_value.name() << ", graph input index: " << index;
graph->inputIndex.emplace_back(static_cast<uint32_t>(index));
graph->inputIndices.emplace_back(static_cast<uint32_t>(index));
}
}
return RET_OK;
}
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph,
TensorCache *tensor_cache) {
for (const auto &output_value : onnx_graph.output()) {
int index;
@ -170,15 +170,15 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
return status;
}
}
graph->outputIndex.emplace_back(index);
graph->outputIndices.emplace_back(index);
MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << index;
}
return RET_OK;
}
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph, TensorCache *tensor_cache,
const QuantType &quant_type) {
schema::SubGraphT *sub_graph, schema::MetaGraphT *graph,
TensorCache *tensor_cache, const QuantType &quant_type) {
std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>();
dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0);
dst_op_1->quantType = quant_type;
@ -189,6 +189,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
SetOpInputIndex(matmul_inputs, dst_op_1.get(), onnx_node, tensor_cache);
SetOpOutputIndex(matmul_outputs, dst_op_1.get(), tensor_cache);
graph->nodes.emplace_back(std::move(dst_op_1));
sub_graph->nodeIndices.push_back(graph->nodes.size() - 1);
std::unique_ptr<schema::CNodeT> dst_op_2 = std::make_unique<schema::CNodeT>();
dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0);
@ -199,6 +200,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
SetOpInputIndex(biasadd_inputs, dst_op_2.get(), onnx_node, tensor_cache);
SetOpOutputIndex(biasadd_outputs, dst_op_2.get(), tensor_cache);
graph->nodes.emplace_back(std::move(dst_op_2));
sub_graph->nodeIndices.push_back(graph->nodes.size() - 1);
}
STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
@ -511,16 +513,20 @@ STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodePr
return RET_NULL_PTR;
}
attr->subGraphIndex = subGraphNum;
auto sub_graph = std::make_unique<schema::MetaGraphT>();
sub_graph.reset(ParseGraph(onnx_node.attribute().at(0).g(), quantType));
auto sub_graph = std::make_unique<schema::SubGraphT>();
int ret = ParseGraph(dst_graph, sub_graph.get(), onnx_node.attribute().at(0).g(), quantType);
dst_graph->subGraph.push_back(std::move(sub_graph));
subGraphNum += 1;
if (ret != RET_OK) {
return ret;
}
dst_op->primitive->value.type = schema::PrimitiveType_Loop;
dst_op->primitive->value.value = attr.release();
return RET_OK;
}
schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) {
int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph,
const onnx::GraphProto &onnx_graph, const QuantType &quantType) {
TensorCache tensor_cache;
// dst_graph->name = onnx_graph.name(); // this is not used
// find out input names and const names
@ -530,15 +536,16 @@ schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_gra
if (status != RET_OK) {
MS_LOG(ERROR) << "SetGraphConstTensor failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return RET_ERROR;
}
auto dst_graph = std::make_unique<schema::MetaGraphT>();
// init onnx model graph input tensor
status = SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
status = SetGraphInputTensor(onnx_graph, dst_sub_graph, &tensor_cache);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetGraphInputTensor failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return RET_ERROR;
}
// init op node input/output tensor, and dst_op attr
@ -550,7 +557,7 @@ schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_gra
}
if (onnx_node.op_type() == "Gemm") {
if (status == RET_OK) {
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache, quantType);
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_sub_graph, dst_graph, &tensor_cache, quantType);
}
continue;
} else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") {
@ -566,30 +573,31 @@ schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_gra
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType,
dst_graph.get());
status_node =
ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType, dst_graph);
if (status_node != RET_OK) {
status = (status == RET_OK ? status_node : status);
continue;
}
dst_graph->nodes.emplace_back(std::move(dst_op));
dst_sub_graph->nodeIndices.push_back((dst_graph->nodes.size() - 1));
}
if (status != RET_OK) {
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
for (auto &tensor : tensor_cache.GetCachedTensor()) {
delete tensor;
}
return nullptr;
return RET_ERROR;
}
// init onnx model graph output tensor
status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
status = SetGraphOutputTensor(onnx_graph, dst_sub_graph, &tensor_cache);
if (status != RET_OK) {
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
return nullptr;
return RET_ERROR;
}
SetAllTensors(tensor_cache, dst_graph.get());
return dst_graph.release();
SetAllTensors(tensor_cache, dst_graph);
return RET_OK;
}
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
@ -612,12 +620,29 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
const onnx::GraphProto &onnx_graph = onnx_model.graph();
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();
schema::MetaGraphT *dst_graph = ParseGraph(onnx_graph, quantType);
if (dst_graph == nullptr) {
auto dst_graph = std::make_unique<schema::MetaGraphT>();
auto dst_sub_graph = std::make_unique<schema::SubGraphT>();
int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quantType);
dst_graph->subGraph.push_back(std::move(dst_sub_graph));
subGraphNum += 1;
if (ret == RET_ERROR) {
return nullptr;
}
dst_graph->name = GetModelName(modelFile);
return dst_graph;
std::vector<uint32_t> input_temp_index;
for (size_t i = 0; i < dst_graph->subGraph.front()->inputIndices.size(); i++) {
input_temp_index.push_back(dst_graph->subGraph.front()->inputIndices[i]);
}
dst_graph->inputIndex = input_temp_index;
std::vector<uint32_t> output_temp_index;
for (size_t i = 0; i < dst_graph->subGraph.front()->outputIndices.size(); i++) {
output_temp_index.push_back(dst_graph->subGraph.front()->outputIndices[i]);
}
dst_graph->outputIndex = output_temp_index;
return dst_graph.release();
}
} // namespace lite

@ -41,7 +41,10 @@ class OnnxModelParser : public ModelParser {
virtual ~OnnxModelParser();
schema::MetaGraphT *ParseGraph(const onnx::GraphProto &graph, const QuantType &quantType = QuantType_QUANT_NONE);
// schema::MetaGraphT *ParseGraph(const onnx::GraphProto &graph, const QuantType &quantType = QuantType_QUANT_NONE);
int ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph, const onnx::GraphProto &onnx_graph,
const QuantType &quantType);
schema::MetaGraphT *ParseToFb(const std::string &modelFile, const std::string &weightFile,
const QuantType &quantType = QuantType_QUANT_NONE) override;
@ -52,9 +55,9 @@ class OnnxModelParser : public ModelParser {
STATUS SetGraphConstTensor(const onnx::GraphProto &onnx_graph, TensorCache *tensor_cache);
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache);
STATUS SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, TensorCache *tensor_cache);
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache);
STATUS SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph, TensorCache *tensor_cache);
STATUS AddValueInfo(const onnx::ValueInfoProto &proto, const std::string &name, const Category &type,
TensorCache *tensor_cache, int *index);
@ -67,7 +70,8 @@ class OnnxModelParser : public ModelParser {
const QuantType &quantType, schema::MetaGraphT *dst_graph);
void ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
schema::MetaGraphT *graph, TensorCache *tensor_cache, const QuantType &quant_type);
schema::SubGraphT *sub_graph, schema::MetaGraphT *graph, TensorCache *tensor_cache,
const QuantType &quant_type);
STATUS ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache);

Loading…
Cancel
Save