You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1179 lines
50 KiB
1179 lines
50 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "graph/utils/ge_ir_utils.h"
|
|
#include <utility>
|
|
#include "framework/common/debug/ge_log.h"
|
|
|
|
namespace {
|
|
const char *const kControlAnchorIndex = ":-1";
|
|
const char *const kNodeTypeForSubgraph = "subgraph";
|
|
const char *const kPrefixForInputDesc = "input_desc_attr_";
|
|
const char *const kPrefixForOutputDesc = "output_desc_attr_";
|
|
const char *const kDumpGEGraph = "DUMP_GE_GRAPH";
|
|
const int8_t kMaxRecursionDepth = 10;
|
|
const char *const kDumpGeGraph = std::getenv(kDumpGEGraph);
|
|
const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP;
|
|
const int64_t kInputPrefixLength = 5;
|
|
const int64_t kOutputPrefixLength = 6;
|
|
using AttrDefPair = ::google::protobuf::MapPair<std::string, ge::proto::AttrDef>;
|
|
} // namespace
|
|
|
|
namespace ge {
|
|
// Part 1: from IR convert to ONNX Protobuf
|
|
static const std::map<ge::DataType, onnx::TensorProto_DataType> kGeDataTypeToOnnxMap = {
|
|
{DT_INT64, onnx::TensorProto_DataType_INT64}, {DT_UINT64, onnx::TensorProto_DataType_UINT64},
|
|
{DT_FLOAT, onnx::TensorProto_DataType_FLOAT}, {DT_INT32, onnx::TensorProto_DataType_INT32},
|
|
{DT_UINT32, onnx::TensorProto_DataType_UINT32}, {DT_INT8, onnx::TensorProto_DataType_INT8},
|
|
{DT_UINT8, onnx::TensorProto_DataType_UINT8}, {DT_INT16, onnx::TensorProto_DataType_INT16},
|
|
{DT_UINT16, onnx::TensorProto_DataType_UINT16}, {DT_FLOAT16, onnx::TensorProto_DataType_FLOAT16},
|
|
{DT_DOUBLE, onnx::TensorProto_DataType_DOUBLE}, {DT_BOOL, onnx::TensorProto_DataType_BOOL},
|
|
};
|
|
|
|
onnx::TensorProto_DataType OnnxUtils::EncodeDataType(DataType data_type) {
|
|
auto it = kGeDataTypeToOnnxMap.find(data_type);
|
|
if (it != kGeDataTypeToOnnxMap.end()) {
|
|
return it->second;
|
|
} else {
|
|
GELOGW("EncodeDataType: datatype not support %u", data_type);
|
|
return onnx::TensorProto_DataType_UNDEFINED;
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::AddAttrProtoFromAttribute(const std::pair<const std::string, ge::GeAttrValue> &string_attr_value,
|
|
onnx::NodeProto *node_proto) {
|
|
if (node_proto == nullptr) {
|
|
GELOGE(FAILED, "Node proto is nullptr.");
|
|
return;
|
|
}
|
|
auto attr = node_proto->add_attribute();
|
|
if (attr == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "attr is nullptr.");
|
|
return;
|
|
}
|
|
auto attr_name = string_attr_value.first;
|
|
attr->set_name(attr_name);
|
|
auto attr_value = string_attr_value.second;
|
|
auto value_type = attr_value.GetValueType();
|
|
switch (value_type) {
|
|
case GeAttrValue::VT_FLOAT: {
|
|
GeAttrValue::FLOAT data_f = 0;
|
|
(void)attr_value.GetValue(data_f);
|
|
attr->set_f(data_f);
|
|
attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
|
|
break;
|
|
}
|
|
case GeAttrValue::VT_LIST_FLOAT: {
|
|
GeAttrValue::LIST_FLOAT data_fs = {};
|
|
(void)attr_value.GetValue(data_fs);
|
|
attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
|
|
for (auto &v : data_fs) {
|
|
attr->add_floats(v);
|
|
}
|
|
break;
|
|
}
|
|
case GeAttrValue::VT_INT: {
|
|
GeAttrValue::INT data_i = 0;
|
|
(void)attr_value.GetValue(data_i);
|
|
attr->set_type(onnx::AttributeProto_AttributeType_INT);
|
|
attr->set_i(data_i);
|
|
break;
|
|
}
|
|
case GeAttrValue::VT_LIST_INT: {
|
|
GeAttrValue::LIST_INT data_is = {};
|
|
(void)attr_value.GetValue(data_is);
|
|
attr->set_type(onnx::AttributeProto_AttributeType_INTS);
|
|
for (auto &v : data_is) {
|
|
attr->add_ints(v);
|
|
}
|
|
break;
|
|
}
|
|
case GeAttrValue::VT_STRING: {
|
|
GeAttrValue::STR data_s;
|
|
(void)attr_value.GetValue(data_s);
|
|
attr->set_type(onnx::AttributeProto_AttributeType_STRING);
|
|
attr->set_s(data_s);
|
|
break;
|
|
}
|
|
case GeAttrValue::VT_LIST_STRING: {
|
|
GeAttrValue::LIST_STR data_ss = {};
|
|
(void)attr_value.GetValue(data_ss);
|
|
attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
|
|
for (auto &v : data_ss) {
|
|
attr->add_strings(v);
|
|
}
|
|
break;
|
|
}
|
|
default:
|
|
GELOGW("GeAttrValue ValueType: %u is not supported for now", value_type);
|
|
break;
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
|
|
void *data) {
|
|
if (node_proto == nullptr) {
|
|
GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
|
|
return;
|
|
}
|
|
auto attr = node_proto->add_attribute();
|
|
if (attr == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "attr is nullptr.");
|
|
return;
|
|
}
|
|
attr->set_name(name);
|
|
switch (type) {
|
|
case onnx::AttributeProto_AttributeType_FLOAT:
|
|
attr->set_f((*(static_cast<float *>(data))));
|
|
attr->set_type(onnx::AttributeProto_AttributeType_FLOAT);
|
|
break;
|
|
|
|
case onnx::AttributeProto_AttributeType_FLOATS:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_FLOATS);
|
|
for (auto &v : (*(static_cast<std::vector<float> *>(data)))) {
|
|
attr->add_floats(v);
|
|
}
|
|
break;
|
|
|
|
case onnx::AttributeProto_AttributeType_INT:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_INT);
|
|
attr->set_i((*(static_cast<int64_t *>(data))));
|
|
break;
|
|
|
|
case onnx::AttributeProto_AttributeType_INTS:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_INTS);
|
|
for (auto &v : *(static_cast<std::vector<int64_t> *>(data))) {
|
|
attr->add_ints(v);
|
|
}
|
|
break;
|
|
|
|
case onnx::AttributeProto_AttributeType_STRING:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_STRING);
|
|
attr->set_s((*(static_cast<std::string *>(data))));
|
|
break;
|
|
|
|
case onnx::AttributeProto_AttributeType_STRINGS:
|
|
attr->set_type(onnx::AttributeProto_AttributeType_STRINGS);
|
|
for (auto &v : *(static_cast<std::vector<std::string> *>(data))) {
|
|
attr->add_strings(v);
|
|
}
|
|
break;
|
|
|
|
default:
|
|
GELOGW("AttributeProto AttributeType: %u is not supported for now", type);
|
|
break;
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
|
|
::google::protobuf::RepeatedField<::google::protobuf::int64> data) {
|
|
if (node_proto == nullptr) {
|
|
GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
|
|
return;
|
|
}
|
|
if (!data.empty()) {
|
|
auto attr = node_proto->add_attribute();
|
|
if (attr == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "attr is nullptr.");
|
|
return;
|
|
}
|
|
attr->set_name(name);
|
|
for (auto &v : data) {
|
|
attr->add_ints(v);
|
|
}
|
|
attr->set_type(type);
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
|
|
::google::protobuf::RepeatedField<bool> data) {
|
|
if (node_proto == nullptr) {
|
|
GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str());
|
|
return;
|
|
}
|
|
if (!data.empty()) {
|
|
auto attr = node_proto->add_attribute();
|
|
if (attr == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "attr is nullptr.");
|
|
return;
|
|
}
|
|
attr->set_name(name);
|
|
for (auto &v : data) {
|
|
attr->add_ints(static_cast<int64_t>(v));
|
|
}
|
|
attr->set_type(type);
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
|
|
::google::protobuf::RepeatedField<float> data) {
|
|
if (node_proto == nullptr) {
|
|
GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str());
|
|
return;
|
|
}
|
|
if (!data.empty()) {
|
|
auto attr = node_proto->add_attribute();
|
|
if (attr == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "attr is nullptr.");
|
|
return;
|
|
}
|
|
attr->set_name(name);
|
|
for (auto &v : data) {
|
|
attr->add_floats(v);
|
|
}
|
|
attr->set_type(type);
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name,
|
|
::google::protobuf::RepeatedPtrField<::std::string> data) {
|
|
if (node_proto == nullptr) {
|
|
GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str());
|
|
return;
|
|
}
|
|
if (!data.empty()) {
|
|
auto attr = node_proto->add_attribute();
|
|
if (attr == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "attr is nullptr.");
|
|
return;
|
|
}
|
|
attr->set_name(name);
|
|
for (auto &v : data) {
|
|
attr->add_strings(v);
|
|
}
|
|
attr->set_type(type);
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc) {
|
|
if (node_proto == nullptr || op_desc == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "node_proto or op_desc is nullptr");
|
|
return;
|
|
}
|
|
// Input describes
|
|
auto size_in = op_desc->GetAllInputsSize();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_nums", &size_in);
|
|
if (size_in > 0) {
|
|
for (uint32_t i = 0; i < size_in; i++) {
|
|
auto input_desc = op_desc->GetInputDescPtrDfault(i);
|
|
if (input_desc != nullptr) {
|
|
auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType());
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_dtype:" + std::to_string(i),
|
|
&data_type);
|
|
auto data_type_origin = TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType());
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
|
|
"input_desc_origin_dtype:" + std::to_string(i), &data_type_origin);
|
|
auto dims = input_desc->GetShape().GetDims();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_desc_shape:" + std::to_string(i),
|
|
&dims);
|
|
auto dims_origin = input_desc->GetOriginShape().GetDims();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS,
|
|
"input_desc_origin_shape:" + std::to_string(i), &dims_origin);
|
|
auto layout = TypeUtils::FormatToSerialString(input_desc->GetFormat());
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_layout:" + std::to_string(i),
|
|
&layout);
|
|
auto layout_origin = TypeUtils::FormatToSerialString(input_desc->GetOriginFormat());
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
|
|
"input_desc_origin_layout:" + std::to_string(i), &layout_origin);
|
|
auto tensor_descriptor = input_desc->tensor_descriptor_.GetProtoMsg();
|
|
if (tensor_descriptor != nullptr) {
|
|
auto size = tensor_descriptor->size();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_size:" + std::to_string(i),
|
|
&size);
|
|
auto weight_size = tensor_descriptor->weight_size();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
|
|
"input_desc_weight_size:" + std::to_string(i), &weight_size);
|
|
auto reuse_input = tensor_descriptor->reuse_input();
|
|
auto reuse_input_int = static_cast<int64_t>(reuse_input);
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
|
|
"input_desc_reuse_input:" + std::to_string(i), &reuse_input_int);
|
|
auto output_tensor = tensor_descriptor->output_tensor();
|
|
auto output_tensor_int = static_cast<int64_t>(output_tensor);
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
|
|
"input_desc_output_tensor:" + std::to_string(i), &output_tensor_int);
|
|
auto device_type = tensor_descriptor->device_type();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
|
|
"input_desc_device_type:" + std::to_string(i), &device_type);
|
|
auto input_tensor = tensor_descriptor->input_tensor();
|
|
auto input_tensor_int = static_cast<int64_t>(input_tensor);
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
|
|
"input_desc_input_tensor:" + std::to_string(i), &input_tensor_int);
|
|
auto real_dim_cnt = tensor_descriptor->real_dim_cnt();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
|
|
"input_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt);
|
|
auto data_offset = tensor_descriptor->data_offset();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
|
|
"input_desc_data_offset:" + std::to_string(i), &data_offset);
|
|
auto cmps_size = tensor_descriptor->cmps_size();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "input_desc_cmps_size:" + std::to_string(i),
|
|
&cmps_size);
|
|
auto cmps_tab = tensor_descriptor->cmps_tab();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
|
|
"input_desc_cmps_tab:" + std::to_string(i), &cmps_tab);
|
|
auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
|
|
"input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset);
|
|
const auto &tensor_desc_map = tensor_descriptor->attr();
|
|
std::string suffix = ":" + std::to_string(i);
|
|
AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix);
|
|
} else {
|
|
GELOGW("Tensor descriptor is nullptr");
|
|
continue;
|
|
}
|
|
} else {
|
|
GELOGW("Input desc is nullptr");
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
// Output describes
|
|
auto size_out = op_desc->GetOutputsSize();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_nums", &size_out);
|
|
if (size_out > 0) {
|
|
for (uint32_t i = 0; i < size_out; i++) {
|
|
auto output_desc = op_desc->GetOutputDescPtr(i);
|
|
if (output_desc != nullptr) {
|
|
auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType());
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_dtype:" + std::to_string(i),
|
|
&data_type);
|
|
auto origin_data_type = TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType());
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
|
|
"output_desc_origin_dtype:" + std::to_string(i), &origin_data_type);
|
|
auto dims = output_desc->GetShape().GetDims();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_desc_shape:" + std::to_string(i),
|
|
&dims);
|
|
auto dims_origin = output_desc->GetOriginShape().GetDims();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS,
|
|
"output_desc_origin_shape:" + std::to_string(i), &dims_origin);
|
|
auto layout = TypeUtils::FormatToSerialString(output_desc->GetFormat());
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_layout:" + std::to_string(i),
|
|
&layout);
|
|
auto layout_origin = TypeUtils::FormatToSerialString(output_desc->GetOriginFormat());
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
|
|
"output_desc_origin_layout:" + std::to_string(i), &layout_origin);
|
|
auto tensor_descriptor = output_desc->tensor_descriptor_.GetProtoMsg();
|
|
if (tensor_descriptor != nullptr) {
|
|
auto size = tensor_descriptor->size();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "output_desc_size:" + std::to_string(i),
|
|
&size);
|
|
auto weight_size = tensor_descriptor->weight_size();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
|
|
"output_desc_weight_size:" + std::to_string(i), &weight_size);
|
|
auto device_type = tensor_descriptor->device_type();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
|
|
"output_desc_device_type:" + std::to_string(i), &device_type);
|
|
auto real_dim_cnt = tensor_descriptor->real_dim_cnt();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT,
|
|
"output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt);
|
|
const auto &tensor_desc_map = tensor_descriptor->attr();
|
|
std::string suffix = ":" + std::to_string(i);
|
|
AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix);
|
|
} else {
|
|
GELOGW("Tensor descriptor is nullptr");
|
|
continue;
|
|
}
|
|
} else {
|
|
GELOGW("Output desc is nullptr");
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::AddAttrProtoForAttrsFromAttrMap(
|
|
const ::google::protobuf::Map<std::string, ::ge::proto::AttrDef> &attr_map, onnx::NodeProto *node_proto,
|
|
const std::string &prefix, const std::string &suffix) {
|
|
for (const auto &item : attr_map) {
|
|
auto attr_name = item.first;
|
|
auto attr_def = item.second;
|
|
auto attr_type = attr_def.value_case();
|
|
if (attr_type == ge::proto::AttrDef::kT) {
|
|
const auto &tensor_def = attr_def.t();
|
|
const auto &tensor_desc = tensor_def.desc();
|
|
auto data_type = ge::proto::DataType_Name(tensor_desc.dtype());
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix,
|
|
&data_type);
|
|
auto dims = tensor_desc.shape().dim();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix,
|
|
dims);
|
|
auto layout = tensor_desc.layout();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix,
|
|
&layout);
|
|
auto device_type = tensor_desc.device_type();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING,
|
|
prefix + attr_name + "_desc_device_type" + suffix, &device_type);
|
|
if (kDumpLevel == DUMP_ALL) {
|
|
auto data = tensor_def.data();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix,
|
|
&data);
|
|
}
|
|
}
|
|
if (attr_type == ge::proto::AttrDef::kS) {
|
|
if (kDumpLevel == DUMP_ALL) {
|
|
auto str_value = attr_def.s();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value);
|
|
}
|
|
}
|
|
if (attr_type == ge::proto::AttrDef::kI) {
|
|
auto int_value = attr_def.i();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value);
|
|
}
|
|
if (attr_type == ge::proto::AttrDef::kF) {
|
|
auto float_value = attr_def.f();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value);
|
|
}
|
|
if (attr_type == ge::proto::AttrDef::kB) {
|
|
auto int_value = static_cast<int64_t>(attr_def.b());
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value);
|
|
}
|
|
if (attr_type == ge::proto::AttrDef::kList) {
|
|
const auto &list_value = attr_def.list();
|
|
auto list_value_type = list_value.val_type();
|
|
if (list_value_type ==
|
|
ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) {
|
|
if (kDumpLevel == DUMP_ALL) {
|
|
const auto &strings = list_value.s();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings);
|
|
}
|
|
}
|
|
if (list_value_type ==
|
|
ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) {
|
|
const auto &floats = list_value.f();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats);
|
|
}
|
|
if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) {
|
|
const auto &ints = list_value.i();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints);
|
|
}
|
|
if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) {
|
|
const auto &bools = list_value.b();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto) {
|
|
if (node == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "node is nullptr");
|
|
return;
|
|
}
|
|
// 1.Attributes added from node's methods
|
|
auto send_list = node->send_event_id_list_;
|
|
if (!send_list.empty()) {
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "send_event_id_list", &send_list);
|
|
}
|
|
auto recv_list = node->recv_event_id_list_;
|
|
if (!recv_list.empty()) {
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "recv_event_id_list", &recv_list);
|
|
}
|
|
auto op_desc = node->op_;
|
|
if (op_desc != nullptr) {
|
|
// for input_name_idx_ in opdesc
|
|
auto input_name_2_indexs = op_desc->GetAllInputName();
|
|
::google::protobuf::RepeatedPtrField<::std::string> input_names;
|
|
::google::protobuf::RepeatedField<::google::protobuf::int64> input_indexes;
|
|
for (const auto &input_name_2_index : input_name_2_indexs) {
|
|
std::string input_name = input_name_2_index.first;
|
|
input_names.Add(std::move(input_name));
|
|
input_indexes.Add(input_name_2_index.second);
|
|
}
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "_input_name_key", input_names);
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "_input_name_value", input_indexes);
|
|
// 2.Attributes added from node's op_(message OpDef)
|
|
// Input and out describes
|
|
AddAttrProtoForOpInAndOutDesc(node_proto, op_desc);
|
|
// Others
|
|
auto op_def = op_desc->op_def_.GetProtoMsg();
|
|
if (op_def != nullptr) {
|
|
auto id = op_def->id();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "id", &id);
|
|
auto stream_id = op_def->stream_id();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, "stream_id", &stream_id);
|
|
const auto &input_name = op_def->input_name();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "input_name", input_name);
|
|
const auto &src_name = op_def->src_name();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "src_name", src_name);
|
|
const auto &src_index = op_def->src_index();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "src_index", src_index);
|
|
const auto &dst_name = op_def->dst_name();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, "dst_name", dst_name);
|
|
const auto &dst_index = op_def->dst_index();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "dst_index", dst_index);
|
|
const auto &input_i = op_def->input_i();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_i", input_i);
|
|
const auto &output_i = op_def->output_i();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_i", output_i);
|
|
const auto &workspace = op_def->workspace();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace", workspace);
|
|
const auto &workspace_bytes = op_def->workspace_bytes();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes);
|
|
const auto &is_input_const = op_def->is_input_const();
|
|
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const);
|
|
const auto &op_def_attr_map = op_def->attr();
|
|
AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto);
|
|
} else {
|
|
GELOGE(FAILED, "Opdef is nullptr");
|
|
return;
|
|
}
|
|
} else {
|
|
GELOGE(FAILED, "Opdesc is nullptr");
|
|
return;
|
|
}
|
|
}
|
|
|
|
bool OnnxUtils::EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto) {
|
|
if ((node == nullptr) || (node_proto == nullptr)) {
|
|
GELOGE(GRAPH_FAILED, "EncodeOpDesc: Input Para Node Invalid");
|
|
return false;
|
|
}
|
|
|
|
// 2.Encode map<string, GeAttrValue> attrs_ to AttributeProto
|
|
for (auto &node_attr : node->attrs_) {
|
|
AddAttrProtoFromAttribute(node_attr, node_proto);
|
|
}
|
|
// 3.Encode ge::Node members to AttributeProto
|
|
AddAttrProtoFromNodeMembers(node, node_proto);
|
|
return true;
|
|
}
|
|
|
|
void OnnxUtils::EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto) {
|
|
if ((node == nullptr) || (node_proto == nullptr)) {
|
|
GELOGE(GRAPH_FAILED, "EncodeNodeLinkForNetronVisual: Input Para Node Invalid");
|
|
return;
|
|
}
|
|
const auto &node_name = node->GetName();
|
|
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
|
|
if ((out_data_anchor != nullptr) && (!out_data_anchor->GetPeerInDataAnchors().empty())) {
|
|
node_proto->add_output(node_name + ":" + std::to_string(out_data_anchor->GetIdx()));
|
|
}
|
|
}
|
|
auto out_control_anchor = node->GetOutControlAnchor();
|
|
if ((out_control_anchor != nullptr) && (!out_control_anchor->GetPeerInControlAnchors().empty())) {
|
|
node_proto->add_output(node_name + kControlAnchorIndex);
|
|
}
|
|
}
|
|
|
|
bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) {
|
|
if ((node == nullptr) || (node_proto == nullptr)) {
|
|
GELOGE(GRAPH_FAILED, "EncodeNodeLink: Input Para Node Invalid");
|
|
return false;
|
|
}
|
|
node_proto->clear_input();
|
|
// 1. Add input by in data edge
|
|
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
|
|
auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
|
|
if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) {
|
|
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" +
|
|
std::to_string(peer_out_anchor->GetIdx()));
|
|
} else {
|
|
// Add "" input
|
|
node_proto->add_input("");
|
|
}
|
|
}
|
|
|
|
// 2. Add input by in control edge
|
|
auto in_control_anchor = node->GetInControlAnchor();
|
|
if (in_control_anchor != nullptr) {
|
|
auto peer_out_anchors = in_control_anchor->GetPeerOutControlAnchors();
|
|
for (const auto &peer_out_anchor : peer_out_anchors) {
|
|
if (peer_out_anchor->GetOwnerNode()) {
|
|
node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex);
|
|
}
|
|
}
|
|
} else {
|
|
GELOGE(FAILED, "Incontrol anchor is nullptr");
|
|
return false;
|
|
}
|
|
|
|
// 3. Add output for Netron visual support
|
|
EncodeNodeLinkForNetronVisual(node, node_proto);
|
|
return true;
|
|
}
|
|
|
|
bool OnnxUtils::EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto) {
|
|
if ((node == nullptr) || (node_proto == nullptr)) {
|
|
GELOGE(GRAPH_FAILED, "EncodeNode: Input Para Node Invalid");
|
|
return false;
|
|
}
|
|
// 1. Encode name and type
|
|
node_proto->set_name(node->GetName());
|
|
/// Netron believes that some operators, such as the activation operator of softplus, only have one input,
|
|
/// while the link relation of control anchor may exist in ge, resulting in two inputs. Therefore, "ge:" prefix
|
|
/// is added to correctly display the link relation at the expense of some color features
|
|
node_proto->set_op_type("ge:" + node->GetType());
|
|
|
|
if (kDumpLevel != DUMP_WITH_OUT_DESC) {
|
|
// 2.for attr
|
|
if (!EncodeNodeDesc(node, node_proto)) {
|
|
GELOGE(GRAPH_FAILED, "Encode NodeDesc: %s failed", node->GetName().c_str());
|
|
return false;
|
|
}
|
|
}
|
|
// 3.for link info
|
|
return EncodeNodeLink(node, node_proto);
|
|
}
|
|
|
|
void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type) {
|
|
if ((node == nullptr) || (tensor_type == nullptr)) {
|
|
GELOGE(GRAPH_FAILED, "EncodeTypeProtoTensorType: Input Para Node or tensor_type Invalid");
|
|
return;
|
|
}
|
|
const auto &op_desc = node->GetOpDesc();
|
|
if (op_desc != nullptr) {
|
|
uint32_t size_out = static_cast<uint32_t>(op_desc->GetOutputsSize());
|
|
if (size_out > 0) {
|
|
for (uint32_t i = 0; i < size_out; i++) {
|
|
const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i);
|
|
if (ge_tensor != nullptr) {
|
|
auto ge_data_type = ge_tensor->GetDataType();
|
|
auto onnx_data_type = EncodeDataType(ge_data_type);
|
|
tensor_type->set_elem_type(onnx_data_type);
|
|
onnx::TensorShapeProto *shape = tensor_type->mutable_shape();
|
|
if (shape != nullptr) {
|
|
for (auto d : ge_tensor->GetShape().GetDims()) {
|
|
auto dim = shape->add_dim();
|
|
dim->set_dim_value(d);
|
|
}
|
|
} else {
|
|
GELOGW("Shape is nullptr");
|
|
continue;
|
|
}
|
|
} else {
|
|
GELOGW("Ge tensor is nullptr");
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
GELOGW("OpDesc Is Empty, nodeName %s nodeType %s", node->GetName().c_str(), node->GetType().c_str());
|
|
return;
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::EncodeValueInfo(const NodePtr &node, onnx::ValueInfoProto *value_info_proto) {
|
|
if ((node == nullptr) || (value_info_proto == nullptr)) {
|
|
GELOGE(GRAPH_FAILED, "EncodeValueInfo: Input Para Node or value_info_proto Invalid");
|
|
return;
|
|
}
|
|
value_info_proto->set_name(node->GetName());
|
|
onnx::TypeProto *t = value_info_proto->mutable_type();
|
|
onnx::TypeProto_Tensor *tensor_type = t->mutable_tensor_type();
|
|
EncodeTypeProtoTensorType(node, tensor_type);
|
|
}
|
|
|
|
bool OnnxUtils::EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto) {
|
|
if ((graph == nullptr) || (graph_proto == nullptr)) {
|
|
GELOGE(GRAPH_FAILED, "EncodeGraph: Input para Invalid");
|
|
return false;
|
|
}
|
|
graph_proto->set_name(graph->GetName());
|
|
// 1. Add graph inputs
|
|
for (const auto &input : graph->GetInputNodes()) {
|
|
auto value_info_proto = graph_proto->add_input();
|
|
EncodeValueInfo(input, value_info_proto);
|
|
}
|
|
// 2. Add graph outputs
|
|
for (const auto &output : graph->GetOutputNodes()) {
|
|
auto value_info_proto = graph_proto->add_output();
|
|
EncodeValueInfo(output, value_info_proto);
|
|
}
|
|
// 3. Add nodes
|
|
for (const auto &node : graph->GetDirectNode()) {
|
|
if (!EncodeNode(node, graph_proto->add_node())) {
|
|
GELOGW("EncodeNode failed");
|
|
continue;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto) {
|
|
model_proto.set_model_version(model.GetVersion());
|
|
model_proto.set_ir_version(onnx::IR_VERSION);
|
|
model_proto.set_producer_name(model.GetName());
|
|
auto &graph = model.graph_;
|
|
auto compute_graph = GraphUtils::GetComputeGraph(graph);
|
|
if (compute_graph == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "GetComputeGraph: return nullptr");
|
|
return false;
|
|
}
|
|
auto graph_proto = model_proto.mutable_graph();
|
|
if (graph_proto == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "mutable_graph: %s return nullptr", compute_graph->GetName().c_str());
|
|
return false;
|
|
}
|
|
if (!EncodeGraph(compute_graph, graph_proto)) {
|
|
GELOGE(GRAPH_FAILED, "EncodeGraph: %s fail", compute_graph->GetName().c_str());
|
|
return false;
|
|
}
|
|
|
|
// For subgraphs: a subgraph is represented by a node
|
|
for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) {
|
|
if (sub_compute_graph != nullptr) {
|
|
auto node_proto = graph_proto->add_node();
|
|
if (node_proto == nullptr) {
|
|
GELOGW("Node proto is nullptr");
|
|
continue;
|
|
}
|
|
node_proto->set_name(sub_compute_graph->GetName());
|
|
node_proto->set_op_type(kNodeTypeForSubgraph);
|
|
auto attr = node_proto->add_attribute();
|
|
attr->set_name("graph");
|
|
attr->set_type(onnx::AttributeProto_AttributeType_GRAPH);
|
|
auto sub_graph_proto = attr->mutable_g();
|
|
if (sub_graph_proto == nullptr) {
|
|
GELOGW("Sub graph proto is nullptr");
|
|
continue;
|
|
}
|
|
if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) {
|
|
GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str());
|
|
continue;
|
|
}
|
|
} else {
|
|
GELOGW("Graph: %s subgraph is nullptr, skip EncodeGraph", compute_graph->GetName().c_str());
|
|
continue;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Part 2: from ONNX Protobuf convert to IR
|
|
static std::map<onnx::TensorProto_DataType, ge::DataType> onnxDataTypeToGeMap = {
|
|
{onnx::TensorProto_DataType_INT64, DT_INT64}, {onnx::TensorProto_DataType_UINT64, DT_UINT64},
|
|
{onnx::TensorProto_DataType_FLOAT, DT_FLOAT}, {onnx::TensorProto_DataType_INT32, DT_INT32},
|
|
{onnx::TensorProto_DataType_UINT32, DT_UINT32}, {onnx::TensorProto_DataType_INT8, DT_INT8},
|
|
{onnx::TensorProto_DataType_UINT8, DT_UINT8}, {onnx::TensorProto_DataType_INT16, DT_INT16},
|
|
{onnx::TensorProto_DataType_UINT16, DT_UINT16}, {onnx::TensorProto_DataType_FLOAT16, DT_FLOAT16},
|
|
{onnx::TensorProto_DataType_DOUBLE, DT_DOUBLE}, {onnx::TensorProto_DataType_BOOL, DT_BOOL},
|
|
};
|
|
|
|
ge::DataType OnnxUtils::DecodeDataType(onnx::TensorProto_DataType data_type) {
|
|
auto it = onnxDataTypeToGeMap.find(data_type);
|
|
if (it != onnxDataTypeToGeMap.end()) {
|
|
return it->second;
|
|
} else {
|
|
GELOGW("DecodeDataType: datatype not support %u", data_type);
|
|
return ge::DT_UNDEFINED;
|
|
}
|
|
}
|
|
|
|
bool OnnxUtils::ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index) {
|
|
auto sep = node_name_index.rfind(':');
|
|
if (sep == std::string::npos) {
|
|
return false;
|
|
}
|
|
node_name = node_name_index.substr(0, sep);
|
|
auto index_str = node_name_index.substr(sep + 1);
|
|
index = static_cast<int32_t>(std::strtol(index_str.c_str(), nullptr, 10));
|
|
return true;
|
|
}
|
|
|
|
bool OnnxUtils::DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr) {
|
|
if (node_ptr == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp: node_ptr is nullptr");
|
|
return false;
|
|
}
|
|
// Data edge
|
|
if (item.src_out_index >= 0) {
|
|
auto src_anchor = node_ptr->GetOutDataAnchor(item.src_out_index);
|
|
auto dst_anchor = item.dst_node->GetInDataAnchor(item.dst_in_index);
|
|
if ((src_anchor == nullptr) || (dst_anchor == nullptr)) {
|
|
GELOGE(GRAPH_FAILED, "Get data anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
|
|
item.dst_node_name.c_str(), item.dst_in_index);
|
|
return false;
|
|
}
|
|
if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "Data Anchor: src_anchor->LinkTo(dst_anchor) failed");
|
|
return false;
|
|
}
|
|
// Control edge
|
|
} else {
|
|
auto src_anchor = node_ptr->GetOutControlAnchor();
|
|
auto dst_anchor = item.dst_node->GetInControlAnchor();
|
|
if ((src_anchor == nullptr) || (dst_anchor == nullptr)) {
|
|
GELOGE(GRAPH_FAILED, "Get control anchor failed %s:%d, %s:%d ", item.src_node_name.c_str(), item.src_out_index,
|
|
item.dst_node_name.c_str(), item.dst_in_index);
|
|
return false;
|
|
}
|
|
if (src_anchor->LinkTo(dst_anchor) != GRAPH_SUCCESS) {
|
|
GELOGE(GRAPH_FAILED, "Control Anchor: src_anchor->LinkTo(dst_anchor) failed");
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool OnnxUtils::DecodeNodeLink(const std::vector<onnx::NodeProto> &node_proto_vector,
|
|
const std::map<std::string, NodePtr> &node_map) {
|
|
for (const auto &node_proto : node_proto_vector) {
|
|
const auto &node_name = node_proto.name();
|
|
auto dst_node = node_map.find(node_name);
|
|
if ((dst_node == node_map.end()) || (dst_node->second == nullptr)) {
|
|
GELOGE(GRAPH_FAILED, "destination node: %s find failed or is nullptr", node_name.c_str());
|
|
return false;
|
|
}
|
|
int32_t dst_index = 0;
|
|
for (const auto &input : node_proto.input()) {
|
|
std::string input_node_name;
|
|
int32_t index = 0;
|
|
if (ParseNameIndex(input, input_node_name, index)) {
|
|
auto item = NodeLinkInfo{input_node_name, index, dst_node->second, dst_index, node_proto.name()};
|
|
auto src_node = node_map.find(input_node_name);
|
|
if (src_node == node_map.end()) {
|
|
GELOGE(GRAPH_FAILED, "find src node: %s failed", input_node_name.c_str());
|
|
return false;
|
|
}
|
|
auto node_ptr = src_node->second;
|
|
if (node_ptr == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "src node: %s is nullptr", input_node_name.c_str());
|
|
return false;
|
|
}
|
|
if (!DecodeNodeLinkImp(item, node_ptr)) {
|
|
GELOGE(GRAPH_FAILED, "DecodeNodeLinkImp node: %s failed", input_node_name.c_str());
|
|
return false;
|
|
}
|
|
}
|
|
if (index >= 0) {
|
|
dst_index++;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<std::string> &strings) {
|
|
if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRINGS) {
|
|
GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
|
|
return;
|
|
}
|
|
for (int i = 0; i < attr_proto.strings_size(); i++) {
|
|
strings.push_back(attr_proto.strings(i));
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value) {
|
|
if (attr_proto.type() != onnx::AttributeProto_AttributeType_STRING) {
|
|
GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
|
|
return;
|
|
}
|
|
value = attr_proto.s();
|
|
}
|
|
|
|
void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<int64_t> &ints) {
|
|
if (attr_proto.type() != onnx::AttributeProto_AttributeType_INTS) {
|
|
GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
|
|
return;
|
|
}
|
|
for (int i = 0; i < attr_proto.ints_size(); i++) {
|
|
ints.push_back(attr_proto.ints(i));
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value) {
|
|
if (attr_proto.type() != onnx::AttributeProto_AttributeType_INT) {
|
|
GELOGE(GRAPH_FAILED, "Attribute %s call wrong decode attribute function", attr_proto.name().c_str());
|
|
return;
|
|
}
|
|
value = attr_proto.i();
|
|
}
|
|
|
|
void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto,
|
|
const std::string &attr_name_for_input_desc, int32_t index,
|
|
OpDescPtr &op_desc) {
|
|
if (op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast<uint32_t>(index)) is nullptr",
|
|
op_desc->GetName().c_str(), attr_name_for_input_desc.c_str());
|
|
return;
|
|
}
|
|
if (attr_name_for_input_desc == "input_desc_dtype") {
|
|
auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s());
|
|
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetDataType(data_type);
|
|
} else if (attr_name_for_input_desc == "input_desc_shape") {
|
|
std::vector<std::int64_t> ints;
|
|
DecodeAttribute(attr_proto, ints);
|
|
GeShape ge_shape(ints);
|
|
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape);
|
|
} else if (attr_name_for_input_desc == "input_desc_layout") {
|
|
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
|
|
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetFormat(data_format);
|
|
} else if (attr_name_for_input_desc == "input_desc_origin_shape") {
|
|
std::vector<std::int64_t> ints;
|
|
DecodeAttribute(attr_proto, ints);
|
|
GeShape ge_shape(ints);
|
|
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape);
|
|
} else if (attr_name_for_input_desc == "input_desc_origin_layout") {
|
|
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
|
|
op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format);
|
|
} else if (attr_name_for_input_desc == "input_desc_size") {
|
|
int64_t input_size = 0;
|
|
auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
|
|
DecodeAttribute(attr_proto, input_size);
|
|
tensor_descriptor->set_size(input_size);
|
|
} else if (attr_name_for_input_desc == "input_desc_data_offset") {
|
|
auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
|
|
int64_t offset = 0;
|
|
DecodeAttribute(attr_proto, offset);
|
|
tensor_descriptor->set_data_offset(offset);
|
|
} else {
|
|
return;
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto,
|
|
const std::string &attr_name_for_output_desc, int32_t index,
|
|
OpDescPtr &op_desc) {
|
|
if (op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) is nullptr",
|
|
op_desc->GetName().c_str(), attr_name_for_output_desc.c_str());
|
|
return;
|
|
}
|
|
if (attr_name_for_output_desc == "output_desc_dtype") {
|
|
auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s());
|
|
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetDataType(data_type);
|
|
} else if (attr_name_for_output_desc == "output_desc_shape") {
|
|
std::vector<std::int64_t> ints;
|
|
DecodeAttribute(attr_proto, ints);
|
|
GeShape ge_shape(ints);
|
|
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape);
|
|
} else if (attr_name_for_output_desc == "output_desc_layout") {
|
|
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
|
|
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetFormat(data_format);
|
|
} else if (attr_name_for_output_desc == "output_desc_origin_shape") {
|
|
std::vector<std::int64_t> ints;
|
|
DecodeAttribute(attr_proto, ints);
|
|
GeShape ge_shape(ints);
|
|
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape);
|
|
} else if (attr_name_for_output_desc == "output_desc_origin_layout") {
|
|
auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s());
|
|
op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format);
|
|
} else if (attr_name_for_output_desc == "output_desc_size") {
|
|
int64_t output_size = 0;
|
|
auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
|
|
DecodeAttribute(attr_proto, output_size);
|
|
tensor_descriptor->set_size(output_size);
|
|
} else if (attr_name_for_output_desc == "output_desc_data_offset") {
|
|
auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg();
|
|
int64_t offset = 0;
|
|
DecodeAttribute(attr_proto, offset);
|
|
tensor_descriptor->set_data_offset(offset);
|
|
} else {
|
|
return;
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto,
|
|
const std::string &attr_name_for_input_output_desc, int32_t index,
|
|
OpDescPtr &op_desc) {
|
|
if (op_desc == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "op_desc is nullptr");
|
|
return;
|
|
}
|
|
if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") {
|
|
DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
|
|
} else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") {
|
|
DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
|
|
} else {
|
|
return;
|
|
}
|
|
}
|
|
|
|
void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) {
|
|
auto attr_map = op_def.mutable_attr();
|
|
const auto &attr_name = attr_proto.name();
|
|
ge::proto::AttrDef op_attr;
|
|
int64_t value = 0;
|
|
DecodeAttribute(attr_proto, value);
|
|
op_attr.set_i(value);
|
|
attr_map->insert(AttrDefPair(attr_name, op_attr));
|
|
}
|
|
|
|
void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) {
|
|
if (op_desc == nullptr) {
|
|
GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr");
|
|
return;
|
|
}
|
|
const auto &attr_name = attr_proto.name();
|
|
std::string attr_name_for_input_output_desc;
|
|
int32_t index = 0;
|
|
if (!ParseNameIndex(attr_name, attr_name_for_input_output_desc, index)) {
|
|
if (attr_name == "id") {
|
|
op_desc->SetId(attr_proto.i());
|
|
} else if (attr_name == "stream_id") {
|
|
op_desc->SetStreamId(attr_proto.i());
|
|
} else if (attr_name == "src_name") {
|
|
std::vector<std::string> strings;
|
|
DecodeAttribute(attr_proto, strings);
|
|
op_desc->SetSrcName(strings);
|
|
} else if (attr_name == "dst_name") {
|
|
std::vector<std::string> strings;
|
|
DecodeAttribute(attr_proto, strings);
|
|
op_desc->SetDstName(strings);
|
|
} else if (attr_name == "src_index") {
|
|
std::vector<std::int64_t> ints;
|
|
DecodeAttribute(attr_proto, ints);
|
|
op_desc->SetSrcIndex(ints);
|
|
} else if (attr_name == "dst_index") {
|
|
std::vector<std::int64_t> ints;
|
|
DecodeAttribute(attr_proto, ints);
|
|
op_desc->SetDstIndex(ints);
|
|
} else if (attr_name == "fusion_scope") {
|
|
DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg());
|
|
} else if (attr_name == "input_i") {
|
|
std::vector<std::int64_t> ints;
|
|
DecodeAttribute(attr_proto, ints);
|
|
op_desc->SetInputOffset(ints);
|
|
} else if (attr_name == "output_i") {
|
|
std::vector<std::int64_t> ints;
|
|
DecodeAttribute(attr_proto, ints);
|
|
op_desc->SetOutputOffset(ints);
|
|
} else {
|
|
return;
|
|
}
|
|
// Update input and output desc
|
|
} else {
|
|
DecodeNodeAttributeForOpInAndOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc);
|
|
}
|
|
}
|
|
|
|
bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_desc) {
|
|
if (op_desc == nullptr || node_proto == nullptr) {
|
|
GELOGE(GRAPH_FAILED, " Op_desc is nullptr or node_proto is nullptr");
|
|
return false;
|
|
}
|
|
// 1. Decode node_proto name and type
|
|
op_desc->SetName(node_proto->name());
|
|
const auto &node_type_with_ge_prefix = node_proto->op_type();
|
|
auto sep = node_type_with_ge_prefix.find(':');
|
|
if (sep == std::string::npos) {
|
|
return false;
|
|
}
|
|
auto node_type = node_type_with_ge_prefix.substr(sep + 1);
|
|
op_desc->SetType(node_type);
|
|
// 2. Add empty input and output desc
|
|
for (const auto &attr : node_proto->attribute()) {
|
|
if (attr.name() == "input_desc_nums") {
|
|
auto size_in = attr.i();
|
|
for (int64_t i = 0; i < size_in; i++) {
|
|
GeTensorDesc ge_tensor_desc;
|
|
GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed.");
|
|
}
|
|
}
|
|
if (attr.name() == "output_desc_nums") {
|
|
auto size_out = attr.i();
|
|
for (int64_t i = 0; i < size_out; i++) {
|
|
GeTensorDesc ge_tensor_desc;
|
|
GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed.");
|
|
}
|
|
}
|
|
}
|
|
// 3.Decode node_proto attributes
|
|
for (int i = 0; i < node_proto->attribute_size(); i++) {
|
|
DecodeNodeAttributeForOpDesc(node_proto->attribute(i), op_desc);
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph) {
|
|
if (recursion_depth > kMaxRecursionDepth) {
|
|
GELOGE(GRAPH_FAILED, "DecodeGraph: recursion depth is too large, abort");
|
|
return false;
|
|
}
|
|
|
|
graph = ComGraphMakeShared<ge::ComputeGraph>(graph_proto.name());
|
|
GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed");
|
|
/// 1. Decode all nodes first, node should include input
|
|
/// and output nodes and nodes which represent sub graphs
|
|
std::map<std::string, NodePtr> node_map;
|
|
std::vector<onnx::NodeProto> node_proto_vector;
|
|
for (const auto &node_proto : graph_proto.node()) {
|
|
// a. nodes represent sub graphs
|
|
if (node_proto.op_type() == kNodeTypeForSubgraph) {
|
|
ComputeGraphPtr compute_graph;
|
|
// in this case, node only have one attr, whose type is AttributeProto_AttributeType_GRAPH
|
|
const auto &node_attr = node_proto.attribute(0);
|
|
if ((node_attr.type() == onnx::AttributeProto_AttributeType_GRAPH) &&
|
|
DecodeGraph(recursion_depth + 1, node_attr.g(), compute_graph)) {
|
|
(void)graph->AddSubGraph(compute_graph);
|
|
} else {
|
|
GELOGE(GRAPH_FAILED, "Decode sub graph %s failed with node type:%d", node_proto.name().c_str(),
|
|
node_attr.type());
|
|
return false;
|
|
}
|
|
// b. direct nodes in graph
|
|
} else {
|
|
node_proto_vector.push_back(node_proto);
|
|
OpDescPtr op_desc = ComGraphMakeShared<OpDesc>();
|
|
// b.1 For node desc
|
|
if (!DecodeNodeDesc(&node_proto, op_desc)) {
|
|
GELOGE(GRAPH_FAILED, "Decode node desc %s failed ", node_proto.name().c_str());
|
|
return false;
|
|
}
|
|
auto node = graph->AddNode(op_desc);
|
|
node_map.insert(std::make_pair(node_proto.name(), node));
|
|
}
|
|
}
|
|
/// We get all nodes in graph here
|
|
/// b.2 For node link
|
|
if (!DecodeNodeLink(node_proto_vector, node_map)) {
|
|
GELOGE(GRAPH_FAILED, "Decode node link failed");
|
|
return false;
|
|
}
|
|
|
|
// 2. Add inputs nodes for graph
|
|
for (const auto &input : graph_proto.input()) {
|
|
const auto &input_node_name = input.name();
|
|
auto input_node_item = node_map.find(input_node_name);
|
|
if (input_node_item == node_map.end()) {
|
|
GELOGE(GRAPH_FAILED, "cannot find graph's input node %s in node_", input_node_name.c_str());
|
|
return false;
|
|
}
|
|
auto ret = graph->AddInputNode(input_node_item->second);
|
|
GE_CHK_BOOL_EXEC(ret != nullptr, continue, "Add inputnode failed");
|
|
}
|
|
// 3. Add outputs nodes for graph
|
|
for (const auto &output : graph_proto.output()) {
|
|
const auto &output_node_name = output.name();
|
|
auto output_node_item = node_map.find(output_node_name);
|
|
if (output_node_item == node_map.end()) {
|
|
GELOGE(GRAPH_FAILED, "cannot find graph's output node %s in node_", output_node_name.c_str());
|
|
return false;
|
|
}
|
|
auto ret = graph->AddOutputNode(output_node_item->second);
|
|
if (ret == nullptr) {
|
|
GELOGW("Add outputnode failed,out put node is %s", output_node_name.c_str());
|
|
continue;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool OnnxUtils::ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model) {
|
|
model.name_ = model_proto.producer_name();
|
|
model.version_ = static_cast<uint32_t>(model_proto.model_version());
|
|
|
|
auto &graph_proto = model_proto.graph();
|
|
ComputeGraphPtr compute_graph;
|
|
// 0 means recursion depth, father call
|
|
if (!DecodeGraph(0, graph_proto, compute_graph)) {
|
|
GELOGE(GRAPH_FAILED, "Decode compute graph from graph_proto failed");
|
|
return false;
|
|
}
|
|
model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph);
|
|
return true;
|
|
}
|
|
} // namespace ge
|