You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/metadef/graph/utils/ge_ir_utils.cc

1179 lines
50 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/utils/ge_ir_utils.h"
#include <utility>
#include "framework/common/debug/ge_log.h"
namespace {
const char *const kControlAnchorIndex = ":-1";
const char *const kNodeTypeForSubgraph = "subgraph";
const char *const kPrefixForInputDesc = "input_desc_attr_";
const char *const kPrefixForOutputDesc = "output_desc_attr_";
const char *const kDumpGEGraph = "DUMP_GE_GRAPH";
const int8_t kMaxRecursionDepth = 10;
const char *const kDumpGeGraph = std::getenv(kDumpGEGraph);
const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP;
const int64_t kInputPrefixLength = 5;
const int64_t kOutputPrefixLength = 6;
using AttrDefPair = ::google::protobuf::MapPair<std::string, ge::proto::AttrDef>;
} // namespace
namespace ge {
// Part 1: from IR convert to ONNX Protobuf
static const std::map<ge::DataType, onnx::TensorProto_DataType> kGeDataTypeToOnnxMap = {
{DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64},
{DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32},
{DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8},
{DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16},
{DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16},
{DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL},
};
onnx::TensorProto_DataType OnnxUtils::EncodeDataType(DataType data_type) {
auto it = kGeDataTypeToOnnxMap.find(data_type);
if (it != kGeDataTypeToOnnxMap.end()) {
return it->second;
} else {
GELOGW("EncodeDataType: datatype not support %u", data_type);
return onnx::TensorProto_DataType_UNDEFINED;
}
}
void OnnxUtils::AddAttrProtoFromAttribute(const std::pair<const std::string, ge::GeAttrValue> &string_attr_value,
onnx::NodeProto *node_proto) {
if (node_proto == nullptr) {
GELOGE(FAILED, "Node proto is nullptr.");
return;
}
auto attr = node_proto->add_attribute();
if (attr == nullptr) {
GELOGE(GRAPH_FAILED, "attr is nullptr.");
return;
}
auto attr_name = string_attr_value.first;
attr->set_name(attr_name);
auto attr_value = string_attr_value.second;
auto value_type = attr_value.GetValueType();
switch (value_type) {
case GeAttrValue::VT_FLOAT: {
GeAttrValue::FLOAT data_f = 0;
(void)attr_value.GetValue(data_f);
attr->set_f(data_f);
attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
break;
}
case GeAttrValue::VT_LIST_FLOAT: {
GeAttrValue::LIST_FLOAT data_fs = {};
(void)attr_value.GetValue(data_fs);
attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
for (auto &v : data_fs) {
attr->add_floats(v);
}
break;
}
case GeAttrValue::VT_INT: {
GeAttrValue::INT data_i = 0;
(void)attr_value.GetValue(data_i);
attr->set_type(onnx::AttributeProto_AttributeType_INT);
attr->set_i(data_i);
break;
}
case GeAttrValue::VT_LIST_INT: {
GeAttrValue::LIST_INT data_is = {};
(void)attr_value.GetValue(data_is);
attr->set_type(onnx::AttributeProto_AttributeType_INTS);
for (auto &v : data_is) {
attr->add_ints(v);
}
break;
}
case GeAttrValue::VT_STRING: {
GeAttrValue::STR data_s;
(void)attr_value.GetValue(data_s);
attr->set_type(onnx::AttributeProto_AttributeType_STRING);
attr->set_s(data_s);
break;
}
case GeAttrValue::VT_LIST_STRING: {
GeAttrValue::LIST_STR data_ss = {};
(void)attr_value.GetValue(data_ss);
attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
for (auto &v : data_ss) {
attr->add_strings(v);
}
break;
}
default:
GELOGW("GeAttrValue ValueType: %u is not supported for now", value_type);
break;
}
}
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
void *data) {
if (node_proto == nullptr) {
GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
return;
}
auto attr = node_proto->add_attribute();
if (attr == nullptr) {
GELOGE(GRAPH_FAILED, "attr is nullptr.");
return;
}
attr->set_name(name);
switch (type) {
case onnx::AttributeProto_AttributeType_FLOAT:
attr->set_f((*(static_cast<float *>(data))));
attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
break;
case onnx::AttributeProto_AttributeType_FLOATS:
attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
for (auto &v : (*(static_cast<std::vector<float> *>(data)))) {
attr->add_floats(v);
}
break;
case onnx::AttributeProto_AttributeType_INT:
attr->set_type(onnx::AttributeProto_AttributeType_INT);
attr->set_i((*(static_cast<int64_t *>(data))));
break;
case onnx::AttributeProto_AttributeType_INTS:
attr->set_type(onnx::AttributeProto_AttributeType_INTS);
for (auto &v : *(static_cast<std::vector<int64_t> *>(data))) {
attr->add_ints(v);
}
break;
case onnx::AttributeProto_AttributeType_STRING:
attr->set_type(onnx::AttributeProto_AttributeType_STRING);
attr->set_s((*(static_cast<std::string *>(data))));
break;
case onnx::AttributeProto_AttributeType_STRINGS:
attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
for (auto &v : *(static_cast<std::vector<std::string> *>(data))) {
attr->add_strings(v);
}
break;
default:
GELOGW("AttributeProto AttributeType: %u is not supported for now", type);
break;
}
}
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
::google::protobuf::RepeatedField<::google::protobuf::int64> data) {
if (node_proto == nullptr) {
GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
return;
}
if (!data.empty()) {
auto attr = node_proto->add_attribute();
if (attr == nullptr) {
GELOGE(GRAPH_FAILED, "attr is nullptr.");
return;
}
attr->set_name(name);
for (auto &v : data) {
attr->add_ints(v);
}
attr->set_type(type);
}
}
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
::google::protobuf::RepeatedField<bool> data) {
if (node_proto == nullptr) {
GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str());
return;
}
if (!data.empty()) {
auto attr = node_proto->add_attribute();
if (attr == nullptr) {
GELOGE(GRAPH_FAILED, "attr is nullptr.");
return;
}
attr->set_name(name);
for (auto &v : data) {
attr->add_ints(static_cast<int64_t>(v));
}
attr->set_type(type);
}
}
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
::google::protobuf::RepeatedField<float> data) {
if (node_proto == nullptr) {
GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
return;
}
if (!data.empty()) {
auto attr = node_proto->add_attribute();
if (attr == nullptr) {
GELOGE(GRAPH_FAILED, "attr is nullptr.");
return;
}
attr->set_name(name);
for (auto &v : data) {
attr->add_floats(v);
}
attr->set_type(type);
}
}
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
::google::protobuf::RepeatedPtrField<::std::string> data) {
if (node_proto == nullptr) {
GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str());
return;
}
if (!data.empty()) {
auto attr = node_proto->add_attribute();
if (attr == nullptr) {
GELOGE(GRAPH_FAILED, "attr is nullptr.");
return;
}
attr->set_name(name);
for (auto &v : data) {
attr->add_strings(v);
}
attr->set_type(type);
}
}
void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc) {
if (node_proto == nullptr || op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "node_proto or op_desc is nullptr");
return;
}
// Input describes
auto size_in = op_desc->GetAllInputsSize();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_nums", &size_in);
if (size_in > 0) {
for (uint32_t i = 0; i < size_in; i++) {
auto input_desc = op_desc->GetInputDescPtrDfault(i);
if (input_desc != nullptr) {
auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_dtype:" + std::to_string(i),
&data_type);
auto data_type_origin = TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
"input_desc_origin_dtype:" + std::to_string(i), &data_type_origin);
auto dims = input_desc->GetShape().GetDims();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_desc_shape:" + std::to_string(i),
&dims);
auto dims_origin = input_desc->GetOriginShape().GetDims();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS,
"input_desc_origin_shape:" + std::to_string(i), &dims_origin);
auto layout = TypeUtils::FormatToSerialString(input_desc->GetFormat());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_layout:" + std::to_string(i),
&layout);
auto layout_origin = TypeUtils::FormatToSerialString(input_desc->GetOriginFormat());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
"input_desc_origin_layout:" + std::to_string(i), &layout_origin);
auto tensor_descriptor = input_desc->tensor_descriptor_.GetProtoMsg();
if (tensor_descriptor != nullptr) {
auto size = tensor_descriptor->size();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_size:" + std::to_string(i),
&size);
auto weight_size = tensor_descriptor->weight_size();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
"input_desc_weight_size:" + std::to_string(i), &weight_size);
auto reuse_input = tensor_descriptor->reuse_input();
auto reuse_input_int = static_cast<int64_t>(reuse_input);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
"input_desc_reuse_input:" + std::to_string(i), &reuse_input_int);
auto output_tensor = tensor_descriptor->output_tensor();
auto output_tensor_int = static_cast<int64_t>(output_tensor);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
"input_desc_output_tensor:" + std::to_string(i), &output_tensor_int);
auto device_type = tensor_descriptor->device_type();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
"input_desc_device_type:" + std::to_string(i), &device_type);
auto input_tensor = tensor_descriptor->input_tensor();
auto input_tensor_int = static_cast<int64_t>(input_tensor);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
"input_desc_input_tensor:" + std::to_string(i), &input_tensor_int);
auto real_dim_cnt = tensor_descriptor->real_dim_cnt();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
"input_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt);
auto data_offset = tensor_descriptor->data_offset();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
"input_desc_data_offset:" + std::to_string(i), &data_offset);
auto cmps_size = tensor_descriptor->cmps_size();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_cmps_size:" + std::to_string(i),
&cmps_size);
auto cmps_tab = tensor_descriptor->cmps_tab();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
"input_desc_cmps_tab:" + std::to_string(i), &cmps_tab);
auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
"input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset);
const auto &tensor_desc_map = tensor_descriptor->attr();
std::string suffix = ":" + std::to_string(i);
AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix);
} else {
GELOGW("Tensor descriptor is nullptr");
continue;
}
} else {
GELOGW("Input desc is nullptr");
continue;
}
}
}
// Output describes
auto size_out = op_desc->GetOutputsSize();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_nums", &size_out);
if (size_out > 0) {
for (uint32_t i = 0; i < size_out; i++) {
auto output_desc = op_desc->GetOutputDescPtr(i);
if (output_desc != nullptr) {
auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_dtype:" + std::to_string(i),
&data_type);
auto origin_data_type = TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
"output_desc_origin_dtype:" + std::to_string(i), &origin_data_type);
auto dims = output_desc->GetShape().GetDims();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_desc_shape:" + std::to_string(i),
&dims);
auto dims_origin = output_desc->GetOriginShape().GetDims();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS,
"output_desc_origin_shape:" + std::to_string(i), &dims_origin);
auto layout = TypeUtils::FormatToSerialString(output_desc->GetFormat());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_layout:" + std::to_string(i),
&layout);
auto layout_origin = TypeUtils::FormatToSerialString(output_desc->GetOriginFormat());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
"output_desc_origin_layout:" + std::to_string(i), &layout_origin);
auto tensor_descriptor = output_desc->tensor_descriptor_.GetProtoMsg();
if (tensor_descriptor != nullptr) {
auto size = tensor_descriptor->size();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_size:" + std::to_string(i),
&size);
auto weight_size = tensor_descriptor->weight_size();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
"output_desc_weight_size:" + std::to_string(i), &weight_size);
auto device_type = tensor_descriptor->device_type();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
"output_desc_device_type:" + std::to_string(i), &device_type);
auto real_dim_cnt = tensor_descriptor->real_dim_cnt();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
"output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt);
const auto &tensor_desc_map = tensor_descriptor->attr();
std::string suffix = ":" + std::to_string(i);
AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix);
} else {
GELOGW("Tensor descriptor is nullptr");
continue;
}
} else {
GELOGW("Output desc is nullptr");
continue;
}
}
}
}
void OnnxUtils::AddAttrProtoForAttrsFromAttrMap(
const ::google::protobuf::Map<std::string, ::ge::proto::AttrDef> &attr_map, onnx::NodeProto *node_proto,
const std::string &prefix, const std::string &suffix) {
for (const auto &item : attr_map) {
auto attr_name = item.first;
auto attr_def = item.second;
auto attr_type = attr_def.value_case();
if (attr_type == ge::proto::AttrDef::kT) {
const auto &tensor_def = attr_def.t();
const auto &tensor_desc = tensor_def.desc();
auto data_type = ge::proto::DataType_Name(tensor_desc.dtype());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix,
&data_type);
auto dims = tensor_desc.shape().dim();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix,
dims);
auto layout = tensor_desc.layout();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix,
&layout);
auto device_type = tensor_desc.device_type();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
prefix + attr_name + "_desc_device_type" + suffix, &device_type);
if (kDumpLevel == DUMP_ALL) {
auto data = tensor_def.data();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix,
&data);
}
}
if (attr_type == ge::proto::AttrDef::kS) {
if (kDumpLevel == DUMP_ALL) {
auto str_value = attr_def.s();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value);
}
}
if (attr_type == ge::proto::AttrDef::kI) {
auto int_value = attr_def.i();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value);
}
if (attr_type == ge::proto::AttrDef::kF) {
auto float_value = attr_def.f();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value);
}
if (attr_type == ge::proto::AttrDef::kB) {
auto int_value = static_cast<int64_t>(attr_def.b());
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value);
}
if (attr_type == ge::proto::AttrDef::kList) {
const auto &list_value = attr_def.list();
auto list_value_type = list_value.val_type();
if (list_value_type ==
ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) {
if (kDumpLevel == DUMP_ALL) {
const auto &strings = list_value.s();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings);
}
}
if (list_value_type ==
ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) {
const auto &floats = list_value.f();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats);
}
if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) {
const auto &ints = list_value.i();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints);
}
if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) {
const auto &bools = list_value.b();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools);
}
}
}
}
void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "node is nullptr");
return;
}
// 1.Attributes added from node's methods
auto send_list = node->send_event_id_list_;
if (!send_list.empty()) {
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "send_event_id_list", &send_list);
}
auto recv_list = node->recv_event_id_list_;
if (!recv_list.empty()) {
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "recv_event_id_list", &recv_list);
}
auto op_desc = node->op_;
if (op_desc != nullptr) {
// for input_name_idx_ in opdesc
auto input_name_2_indexs = op_desc->GetAllInputName();
::google::protobuf::RepeatedPtrField<::std::string> input_names;
::google::protobuf::RepeatedField<::google::protobuf::int64> input_indexes;
for (const auto &input_name_2_index : input_name_2_indexs) {
std::string input_name = input_name_2_index.first;
input_names.Add(std::move(input_name));
input_indexes.Add(input_name_2_index.second);
}
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "_input_name_key", input_names);
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "_input_name_value", input_indexes);
// 2.Attributes added from node's op_(message OpDef)
// Input and out describes
AddAttrProtoForOpInAndOutDesc(node_proto, op_desc);
// Others
auto op_def = op_desc->op_def_.GetProtoMsg();
if (op_def != nullptr) {
auto id = op_def->id();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "id", &id);
auto stream_id = op_def->stream_id();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "stream_id", &stream_id);
const auto &input_name = op_def->input_name();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "input_name", input_name);
const auto &src_name = op_def->src_name();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "src_name", src_name);
const auto &src_index = op_def->src_index();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "src_index", src_index);
const auto &dst_name = op_def->dst_name();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "dst_name", dst_name);
const auto &dst_index = op_def->dst_index();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "dst_index", dst_index);
const auto &input_i = op_def->input_i();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_i", input_i);
const auto &output_i = op_def->output_i();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_i", output_i);
const auto &workspace = op_def->workspace();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace", workspace);
const auto &workspace_bytes = op_def->workspace_bytes();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes);
const auto &is_input_const = op_def->is_input_const();
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const);
const auto &op_def_attr_map = op_def->attr();
AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto);
} else {
GELOGE(FAILED, "Opdef is nullptr");
return;
}
} else {
GELOGE(FAILED, "Opdesc is nullptr");
return;
}
}
bool OnnxUtils::EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto) {
if ((node == nullptr) || (node_proto == nullptr)) {
GELOGE(GRAPH_FAILED, "EncodeOpDesc: Input Para Node Invalid");
return false;
}
// 2.Encode map<string, GeAttrValue> attrs_ to AttributeProto
for (auto &node_attr : node->attrs_) {
AddAttrProtoFromAttribute(node_attr, node_proto);
}
// 3.Encode ge::Node members to AttributeProto
AddAttrProtoFromNodeMembers(node, node_proto);
return true;
}
void OnnxUtils::EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto) {
if ((node == nullptr) || (node_proto == nullptr)) {
GELOGE(GRAPH_FAILED, "EncodeNodeLinkForNetronVisual: Input Para Node Invalid");
return;
}
const auto &node_name = node->GetName();
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
if ((out_data_anchor != nullptr) && (!out_data_anchor->GetPeerInDataAnchors().empty())) {
node_proto->add_output(node_name + ":" + std::to_string(out_data_anchor->GetIdx()));
}
}
auto out_control_anchor = node->GetOutControlAnchor();
if ((out_control_anchor != nullptr) && (!out_control_anchor->GetPeerInControlAnchors().empty())) {
node_proto->add_output(node_name + kControlAnchorIndex);
}
}
bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) {
if ((node == nullptr) || (node_proto == nullptr)) {
GELOGE(GRAPH_FAILED, "EncodeNodeLink: Input Para Node Invalid");
return false;
}
node_proto->clear_input();
// 1. Add input by in data edge
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) {
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" +
std::to_string(peer_out_anchor->GetIdx()));
} else {
// Add "" input
node_proto->add_input("");
}
}
// 2. Add input by in control edge
auto in_control_anchor = node->GetInControlAnchor();
if (in_control_anchor != nullptr) {
auto peer_out_anchors = in_control_anchor->GetPeerOutControlAnchors();
for (const auto &peer_out_anchor : peer_out_anchors) {
if (peer_out_anchor->GetOwnerNode()) {
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex);
}
}
} else {
GELOGE(FAILED, "Incontrol anchor is nullptr");
return false;
}
// 3. Add output for Netron visual support
EncodeNodeLinkForNetronVisual(node, node_proto);
return true;
}
bool OnnxUtils::EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto) {
if ((node == nullptr) || (node_proto == nullptr)) {
GELOGE(GRAPH_FAILED, "EncodeNode: Input Para Node Invalid");
return false;
}
// 1. Encode name and type
node_proto->set_name(node->GetName());
/// Netron believes that some operators, such as the activation operator of softplus, only have one input,
/// while the link relation of control anchor may exist in ge, resulting in two inputs. Therefore, "ge:" prefix
/// is added to correctly display the link relation at the expense of some color features
node_proto->set_op_type("ge:" + node->GetType());
if (kDumpLevel != DUMP_WITH_OUT_DESC) {
// 2.for attr
if (!EncodeNodeDesc(node, node_proto)) {
GELOGE(GRAPH_FAILED, "Encode NodeDesc: %s failed", node->GetName().c_str());
return false;
}
}
// 3.for link info
return EncodeNodeLink(node, node_proto);
}
void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type) {
if ((node == nullptr) || (tensor_type == nullptr)) {
GELOGE(GRAPH_FAILED, "EncodeTypeProtoTensorType: Input Para Node or tensor_type Invalid");
return;
}
const auto &op_desc = node->GetOpDesc();
if (op_desc != nullptr) {
uint32_t size_out = static_cast<uint32_t>(op_desc->GetOutputsSize());
if (size_out > 0) {
for (uint32_t i = 0; i < size_out; i++) {
const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i);
if (ge_tensor != nullptr) {
auto ge_data_type = ge_tensor->GetDataType();
auto onnx_data_type = EncodeDataType(ge_data_type);
tensor_type->set_elem_type(onnx_data_type);
onnx::TensorShapeProto *shape = tensor_type->mutable_shape();
if (shape != nullptr) {
for (auto d : ge_tensor->GetShape().GetDims()) {
auto dim = shape->add_dim();
dim->set_dim_value(d);
}
} else {
GELOGW("Shape is nullptr");
continue;
}
} else {
GELOGW("Ge tensor is nullptr");
continue;
}
}
}
} else {
GELOGW("OpDesc Is Empty, nodeName %s nodeType %s", node->GetName().c_str(), node->GetType().c_str());
return;
}
}
void OnnxUtils::EncodeValueInfo(const NodePtr &node, onnx::ValueInfoProto *value_info_proto) {
if ((node == nullptr) || (value_info_proto == nullptr)) {
GELOGE(GRAPH_FAILED, "EncodeValueInfo: Input Para Node or value_info_proto Invalid");
return;
}
value_info_proto->set_name(node->GetName());
onnx::TypeProto *t = value_info_proto->mutable_type();
onnx::TypeProto_Tensor *tensor_type = t->mutable_tensor_type();
EncodeTypeProtoTensorType(node, tensor_type);
}
bool OnnxUtils::EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto) {
if ((graph == nullptr) || (graph_proto == nullptr)) {
GELOGE(GRAPH_FAILED, "EncodeGraph: Input para Invalid");
return false;
}
graph_proto->set_name(graph->GetName());
// 1. Add graph inputs
for (const auto &input : graph->GetInputNodes()) {
auto value_info_proto = graph_proto->add_input();
EncodeValueInfo(input, value_info_proto);
}
// 2. Add graph outputs
for (const auto &output : graph->GetOutputNodes()) {
auto value_info_proto = graph_proto->add_output();
EncodeValueInfo(output, value_info_proto);
}
// 3. Add nodes
for (const auto &node : graph->GetDirectNode()) {
if (!EncodeNode(node, graph_proto->add_node())) {
GELOGW("EncodeNode failed");
continue;
}
}
return true;
}
bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto) {
model_proto.set_model_version(model.GetVersion());
model_proto.set_ir_version(onnx::IR_VERSION);
model_proto.set_producer_name(model.GetName());
auto &graph = model.graph_;
auto compute_graph = GraphUtils::GetComputeGraph(graph);
if (compute_graph == nullptr) {
GELOGE(GRAPH_FAILED, "GetComputeGraph: return nullptr");
return false;
}
auto graph_proto = model_proto.mutable_graph();
if (graph_proto == nullptr) {
GELOGE(GRAPH_FAILED, "mutable_graph: %s return nullptr", compute_graph->GetName().c_str());
return false;
}
if (!EncodeGraph(compute_graph, graph_proto)) {
GELOGE(GRAPH_FAILED, "EncodeGraph: %s fail", compute_graph->GetName().c_str());
return false;
}
// For subgraphs: a subgraph is represented by a node
for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) {
if (sub_compute_graph != nullptr) {
auto node_proto = graph_proto->add_node();
if (node_proto == nullptr) {
GELOGW("Node proto is nullptr");
continue;
}
node_proto->set_name(sub_compute_graph->GetName());
node_proto->set_op_type(kNodeTypeForSubgraph);
auto attr = node_proto->add_attribute();
attr->set_name("graph");
attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
auto sub_graph_proto = attr->mutable_g();
if (sub_graph_proto == nullptr) {
GELOGW("Sub graph proto is nullptr");
continue;
}
if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) {
GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str());
continue;
}
} else {
GELOGW("Graph: %s subgraph is nullptr, skip EncodeGraph", compute_graph->GetName().c_str());
continue;
}
}
return true;
}
// Part 2: from ONNX Protobuf convert to IR
static std::map<onnx::TensorProto_DataType, ge::DataType> onnxDataTypeToGeMap = {
{onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64},
{onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32},
{onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8},
{onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16},
{onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16},
{onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL},
};
ge::DataType OnnxUtils::DecodeDataType(onnx::TensorProto_DataType data_type) {
auto it = onnxDataTypeToGeMap.find(data_type);
if (it != onnxDataTypeToGeMap.end()) {
return it->second;
} else {
GELOGW("DecodeDataType: datatype not support %u", data_type);
return ge::DT_UNDEFINED;
}
}
bool OnnxUtils::ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index) {
auto sep = node_name_index.rfind(':');
if (sep == std::string::npos) {
return false;
}
node_name = node_name_index.substr(0, sep);
auto index_str = node_name_index.substr(sep + 1);
index = static_cast<int32_t>(std::strtol(index_str.c_str(), nullptr, 10));
return true;
}
bool OnnxUtils::DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr) {
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp: node_ptr is nullptr");
return false;
}
// Data edge
if (item.src_out_index >= 0) {
auto src_anchor = node_ptr->GetOutDataAnchor(item.src_out_index);
auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index);
if ((src_anchor == nullptr) || (dst_anchor == nullptr)) {
GELOGE(GRAPH_FAILED, "Get data anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
item.dst_node_name.c_str(), item.dst_in_index);
return false;
}
if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Data Anchor: src_anchor->LinkTo(dst_anchor) failed");
return false;
}
// Control edge
} else {
auto src_anchor = node_ptr->GetOutControlAnchor();
auto dst_anchor = item.dst_node->GetInControlAnchor();
if ((src_anchor == nullptr) || (dst_anchor == nullptr)) {
GELOGE(GRAPH_FAILED, "Get control anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
item.dst_node_name.c_str(), item.dst_in_index);
return false;
}
if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Control Anchor: src_anchor->LinkTo(dst_anchor) failed");
return false;
}
}
return true;
}
bool OnnxUtils::DecodeNodeLink(const std::vector<onnx::NodeProto> &node_proto_vector,
const std::map<std::string, NodePtr> &node_map) {
for (const auto &node_proto : node_proto_vector) {
const auto &node_name = node_proto.name();
auto dst_node = node_map.find(node_name);
if ((dst_node == node_map.end()) || (dst_node->second == nullptr)) {
GELOGE(GRAPH_FAILED, "destination node: %s find failed or is nullptr", node_name.c_str());
return false;
}
int32_t dst_index = 0;
for (const auto &input : node_proto.input()) {
std::string input_node_name;
int32_t index = 0;
if (ParseNameIndex(input, input_node_name, index)) {
auto item = NodeLinkInfo{input_node_name, index, dst_node->second, dst_index, node_proto.name()};
auto src_node = node_map.find(input_node_name);
if (src_node == node_map.end()) {
GELOGE(GRAPH_FAILED, "find src node: %s failed", input_node_name.c_str());
return false;
}
auto node_ptr = src_node->second;
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "src node: %s is nullptr", input_node_name.c_str());
return false;
}
if (!DecodeNodeLinkImp(item, node_ptr)) {
GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp node: %s failed", input_node_name.c_str());
return false;
}
}
if (index >= 0) {
dst_index++;
}
}
}
return true;
}
void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<std::string> &strings) {
if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRINGS) {
GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
return;
}
for (int i = 0; i < attr_proto.strings_size(); i++) {
strings.push_back(attr_proto.strings(i));
}
}
void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value) {
if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRING) {
GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
return;
}
value = attr_proto.s();
}
void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<int64_t> &ints) {
if (attr_proto.type() != onnx::AttributeProto_AttributeType_INTS) {
GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
return;
}
for (int i = 0; i < attr_proto.ints_size(); i++) {
ints.push_back(attr_proto.ints(i));
}
}
void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value) {
if (attr_proto.type() != onnx::AttributeProto_AttributeType_INT) {
GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
return;
}
value = attr_proto.i();
}
void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_input_desc, int32_t index,
OpDescPtr &op_desc) {
if (op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) {
GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast<uint32_t>(index)) is nullptr",
op_desc->GetName().c_str(), attr_name_for_input_desc.c_str());
return;
}
if (attr_name_for_input_desc == "input_desc_dtype") {
auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s());
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetDataType(data_type);
} else if (attr_name_for_input_desc == "input_desc_shape") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
GeShape ge_shape(ints);
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape);
} else if (attr_name_for_input_desc == "input_desc_layout") {
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetFormat(data_format);
} else if (attr_name_for_input_desc == "input_desc_origin_shape") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
GeShape ge_shape(ints);
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape);
} else if (attr_name_for_input_desc == "input_desc_origin_layout") {
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format);
} else if (attr_name_for_input_desc == "input_desc_size") {
int64_t input_size = 0;
auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
DecodeAttribute(attr_proto, input_size);
tensor_descriptor->set_size(input_size);
} else if (attr_name_for_input_desc == "input_desc_data_offset") {
auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
int64_t offset = 0;
DecodeAttribute(attr_proto, offset);
tensor_descriptor->set_data_offset(offset);
} else {
return;
}
}
void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_output_desc, int32_t index,
OpDescPtr &op_desc) {
if (op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) == nullptr) {
GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) is nullptr",
op_desc->GetName().c_str(), attr_name_for_output_desc.c_str());
return;
}
if (attr_name_for_output_desc == "output_desc_dtype") {
auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s());
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetDataType(data_type);
} else if (attr_name_for_output_desc == "output_desc_shape") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
GeShape ge_shape(ints);
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape);
} else if (attr_name_for_output_desc == "output_desc_layout") {
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetFormat(data_format);
} else if (attr_name_for_output_desc == "output_desc_origin_shape") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
GeShape ge_shape(ints);
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape);
} else if (attr_name_for_output_desc == "output_desc_origin_layout") {
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format);
} else if (attr_name_for_output_desc == "output_desc_size") {
int64_t output_size = 0;
auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
DecodeAttribute(attr_proto, output_size);
tensor_descriptor->set_size(output_size);
} else if (attr_name_for_output_desc == "output_desc_data_offset") {
auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
int64_t offset = 0;
DecodeAttribute(attr_proto, offset);
tensor_descriptor->set_data_offset(offset);
} else {
return;
}
}
void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_input_output_desc, int32_t index,
OpDescPtr &op_desc) {
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "op_desc is nullptr");
return;
}
if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") {
DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
} else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") {
DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
} else {
return;
}
}
void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) {
auto attr_map = op_def.mutable_attr();
const auto &attr_name = attr_proto.name();
ge::proto::AttrDef op_attr;
int64_t value = 0;
DecodeAttribute(attr_proto, value);
op_attr.set_i(value);
attr_map->insert(AttrDefPair(attr_name, op_attr));
}
void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) {
if (op_desc == nullptr) {
GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr");
return;
}
const auto &attr_name = attr_proto.name();
std::string attr_name_for_input_output_desc;
int32_t index = 0;
if (!ParseNameIndex(attr_name, attr_name_for_input_output_desc, index)) {
if (attr_name == "id") {
op_desc->SetId(attr_proto.i());
} else if (attr_name == "stream_id") {
op_desc->SetStreamId(attr_proto.i());
} else if (attr_name == "src_name") {
std::vector<std::string> strings;
DecodeAttribute(attr_proto, strings);
op_desc->SetSrcName(strings);
} else if (attr_name == "dst_name") {
std::vector<std::string> strings;
DecodeAttribute(attr_proto, strings);
op_desc->SetDstName(strings);
} else if (attr_name == "src_index") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
op_desc->SetSrcIndex(ints);
} else if (attr_name == "dst_index") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
op_desc->SetDstIndex(ints);
} else if (attr_name == "fusion_scope") {
DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg());
} else if (attr_name == "input_i") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
op_desc->SetInputOffset(ints);
} else if (attr_name == "output_i") {
std::vector<std::int64_t> ints;
DecodeAttribute(attr_proto, ints);
op_desc->SetOutputOffset(ints);
} else {
return;
}
// Update input and output desc
} else {
DecodeNodeAttributeForOpInAndOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
}
}
bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_desc) {
if (op_desc == nullptr || node_proto == nullptr) {
GELOGE(GRAPH_FAILED, " Op_desc is nullptr or node_proto is nullptr");
return false;
}
// 1. Decode node_proto name and type
op_desc->SetName(node_proto->name());
const auto &node_type_with_ge_prefix = node_proto->op_type();
auto sep = node_type_with_ge_prefix.find(':');
if (sep == std::string::npos) {
return false;
}
auto node_type = node_type_with_ge_prefix.substr(sep + 1);
op_desc->SetType(node_type);
// 2. Add empty input and output desc
for (const auto &attr : node_proto->attribute()) {
if (attr.name() == "input_desc_nums") {
auto size_in = attr.i();
for (int64_t i = 0; i < size_in; i++) {
GeTensorDesc ge_tensor_desc;
GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed.");
}
}
if (attr.name() == "output_desc_nums") {
auto size_out = attr.i();
for (int64_t i = 0; i < size_out; i++) {
GeTensorDesc ge_tensor_desc;
GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed.");
}
}
}
// 3.Decode node_proto attributes
for (int i = 0; i < node_proto->attribute_size(); i++) {
DecodeNodeAttributeForOpDesc(node_proto->attribute(i), op_desc);
}
return true;
}
bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph) {
if (recursion_depth > kMaxRecursionDepth) {
GELOGE(GRAPH_FAILED, "DecodeGraph: recursion depth is too large, abort");
return false;
}
graph = ComGraphMakeShared<ge::ComputeGraph>(graph_proto.name());
GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed");
/// 1. Decode all nodes first, node should include input
/// and output nodes and nodes which represent sub graphs
std::map<std::string, NodePtr> node_map;
std::vector<onnx::NodeProto> node_proto_vector;
for (const auto &node_proto : graph_proto.node()) {
// a. nodes represent sub graphs
if (node_proto.op_type() == kNodeTypeForSubgraph) {
ComputeGraphPtr compute_graph;
// in this case, node only have one attr, whose type is AttributeProto_AttributeType_GRAPH
const auto &node_attr = node_proto.attribute(0);
if ((node_attr.type() == onnx::AttributeProto_AttributeType_GRAPH) &&
DecodeGraph(recursion_depth + 1, node_attr.g(), compute_graph)) {
(void)graph->AddSubGraph(compute_graph);
} else {
GELOGE(GRAPH_FAILED, "Decode sub graph %s failed with node type:%d", node_proto.name().c_str(),
node_attr.type());
return false;
}
// b. direct nodes in graph
} else {
node_proto_vector.push_back(node_proto);
OpDescPtr op_desc = ComGraphMakeShared<OpDesc>();
// b.1 For node desc
if (!DecodeNodeDesc(&node_proto, op_desc)) {
GELOGE(GRAPH_FAILED, "Decode node desc %s failed ", node_proto.name().c_str());
return false;
}
auto node = graph->AddNode(op_desc);
node_map.insert(std::make_pair(node_proto.name(), node));
}
}
/// We get all nodes in graph here
/// b.2 For node link
if (!DecodeNodeLink(node_proto_vector, node_map)) {
GELOGE(GRAPH_FAILED, "Decode node link failed");
return false;
}
// 2. Add inputs nodes for graph
for (const auto &input : graph_proto.input()) {
const auto &input_node_name = input.name();
auto input_node_item = node_map.find(input_node_name);
if (input_node_item == node_map.end()) {
GELOGE(GRAPH_FAILED, "cannot find graph's input node %s in node_", input_node_name.c_str());
return false;
}
auto ret = graph->AddInputNode(input_node_item->second);
GE_CHK_BOOL_EXEC(ret != nullptr, continue, "Add inputnode failed");
}
// 3. Add outputs nodes for graph
for (const auto &output : graph_proto.output()) {
const auto &output_node_name = output.name();
auto output_node_item = node_map.find(output_node_name);
if (output_node_item == node_map.end()) {
GELOGE(GRAPH_FAILED, "cannot find graph's output node %s in node_", output_node_name.c_str());
return false;
}
auto ret = graph->AddOutputNode(output_node_item->second);
if (ret == nullptr) {
GELOGW("Add outputnode failed,out put node is %s", output_node_name.c_str());
continue;
}
}
return true;
}
bool OnnxUtils::ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model) {
model.name_ = model_proto.producer_name();
model.version_ = static_cast<uint32_t>(model_proto.model_version());
auto &graph_proto = model_proto.graph();
ComputeGraphPtr compute_graph;
// 0 means recursion depth, father call
if (!DecodeGraph(0, graph_proto, compute_graph)) {
GELOGE(GRAPH_FAILED, "Decode compute graph from graph_proto failed");
return false;
}
model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
return true;
}
} // namespace ge