You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/passes/transop_without_reshape_fus...

154 lines
6.4 KiB

5 years ago
/**
* Copyright 2020 Huawei Technologies Co., Ltd
5 years ago
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_
#define GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_
#include <vector>
#include <utility>
5 years ago
#include "inc/graph_pass.h"
namespace ge {
///
/// Transform operators depth fusion
///
class TransOpWithoutReshapeFusionPass : public GraphPass {
public:
TransOpWithoutReshapeFusionPass() {}
virtual ~TransOpWithoutReshapeFusionPass() {}
graphStatus Run(ge::ComputeGraphPtr graph) override;
private:
void SetRemainNode(const vector<pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_anchor);
bool FormatContinuousCheck(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor);
void RemoveNousedNodes(const ComputeGraphPtr &graph);
void GetBeginOutDescAndEndInDesc(const int index, GeTensorDesc &out_desc, GeTensorDesc &in_desc);
4 years ago
void GetFormatTransferDesc(const GeTensorDesc &out_desc,
const GeTensorDesc &in_desc,
GeTensorDesc &format_transfer_input,
GeTensorDesc &format_transfer_output);
5 years ago
4 years ago
void GetCastOpDesc(const GeTensorDesc &out_desc,
const GeTensorDesc &in_desc,
GeTensorDesc &cast_input,
5 years ago
GeTensorDesc &cast_output);
4 years ago
graphStatus FormatFusion(const int index,
OpDescPtr &format_transfer_op,
int32_t &fusion_op_count,
5 years ago
bool &fusion_continue);
graphStatus DataTypeFusion(const int index, OpDescPtr &cast_op, int32_t &fusion_op_count);
void GetOutDataPeerInControlAnchors(const size_t index,
vector<vector<InControlAnchorPtr>> &out_data_peer_in_control_anchors);
4 years ago
void GetInControlPeerOutControlAnchors(
const size_t index,
vector<vector<OutControlAnchorPtr>> &in_control_peer_out_control_anchors);
5 years ago
4 years ago
void GetOutControlPeerAnchors(
const size_t index,
vector<vector<InControlAnchorPtr>> &out_control_peer_in_control_anchors,
vector<vector<InDataAnchorPtr>> &out_control_peer_in_data_anchors);
5 years ago
graphStatus TransOpFuse(const ComputeGraphPtr &graph);
bool OpAccuracyAbilityCheck(const OpDescPtr &op_desc);
graphStatus GetSubGraphsBetweenNormalNode(
4 years ago
const OutDataAnchorPtr &out_anchor,
vector<vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>>
>& sub_graphs_out,
vector<std::pair<OutDataAnchorPtr, InDataAnchorPtr>> &nodes_list
);
5 years ago
graphStatus GetSubGraphNodesInfo();
void GetControlAnchors();
graphStatus InsertNewTransOp(const ComputeGraphPtr &graph, const OpDescPtr &cast_op,
4 years ago
const OpDescPtr &format_transfer_op, const int index,
const bool insert_cast_first);
5 years ago
void EraseInvalidAnchorsPair();
graphStatus RelinkNodesWhenDescNotChanged(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
const int index);
OpDescPtr GetFormatTransferOp(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc);
OpDescPtr GetCastOp(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc);
graphStatus TransOpFuseHandle(const ge::ComputeGraphPtr &graph, const int index);
graphStatus AddTransNode(const ComputeGraphPtr &graph, const OpDescPtr &transop, NodePtr &trans_node);
bool DescEqualCheck(ConstGeTensorDescPtr &desc_src, ConstGeTensorDescPtr &desc_dst) const;
bool ShapeEqualCheck(const GeShape &src, const GeShape &dst) const;
bool InsertCastFirstCheck(const GeTensorDesc &out_desc, const GeTensorDesc &in_desc) const;
graphStatus RelinkControlEdge(const int index, const OutDataAnchorPtr &out_anchor,
const vector<NodePtr> &new_trans_nodes);
4 years ago
graphStatus GetTransNode(const ComputeGraphPtr &graph,
const OpDescPtr &cast_op,
const OpDescPtr &format_transfer_op,
const bool insert_cast_first,
std::vector<NodePtr> &new_trans_nodes);
5 years ago
void UpdateOutputName(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &old_peer_in_anchor,
const NodePtr &in_owner_node);
void UpdateInputName(const OutDataAnchorPtr &old_peer_out_anchor, const InDataAnchorPtr &in_anchor,
const NodePtr &out_owner_node);
graphStatus RelinkControlEdgesWhenDescNotChanged(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
const int index);
graphStatus RelinkSubGraphControlEdges(const pair<OutDataAnchorPtr, InDataAnchorPtr> &begin_anchors_pair,
const pair<OutDataAnchorPtr, InDataAnchorPtr> &end_anchors_pair,
const int index);
///
/// judge whether an operator is a transform op or not
/// @param node
/// @return True or False
///
static bool IsTransOp(const NodePtr &node);
static bool FusionFormatSupport(Format format);
4 years ago
vector<vector<pair<OutDataAnchorPtr, InDataAnchorPtr>>>
sub_graph_anchors_;
5 years ago
vector<vector<NodePtr>> sub_graph_nodes_;
vector<int> transop_num_count_;
vector<bool> sub_graph_has_reshape_node_;
vector<vector<OutControlAnchorPtr>> in_control_peer_out_control_anchors_;
vector<vector<InControlAnchorPtr>> out_control_peer_in_control_anchors_;
vector<vector<InDataAnchorPtr>> out_control_peer_in_data_anchors_;
vector<vector<InControlAnchorPtr>> out_data_peer_in_control_anchors_;
vector<bool> sub_graph_has_control_edge_;
vector<bool> sub_graph_has_out_data_peer_in_control_edge_;
};
} // namespace ge
#endif // GE_GRAPH_PASSES_TRANSOP_WITHOUT_RESHAPE_FUSION_PASS_H_
4 years ago