|
|
|
@ -141,7 +141,7 @@ STATUS OnnxModelParser::AddTensorProto(const onnx::TensorProto &proto, const std
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
|
|
|
|
|
STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
for (const auto &input_value : onnx_graph.input()) {
|
|
|
|
|
auto ret = tensor_cache->FindTensor(input_value.name());
|
|
|
|
@ -152,13 +152,13 @@ STATUS OnnxModelParser::SetGraphInputTensor(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "input_value name: " << input_value.name() << ", graph input index: " << index;
|
|
|
|
|
graph->inputIndex.emplace_back(static_cast<uint32_t>(index));
|
|
|
|
|
graph->inputIndices.emplace_back(static_cast<uint32_t>(index));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::MetaGraphT *graph,
|
|
|
|
|
STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph, schema::SubGraphT *graph,
|
|
|
|
|
TensorCache *tensor_cache) {
|
|
|
|
|
for (const auto &output_value : onnx_graph.output()) {
|
|
|
|
|
int index;
|
|
|
|
@ -170,15 +170,15 @@ STATUS OnnxModelParser::SetGraphOutputTensor(const onnx::GraphProto &onnx_graph,
|
|
|
|
|
return status;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
graph->outputIndex.emplace_back(index);
|
|
|
|
|
graph->outputIndices.emplace_back(index);
|
|
|
|
|
MS_LOG(DEBUG) << "output_value name: " << output_value.name() << ", graph output index: " << index;
|
|
|
|
|
}
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, const onnx::NodeProto &onnx_node,
|
|
|
|
|
schema::MetaGraphT *graph, TensorCache *tensor_cache,
|
|
|
|
|
const QuantType &quant_type) {
|
|
|
|
|
schema::SubGraphT *sub_graph, schema::MetaGraphT *graph,
|
|
|
|
|
TensorCache *tensor_cache, const QuantType &quant_type) {
|
|
|
|
|
std::unique_ptr<schema::CNodeT> dst_op_1 = std::make_unique<schema::CNodeT>();
|
|
|
|
|
dst_op_1->name = "Gemm_MatMul_" + onnx_node.output(0);
|
|
|
|
|
dst_op_1->quantType = quant_type;
|
|
|
|
@ -189,6 +189,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
|
|
|
|
|
SetOpInputIndex(matmul_inputs, dst_op_1.get(), onnx_node, tensor_cache);
|
|
|
|
|
SetOpOutputIndex(matmul_outputs, dst_op_1.get(), tensor_cache);
|
|
|
|
|
graph->nodes.emplace_back(std::move(dst_op_1));
|
|
|
|
|
sub_graph->nodeIndices.push_back(graph->nodes.size() - 1);
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<schema::CNodeT> dst_op_2 = std::make_unique<schema::CNodeT>();
|
|
|
|
|
dst_op_2->name = "Gemm_BiasAdd_" + onnx_node.output(0);
|
|
|
|
@ -199,6 +200,7 @@ void OnnxModelParser::ParseOnnxGemmNode(const onnx::GraphProto &onnx_graph, cons
|
|
|
|
|
SetOpInputIndex(biasadd_inputs, dst_op_2.get(), onnx_node, tensor_cache);
|
|
|
|
|
SetOpOutputIndex(biasadd_outputs, dst_op_2.get(), tensor_cache);
|
|
|
|
|
graph->nodes.emplace_back(std::move(dst_op_2));
|
|
|
|
|
sub_graph->nodeIndices.push_back(graph->nodes.size() - 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
STATUS OnnxModelParser::ParseOnnxGivenFillNode(const onnx::NodeProto &onnx_node, TensorCache *tensor_cache) {
|
|
|
|
@ -511,16 +513,20 @@ STATUS OnnxModelParser::ParseSubgraph(schema::CNodeT *dst_op, const onnx::NodePr
|
|
|
|
|
return RET_NULL_PTR;
|
|
|
|
|
}
|
|
|
|
|
attr->subGraphIndex = subGraphNum;
|
|
|
|
|
auto sub_graph = std::make_unique<schema::MetaGraphT>();
|
|
|
|
|
sub_graph.reset(ParseGraph(onnx_node.attribute().at(0).g(), quantType));
|
|
|
|
|
auto sub_graph = std::make_unique<schema::SubGraphT>();
|
|
|
|
|
int ret = ParseGraph(dst_graph, sub_graph.get(), onnx_node.attribute().at(0).g(), quantType);
|
|
|
|
|
dst_graph->subGraph.push_back(std::move(sub_graph));
|
|
|
|
|
subGraphNum += 1;
|
|
|
|
|
if (ret != RET_OK) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
dst_op->primitive->value.type = schema::PrimitiveType_Loop;
|
|
|
|
|
dst_op->primitive->value.value = attr.release();
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_graph, const QuantType &quantType) {
|
|
|
|
|
int OnnxModelParser::ParseGraph(schema::MetaGraphT *dst_graph, schema::SubGraphT *dst_sub_graph,
|
|
|
|
|
const onnx::GraphProto &onnx_graph, const QuantType &quantType) {
|
|
|
|
|
TensorCache tensor_cache;
|
|
|
|
|
// dst_graph->name = onnx_graph.name(); // this is not used
|
|
|
|
|
// find out input names and const names
|
|
|
|
@ -530,15 +536,16 @@ schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_gra
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "SetGraphConstTensor failed";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
auto dst_graph = std::make_unique<schema::MetaGraphT>();
|
|
|
|
|
|
|
|
|
|
// init onnx model graph input tensor
|
|
|
|
|
status = SetGraphInputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
|
|
|
|
|
|
|
|
|
|
status = SetGraphInputTensor(onnx_graph, dst_sub_graph, &tensor_cache);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "SetGraphInputTensor failed";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// init op node input/output tensor, and dst_op attr
|
|
|
|
@ -550,7 +557,7 @@ schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_gra
|
|
|
|
|
}
|
|
|
|
|
if (onnx_node.op_type() == "Gemm") {
|
|
|
|
|
if (status == RET_OK) {
|
|
|
|
|
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_graph.get(), &tensor_cache, quantType);
|
|
|
|
|
ParseOnnxGemmNode(onnx_graph, onnx_node, dst_sub_graph, dst_graph, &tensor_cache, quantType);
|
|
|
|
|
}
|
|
|
|
|
continue;
|
|
|
|
|
} else if (onnx_node.op_type() == "Int8GivenIntTensorFill" || onnx_node.op_type() == "Int8GivenTensorFill") {
|
|
|
|
@ -566,30 +573,31 @@ schema::MetaGraphT *OnnxModelParser::ParseGraph(const onnx::GraphProto &onnx_gra
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<schema::CNodeT> dst_op = std::make_unique<schema::CNodeT>();
|
|
|
|
|
std::unique_ptr<schema::TensorT> dst_tensor = std::make_unique<schema::TensorT>();
|
|
|
|
|
status_node = ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType,
|
|
|
|
|
dst_graph.get());
|
|
|
|
|
status_node =
|
|
|
|
|
ParseOnnxNodeToDstOp(onnx_graph, onnx_node, dst_op.get(), dst_tensor.get(), &tensor_cache, quantType, dst_graph);
|
|
|
|
|
if (status_node != RET_OK) {
|
|
|
|
|
status = (status == RET_OK ? status_node : status);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
dst_graph->nodes.emplace_back(std::move(dst_op));
|
|
|
|
|
dst_sub_graph->nodeIndices.push_back((dst_graph->nodes.size() - 1));
|
|
|
|
|
}
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
for (auto &tensor : tensor_cache.GetCachedTensor()) {
|
|
|
|
|
delete tensor;
|
|
|
|
|
}
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
// init onnx model graph output tensor
|
|
|
|
|
status = SetGraphOutputTensor(onnx_graph, dst_graph.get(), &tensor_cache);
|
|
|
|
|
status = SetGraphOutputTensor(onnx_graph, dst_sub_graph, &tensor_cache);
|
|
|
|
|
if (status != RET_OK) {
|
|
|
|
|
MS_LOG(ERROR) << "SetGraphOutputTensor failed";
|
|
|
|
|
ReturnCode::GetSingleReturnCode()->UpdateReturnCode(status);
|
|
|
|
|
return nullptr;
|
|
|
|
|
return RET_ERROR;
|
|
|
|
|
}
|
|
|
|
|
SetAllTensors(tensor_cache, dst_graph.get());
|
|
|
|
|
return dst_graph.release();
|
|
|
|
|
SetAllTensors(tensor_cache, dst_graph);
|
|
|
|
|
return RET_OK;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, const std::string &weightFile,
|
|
|
|
@ -612,12 +620,29 @@ schema::MetaGraphT *OnnxModelParser::ParseToFb(const std::string &modelFile, con
|
|
|
|
|
const onnx::GraphProto &onnx_graph = onnx_model.graph();
|
|
|
|
|
MS_LOG(INFO) << "model producer name: " << onnx_model.producer_name() << ", graph name: " << onnx_graph.name();
|
|
|
|
|
|
|
|
|
|
schema::MetaGraphT *dst_graph = ParseGraph(onnx_graph, quantType);
|
|
|
|
|
if (dst_graph == nullptr) {
|
|
|
|
|
auto dst_graph = std::make_unique<schema::MetaGraphT>();
|
|
|
|
|
auto dst_sub_graph = std::make_unique<schema::SubGraphT>();
|
|
|
|
|
int ret = ParseGraph(dst_graph.get(), dst_sub_graph.get(), onnx_graph, quantType);
|
|
|
|
|
dst_graph->subGraph.push_back(std::move(dst_sub_graph));
|
|
|
|
|
subGraphNum += 1;
|
|
|
|
|
if (ret == RET_ERROR) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
dst_graph->name = GetModelName(modelFile);
|
|
|
|
|
return dst_graph;
|
|
|
|
|
|
|
|
|
|
std::vector<uint32_t> input_temp_index;
|
|
|
|
|
for (size_t i = 0; i < dst_graph->subGraph.front()->inputIndices.size(); i++) {
|
|
|
|
|
input_temp_index.push_back(dst_graph->subGraph.front()->inputIndices[i]);
|
|
|
|
|
}
|
|
|
|
|
dst_graph->inputIndex = input_temp_index;
|
|
|
|
|
|
|
|
|
|
std::vector<uint32_t> output_temp_index;
|
|
|
|
|
for (size_t i = 0; i < dst_graph->subGraph.front()->outputIndices.size(); i++) {
|
|
|
|
|
output_temp_index.push_back(dst_graph->subGraph.front()->outputIndices[i]);
|
|
|
|
|
}
|
|
|
|
|
dst_graph->outputIndex = output_temp_index;
|
|
|
|
|
|
|
|
|
|
return dst_graph.release();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace lite
|
|
|
|
|