/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_ #define GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_ #include #include #include "inc/graph_pass.h" namespace ge { /// /// Transform operators depth fusion /// class TransOpWithoutReshapeFusionPass : public GraphPass { public: TransOpWithoutReshapeFusionPass() {} virtual ~TransOpWithoutReshapeFusionPass() {} graphStatus Run(ge::ComputeGraphPtr graph) override; private: void SetRemainNode(const vector> &nodes_anchor); bool FormatContinuousCheck(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor); void RemoveNousedNodes(const ComputeGraphPtr &graph); void GetBeginOutDescAndEndInDesc(const int index, GeTensorDesc &out_desc, GeTensorDesc &in_desc); void GetFormatTransferDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc, GeTensorDesc &format_transfer_input, GeTensorDesc &format_transfer_output); void GetCastOpDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc, GeTensorDesc &cast_input, GeTensorDesc &cast_output); graphStatus FormatFusion(const int index, OpDescPtr &format_transfer_op, int32_t &fusion_op_count, bool &fusion_continue); graphStatus DataTypeFusion(const int index, OpDescPtr &cast_op, int32_t &fusion_op_count); void GetOutDataPeerInControlAnchors(const size_t index, vector> &out_data_peer_in_control_anchors); void GetInControlPeerOutControlAnchors( const size_t index, vector> &in_control_peer_out_control_anchors); void GetOutControlPeerAnchors( const size_t index, vector> &out_control_peer_in_control_anchors, vector> &out_control_peer_in_data_anchors); graphStatus TransOpFuse(const ComputeGraphPtr &graph); bool OpAccuracyAbilityCheck(const OpDescPtr &op_desc); graphStatus GetSubGraphsBetweenNormalNode( const OutDataAnchorPtr &out_anchor, vector> >& sub_graphs_out, vector> &nodes_list ); graphStatus GetSubGraphNodesInfo(); void GetControlAnchors(); graphStatus InsertNewTransOp(const ComputeGraphPtr &graph, const OpDescPtr &cast_op, const OpDescPtr &format_transfer_op, const int index, const bool insert_cast_first); void EraseInvalidAnchorsPair(); graphStatus RelinkNodesWhenDescNotChanged(const pair &begin_anchors_pair, const pair &end_anchors_pair, const int index); OpDescPtr GetFormatTransferOp(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc); OpDescPtr GetCastOp(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc); graphStatus TransOpFuseHandle(const ge::ComputeGraphPtr &graph, const int index); graphStatus AddTransNode(const ComputeGraphPtr &graph, const OpDescPtr &transop, NodePtr &trans_node); bool DescEqualCheck(ConstGeTensorDescPtr &desc_src, ConstGeTensorDescPtr &desc_dst) const; bool ShapeEqualCheck(const GeShape &src, const GeShape &dst) const; bool InsertCastFirstCheck(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc) const; graphStatus RelinkControlEdge(const int index, const OutDataAnchorPtr &out_anchor, const vector &new_trans_nodes); graphStatus GetTransNode(const ComputeGraphPtr &graph, const OpDescPtr &cast_op, const OpDescPtr &format_transfer_op, const bool insert_cast_first, std::vector &new_trans_nodes); void UpdateOutputName(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &old_peer_in_anchor, const NodePtr &in_owner_node); void UpdateInputName(const OutDataAnchorPtr &old_peer_out_anchor, const InDataAnchorPtr &in_anchor, const NodePtr &out_owner_node); graphStatus RelinkControlEdgesWhenDescNotChanged(const pair &begin_anchors_pair, const pair &end_anchors_pair, const int index); graphStatus RelinkSubGraphControlEdges(const pair &begin_anchors_pair, const pair &end_anchors_pair, const int index); /// /// judge whether an operator is a transform op or not /// @param node /// @return True or False /// static bool IsTransOp(const NodePtr &node); static bool FusionFormatSupport(Format format); vector>> sub_graph_anchors_; vector> sub_graph_nodes_; vector transop_num_count_; vector sub_graph_has_reshape_node_; vector> in_control_peer_out_control_anchors_; vector> out_control_peer_in_control_anchors_; vector> out_control_peer_in_data_anchors_; vector> out_data_peer_in_control_anchors_; vector sub_graph_has_control_edge_; vector sub_graph_has_out_data_peer_in_control_edge_; }; } // namespace ge #endif // GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_