You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/metadef/graph/utils/ge_ir_utils.h

207 lines
7.8 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_
#define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_
#include <google/protobuf/map.h>
#include <google/protobuf/repeated_field.h>
#include <google/protobuf/stubs/port.h>
#include <graph/anchor.h>
#include <graph/debug/ge_log.h>
#include <graph/debug/ge_util.h>
#include <graph/detail/attributes_holder.h>
#include <graph/ge_tensor.h>
#include <graph/graph.h>
#include <graph/model.h>
#include <graph/node.h>
#include <graph/utils/graph_utils.h>
#include <graph/utils/type_utils.h>
#include <map>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>
#include "proto/ge_ir.pb.h"
#include "proto/onnx.pb.h"
namespace ge {
const int kOffsetToString = 2;
///
/// @ingroup ge_ir_utils
/// @brief RepeatedField->String
/// @param [in] const rpd_field RepeatedField
/// @return String
///
template <typename T>
const std::string ToString(const google::protobuf::RepeatedField<T> &rpd_field) {
std::stringstream ss;
ss << "[";
for (const T &x : rpd_field) {
ss << x;
ss << ", ";
}
std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString);
str_ret += "]";
return str_ret;
}
///
/// @ingroup ge_ir_utils
/// @brief RepeatedPtrField->String
/// @param [in] const rpd_field RepeatedPtrField
/// @return String
///
template <typename T>
const std::string ToString(const google::protobuf::RepeatedPtrField<T> &rpd_ptr_field) {
std::stringstream ss;
ss << "[";
for (const T &x : rpd_ptr_field) {
ss << x;
ss << ", ";
}
std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString);
str_ret += "]";
return str_ret;
}
///
/// @ingroup ge_ir_utils
/// @brief check, if not equal, log with tag
/// @param [in] const left_value, right_value reference, log_info_tag
/// @return bool
///
template <typename T>
bool IsEqual(const T &l_value, const T &r_value, const std::string &log_info_tag) {
if (l_value == r_value) {
return true;
} else {
GELOGE(GRAPH_FAILED, "Check failed with %s", log_info_tag.c_str());
return false;
}
}
class OnnxUtils {
public:
enum DumpLevel { NO_DUMP = 0, DUMP_ALL = 1, DUMP_WITH_OUT_DATA = 2, DUMP_WITH_OUT_DESC = 3, DUMP_LEVEL_END };
static bool ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto);
static bool ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model);
private:
// Part 1: from IR convert to ONNX Protobuf
static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
const std::string &name, void *data);
static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
const std::string &name, ::google::protobuf::RepeatedField<::google::protobuf::int64> data);
static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
const std::string &name, ::google::protobuf::RepeatedField<bool> data);
static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
const std::string &name, ::google::protobuf::RepeatedField<float> data);
static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type,
const std::string &name, ::google::protobuf::RepeatedPtrField<::std::string> data);
static void AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto);
static void AddAttrProtoFromAttribute(const std::pair<const std::string, ge::GeAttrValue> &string_attr_value,
onnx::NodeProto *node_proto);
static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc);
static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map<std::string, ge::proto::AttrDef> &attr_map,
onnx::NodeProto *node_proto, const std::string &prefix = "",
const std::string &suffix = "");
static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto);
static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type);
static void EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto);
static bool EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto);
static bool EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto);
static bool EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto);
static void EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type);
static void EncodeValueInfo(const NodePtr &n, onnx::ValueInfoProto *v);
static bool EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto);
/// Part 2: from ONNX Protobuf convert to IR
/// Describes node's link relationships
struct NodeLinkInfo {
std::string src_node_name;
int32_t src_out_index;
NodePtr dst_node;
int32_t dst_in_index;
std::string dst_node_name;
};
// Parse node name and index
static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index);
static ge::DataType DecodeDataType(onnx::TensorProto_DataType data_type);
static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<std::string> &strings);
static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector<int64_t> &ints);
static void DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value);
static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value);
static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_output_desc, int32_t index,
OpDescPtr &op_desc);
static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_input_desc, int32_t index,
OpDescPtr &op_desc);
static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto,
const std::string &attr_name_for_input_output_desc, int32_t index,
OpDescPtr &op_desc);
static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def);
static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc);
static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr);
static bool DecodeNodeLink(const std::vector<onnx::NodeProto> &node_proto_vector,
const std::map<std::string, NodePtr> &node_map);
static bool DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &node);
static bool DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph);
};
} // namespace ge
#endif // COMMON_GRAPH_UTILS_GE_IR_UTILS_H_