/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ #define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include "proto/ge_ir.pb.h" #include "proto/onnx.pb.h" namespace ge { const int kOffsetToString = 2; /// /// @ingroup ge_ir_utils /// @brief RepeatedField->String /// @param [in] const rpd_field RepeatedField /// @return String /// template const std::string ToString(const google::protobuf::RepeatedField &rpd_field) { std::stringstream ss; ss << "["; for (const T &x : rpd_field) { ss << x; ss << ", "; } std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); str_ret += "]"; return str_ret; } /// /// @ingroup ge_ir_utils /// @brief RepeatedPtrField->String /// @param [in] const rpd_field RepeatedPtrField /// @return String /// template const std::string ToString(const google::protobuf::RepeatedPtrField &rpd_ptr_field) { std::stringstream ss; ss << "["; for (const T &x : rpd_ptr_field) { ss << x; ss << ", "; } std::string str_ret = ss.str().substr(0, ss.str().length() - kOffsetToString); str_ret += "]"; return str_ret; } /// /// @ingroup ge_ir_utils /// @brief check, if not equal, log with tag /// @param [in] const left_value, right_value reference, log_info_tag /// @return bool /// template bool IsEqual(const T &l_value, const T &r_value, const std::string &log_info_tag) { if (l_value == r_value) { return true; } else { GELOGE(GRAPH_FAILED, "Check failed with %s", log_info_tag.c_str()); return false; } } class OnnxUtils { public: enum DumpLevel { NO_DUMP = 0, DUMP_ALL = 1, DUMP_WITH_OUT_DATA = 2, DUMP_WITH_OUT_DESC = 3, DUMP_LEVEL_END }; static bool ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelProto &model_proto); static bool ConvertModelProtoToGeModel(const onnx::ModelProto &model_proto, ge::Model &model); private: // Part 1: from IR convert to ONNX Protobuf static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const std::string &name, void *data); static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const std::string &name, ::google::protobuf::RepeatedField<::google::protobuf::int64> data); static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const std::string &name, ::google::protobuf::RepeatedField data); static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const std::string &name, ::google::protobuf::RepeatedField data); static void AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const std::string &name, ::google::protobuf::RepeatedPtrField<::std::string> data); static void AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto *node_proto); static void AddAttrProtoFromAttribute(const std::pair &string_attr_value, onnx::NodeProto *node_proto); static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc); static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map &attr_map, onnx::NodeProto *node_proto, const std::string &prefix = "", const std::string &suffix = ""); static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto); static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); static void EncodeNodeLinkForNetronVisual(const NodePtr &node, onnx::NodeProto *node_proto); static bool EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto); static bool EncodeNodeDesc(const NodePtr &node, onnx::NodeProto *node_proto); static bool EncodeNode(const NodePtr &node, onnx::NodeProto *node_proto); static void EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_Tensor *tensor_type); static void EncodeValueInfo(const NodePtr &n, onnx::ValueInfoProto *v); static bool EncodeGraph(const ConstComputeGraphPtr &graph, onnx::GraphProto *graph_proto); /// Part 2: from ONNX Protobuf convert to IR /// Describes node's link relationships struct NodeLinkInfo { std::string src_node_name; int32_t src_out_index; NodePtr dst_node; int32_t dst_in_index; std::string dst_node_name; }; // Parse node name and index static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, int32_t &index); static ge::DataType DecodeDataType(onnx::TensorProto_DataType data_type); static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &strings); static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::vector &ints); static void DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t &value); static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value); static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, const std::string &attr_name_for_output_desc, int32_t index, OpDescPtr &op_desc); static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, const std::string &attr_name_for_input_desc, int32_t index, OpDescPtr &op_desc); static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, const std::string &attr_name_for_input_output_desc, int32_t index, OpDescPtr &op_desc); static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def); static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); static bool DecodeNodeLink(const std::vector &node_proto_vector, const std::map &node_map); static bool DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &node); static bool DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_proto, ComputeGraphPtr &graph); }; } // namespace ge #endif // COMMON_GRAPH_UTILS_GE_IR_UTILS_H_