You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
200 lines
5.9 KiB
200 lines
5.9 KiB
5 years ago
|
/**
|
||
|
* Copyright 2019-2020 Huawei Technologies Co., Ltd
|
||
|
*
|
||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||
|
* you may not use this file except in compliance with the License.
|
||
|
* You may obtain a copy of the License at
|
||
|
*
|
||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
|
* See the License for the specific language governing permissions and
|
||
|
* limitations under the License.
|
||
|
*/
|
||
|
|
||
|
#ifndef GE_GRAPH_PASSES_FOR_PASS_H
|
||
|
#define GE_GRAPH_PASSES_FOR_PASS_H
|
||
|
|
||
|
#include "graph/passes/base_pass.h"
|
||
|
|
||
|
struct ForInfo {
|
||
|
ForInfo() : for_node(nullptr), start(nullptr), limit(nullptr), delta(nullptr), for_body(nullptr) {}
|
||
|
ge::NodePtr for_node;
|
||
|
ge::OutDataAnchorPtr start;
|
||
|
ge::OutDataAnchorPtr limit;
|
||
|
ge::OutDataAnchorPtr delta;
|
||
|
std::string body_name;
|
||
|
ge::ComputeGraphPtr for_body;
|
||
|
std::vector<ge::OutDataAnchorPtr> data_inputs;
|
||
|
std::vector<std::vector<ge::InDataAnchorPtr>> data_outputs;
|
||
|
std::vector<ge::OutControlAnchorPtr> ctrl_inputs;
|
||
|
std::vector<ge::InControlAnchorPtr> ctrl_outputs;
|
||
|
};
|
||
|
|
||
|
struct WhileInfo {
|
||
5 years ago
|
WhileInfo()
|
||
|
: while_node(nullptr),
|
||
|
sub_graph_node(nullptr),
|
||
|
i(nullptr),
|
||
|
abs_delta(nullptr),
|
||
|
range(nullptr),
|
||
|
start(nullptr),
|
||
|
delta(nullptr),
|
||
|
for_body(nullptr),
|
||
|
while_cond(nullptr),
|
||
|
while_body(nullptr) {}
|
||
5 years ago
|
ge::NodePtr while_node;
|
||
|
ge::NodePtr sub_graph_node;
|
||
|
ge::OutDataAnchorPtr i;
|
||
5 years ago
|
ge::OutDataAnchorPtr abs_delta;
|
||
|
ge::OutDataAnchorPtr range;
|
||
5 years ago
|
ge::OutDataAnchorPtr start;
|
||
|
ge::OutDataAnchorPtr delta;
|
||
|
std::string for_body_name;
|
||
|
ge::ComputeGraphPtr for_body;
|
||
|
ge::ComputeGraphPtr while_cond;
|
||
|
ge::ComputeGraphPtr while_body;
|
||
|
std::vector<ge::OutDataAnchorPtr> data_inputs;
|
||
|
std::vector<std::vector<ge::InDataAnchorPtr>> data_outputs;
|
||
|
std::vector<ge::OutControlAnchorPtr> ctrl_inputs;
|
||
|
std::vector<ge::InControlAnchorPtr> ctrl_outputs;
|
||
|
};
|
||
|
|
||
|
namespace ge {
|
||
|
class ForPass : public BaseNodePass {
|
||
|
public:
|
||
|
Status Run(NodePtr &node) override;
|
||
|
|
||
|
private:
|
||
|
///
|
||
|
/// @brief Build for_info
|
||
|
/// @param [in] root_graph
|
||
|
/// @param [in] node
|
||
|
/// @param [out] for_info
|
||
|
/// @return Status
|
||
|
///
|
||
|
static Status BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info);
|
||
|
|
||
|
///
|
||
|
/// @brief Transfer while_info from for_info
|
||
|
/// @param [in] graph
|
||
|
/// @param [in] for_info
|
||
|
/// @param [out] while_info
|
||
|
/// @return Status
|
||
|
///
|
||
|
Status TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info);
|
||
|
|
||
|
///
|
||
|
/// @brief Build cond_graph for while_node
|
||
|
/// @param [in&out] while_info
|
||
|
/// @return ComputeGraphPtr
|
||
|
///
|
||
|
static ComputeGraphPtr BuildCondGraph(WhileInfo &while_info);
|
||
|
|
||
|
///
|
||
|
/// @brief Build body_graph for while_node
|
||
|
/// @param [in&out] while_info
|
||
|
/// @return ComputeGraphPtr
|
||
|
///
|
||
|
static ComputeGraphPtr BuildBodyGraph(WhileInfo &while_info);
|
||
|
|
||
|
///
|
||
|
/// @brief Update InputMapping for for-body-graph
|
||
|
/// @param [in] while_info
|
||
|
/// @return Status
|
||
|
///
|
||
|
static Status UpdateForBodyInputMapping(const WhileInfo &while_info);
|
||
|
|
||
|
///
|
||
|
/// @brief Find input with index for For node
|
||
|
/// @param [in] node
|
||
|
/// @param [in] index
|
||
|
/// @return OutDataAnchorPtr
|
||
|
///
|
||
|
static OutDataAnchorPtr FindInputWithIndex(const NodePtr &node, uint32_t index);
|
||
|
|
||
|
///
|
||
|
/// @brief Find inputs / outputs for for node
|
||
|
/// @param [in] node
|
||
|
/// @param [out] data_inputs
|
||
|
/// @param [out] data_outputs
|
||
|
/// @param [out] ctrl_inputs
|
||
|
/// @param [out] ctrl_outputs
|
||
|
/// @return Status
|
||
|
///
|
||
|
static Status FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnchorPtr> &data_inputs,
|
||
|
std::vector<std::vector<ge::InDataAnchorPtr>> &data_outputs,
|
||
|
std::vector<ge::OutControlAnchorPtr> &ctrl_inputs,
|
||
|
std::vector<ge::InControlAnchorPtr> &ctrl_outputs);
|
||
|
|
||
|
///
|
||
|
/// @brief Create const op_desc
|
||
|
/// @param [in] name
|
||
|
/// @param [in] value
|
||
|
/// @return OpDescPtr
|
||
|
///
|
||
|
static OpDescPtr CreateConstDesc(const std::string &name, int32_t value);
|
||
|
|
||
|
///
|
||
5 years ago
|
/// @brief Create loop input
|
||
5 years ago
|
/// @param [in] graph
|
||
|
/// @param [in] for_info
|
||
5 years ago
|
/// @param [out] range_input
|
||
|
/// @param [out] abs_delta_input
|
||
|
/// @return Status
|
||
5 years ago
|
///
|
||
5 years ago
|
Status CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info, OutDataAnchorPtr &range_input,
|
||
|
OutDataAnchorPtr &abs_delta_input);
|
||
5 years ago
|
|
||
|
///
|
||
|
/// @brief Create op_desc
|
||
|
/// @param [in] name
|
||
|
/// @param [in] type
|
||
|
/// @param [in] io_equal_flag
|
||
|
/// @return OpDescPtr
|
||
|
///
|
||
|
static OpDescPtr CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag);
|
||
|
|
||
|
///
|
||
|
/// @brief Build while-info
|
||
|
/// @param [in] for_info
|
||
|
/// @param [in] i_input
|
||
5 years ago
|
/// @param [in] range_input
|
||
|
/// @param [in] abs_delta_input
|
||
5 years ago
|
/// @param [out] while_info
|
||
|
/// @return void
|
||
|
///
|
||
5 years ago
|
static void BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input,
|
||
|
const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input,
|
||
5 years ago
|
WhileInfo &while_info);
|
||
|
|
||
|
///
|
||
|
/// @brief Insert while_node
|
||
|
/// @param [in] graph
|
||
|
/// @param [in] name
|
||
|
/// @param [in] while_info
|
||
|
/// @return Status
|
||
|
///
|
||
|
Status InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info);
|
||
|
|
||
|
///
|
||
|
/// @brief Build while link-edge
|
||
|
/// @param [in] while_info
|
||
|
/// @return Status
|
||
|
///
|
||
|
static Status BuildWhileLink(const WhileInfo &while_info);
|
||
|
|
||
|
///
|
||
|
/// @brief Create op_desc for subgraph node
|
||
|
/// @param [in] name
|
||
|
/// @param [in] input_num
|
||
|
/// @param [in] output_num
|
||
|
/// @return OpDescPtr
|
||
|
///
|
||
|
static OpDescPtr CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num);
|
||
|
};
|
||
|
} // namespace ge
|
||
|
#endif // GE_GRAPH_PASSES_FOR_PASS_H
|