You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
138 lines
6.2 KiB
138 lines
6.2 KiB
/**
|
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_
|
|
#define GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_
|
|
|
|
#include <vector>
|
|
#include <utility>
|
|
#include "inc/graph_pass.h"
|
|
|
|
namespace ge {
|
|
///
|
|
/// Transform operators depth fusion
|
|
///
|
|
class TransOpWithoutReshapeFusionPass : public GraphPass {
|
|
public:
|
|
TransOpWithoutReshapeFusionPass() {}
|
|
virtual ~TransOpWithoutReshapeFusionPass() {}
|
|
|
|
graphStatus Run(ge::ComputeGraphPtr graph) override;
|
|
|
|
private:
|
|
void SetRemainNode(const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor);
|
|
bool FormatContinuousCheck(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor);
|
|
void RemoveNousedNodes(const ComputeGraphPtr &graph);
|
|
void GetBeginOutDescAndEndInDesc(const int index, GeTensorDesc &out_desc, GeTensorDesc &in_desc);
|
|
|
|
void GetFormatTransferDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc,
|
|
GeTensorDesc &format_transfer_input, GeTensorDesc &format_transfer_output);
|
|
|
|
void GetCastOpDesc(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc, GeTensorDesc &cast_input,
|
|
GeTensorDesc &cast_output);
|
|
|
|
graphStatus FormatFusion(const int index, OpDescPtr &format_transfer_op, int32_t &fusion_op_count,
|
|
bool &fusion_continue);
|
|
|
|
graphStatus DataTypeFusion(const int index, OpDescPtr &cast_op, int32_t &fusion_op_count);
|
|
|
|
void GetOutDataPeerInControlAnchors(const size_t index,
|
|
vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors);
|
|
|
|
void GetInControlPeerOutControlAnchors(const size_t index,
|
|
vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors);
|
|
|
|
void GetOutControlPeerAnchors(const size_t index,
|
|
vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors,
|
|
vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors);
|
|
|
|
graphStatus TransOpFuse(const ComputeGraphPtr &graph);
|
|
|
|
bool OpAccuracyAbilityCheck(const OpDescPtr &op_desc);
|
|
|
|
graphStatus GetSubGraphsBetweenNormalNode(
|
|
const OutDataAnchorPtr &out_anchor, vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>> &sub_graphs_out,
|
|
vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list);
|
|
|
|
graphStatus GetSubGraphNodesInfo();
|
|
|
|
void GetControlAnchors();
|
|
|
|
graphStatus InsertNewTransOp(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
|
|
const OpDescPtr &format_transfer_op, const int index, const bool insert_cast_first);
|
|
|
|
void EraseInvalidAnchorsPair();
|
|
|
|
graphStatus RelinkNodesWhenDescNotChanged(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
|
|
const int index);
|
|
|
|
OpDescPtr GetFormatTransferOp(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc);
|
|
|
|
OpDescPtr GetCastOp(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc);
|
|
|
|
graphStatus TransOpFuseHandle(const ge::ComputeGraphPtr &graph, const int index);
|
|
|
|
graphStatus AddTransNode(const ComputeGraphPtr &graph, const OpDescPtr &transop, NodePtr &trans_node);
|
|
|
|
bool DescEqualCheck(ConstGeTensorDescPtr &desc_src, ConstGeTensorDescPtr &desc_dst) const;
|
|
|
|
bool ShapeEqualCheck(const GeShape &src, const GeShape &dst) const;
|
|
|
|
bool InsertCastFirstCheck(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc) const;
|
|
|
|
graphStatus RelinkControlEdge(const int index, const OutDataAnchorPtr &out_anchor,
|
|
const vector<NodePtr> &new_trans_nodes);
|
|
|
|
graphStatus GetTransNode(const ComputeGraphPtr &graph, const OpDescPtr &cast_op, const OpDescPtr &format_transfer_op,
|
|
const bool insert_cast_first, std::vector<NodePtr> &new_trans_nodes);
|
|
|
|
void UpdateOutputName(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &old_peer_in_anchor,
|
|
const NodePtr &in_owner_node);
|
|
void UpdateInputName(const OutDataAnchorPtr &old_peer_out_anchor, const InDataAnchorPtr &in_anchor,
|
|
const NodePtr &out_owner_node);
|
|
|
|
graphStatus RelinkControlEdgesWhenDescNotChanged(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
|
|
const int index);
|
|
|
|
graphStatus RelinkSubGraphControlEdges(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
|
|
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
|
|
const int index);
|
|
///
|
|
/// judge whether an operator is a transform op or not
|
|
/// @param node
|
|
/// @return True or False
|
|
///
|
|
static bool IsTransOp(const NodePtr &node);
|
|
|
|
static bool FusionFormatSupport(Format format);
|
|
|
|
vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>> sub_graph_anchors_;
|
|
vector<vector<NodePtr>> sub_graph_nodes_;
|
|
vector<int> transop_num_count_;
|
|
vector<bool> sub_graph_has_reshape_node_;
|
|
vector<vector<OutControlAnchorPtr>> in_control_peer_out_control_anchors_;
|
|
vector<vector<InControlAnchorPtr>> out_control_peer_in_control_anchors_;
|
|
vector<vector<InDataAnchorPtr>> out_control_peer_in_data_anchors_;
|
|
vector<vector<InControlAnchorPtr>> out_data_peer_in_control_anchors_;
|
|
vector<bool> sub_graph_has_control_edge_;
|
|
vector<bool> sub_graph_has_out_data_peer_in_control_edge_;
|
|
};
|
|
} // namespace ge
|
|
|
|
#endif // GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_
|