You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/passes/for_pass.h

191 lines
5.9 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef GE_GRAPH_PASSES_FOR_PASS_H
#define GE_GRAPH_PASSES_FOR_PASS_H
#include "graph/passes/base_pass.h"
struct ForInfo {
ForInfo() : for_node(nullptr), start(nullptr), limit(nullptr), delta(nullptr), for_body(nullptr) {}
ge::NodePtr for_node;
ge::OutDataAnchorPtr start;
ge::OutDataAnchorPtr limit;
ge::OutDataAnchorPtr delta;
std::string body_name;
ge::ComputeGraphPtr for_body;
std::vector<ge::OutDataAnchorPtr> data_inputs;
std::vector<std::vector<ge::InDataAnchorPtr>> data_outputs;
std::vector<ge::OutControlAnchorPtr> ctrl_inputs;
std::vector<ge::InControlAnchorPtr> ctrl_outputs;
};
struct WhileInfo {
WhileInfo()
: while_node(nullptr), sub_graph_node(nullptr), i(nullptr), abs_delta(nullptr), range(nullptr),
start(nullptr), delta(nullptr), for_body(nullptr), while_cond(nullptr), while_body(nullptr) {}
ge::NodePtr while_node;
ge::NodePtr sub_graph_node;
ge::OutDataAnchorPtr i;
ge::OutDataAnchorPtr abs_delta;
ge::OutDataAnchorPtr range;
ge::OutDataAnchorPtr start;
ge::OutDataAnchorPtr delta;
std::string for_body_name;
ge::ComputeGraphPtr for_body;
ge::ComputeGraphPtr while_cond;
ge::ComputeGraphPtr while_body;
std::vector<ge::OutDataAnchorPtr> data_inputs;
std::vector<std::vector<ge::InDataAnchorPtr>> data_outputs;
std::vector<ge::OutControlAnchorPtr> ctrl_inputs;
std::vector<ge::InControlAnchorPtr> ctrl_outputs;
};
namespace ge {
class ForPass : public BaseNodePass {
public:
Status Run(NodePtr &node) override;
private:
///
/// @brief Build for_info
/// @param [in] root_graph
/// @param [in] node
/// @param [out] for_info
/// @return Status
///
static Status BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info);
///
/// @brief Transfer while_info from for_info
/// @param [in] graph
/// @param [in] for_info
/// @param [out] while_info
/// @return Status
///
Status TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info);
///
/// @brief Build cond_graph for while_node
/// @param [in&out] while_info
/// @return ComputeGraphPtr
///
static ComputeGraphPtr BuildCondGraph(WhileInfo &while_info);
///
/// @brief Build body_graph for while_node
/// @param [in&out] while_info
/// @return ComputeGraphPtr
///
static ComputeGraphPtr BuildBodyGraph(WhileInfo &while_info);
///
/// @brief Update InputMapping for for-body-graph
/// @param [in] while_info
/// @return Status
///
static Status UpdateForBodyInputMapping(const WhileInfo &while_info);
///
/// @brief Find input with index for For node
/// @param [in] node
/// @param [in] index
/// @return OutDataAnchorPtr
///
static OutDataAnchorPtr FindInputWithIndex(const NodePtr &node, uint32_t index);
///
/// @brief Find inputs / outputs for for node
/// @param [in] node
/// @param [out] data_inputs
/// @param [out] data_outputs
/// @param [out] ctrl_inputs
/// @param [out] ctrl_outputs
/// @return Status
///
static Status FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnchorPtr> &data_inputs,
std::vector<std::vector<ge::InDataAnchorPtr>> &data_outputs,
std::vector<ge::OutControlAnchorPtr> &ctrl_inputs,
std::vector<ge::InControlAnchorPtr> &ctrl_outputs);
///
/// @brief Create const op_desc
/// @param [in] name
/// @param [in] value
/// @return OpDescPtr
///
static OpDescPtr CreateConstDesc(const std::string &name, int32_t value);
///
/// @brief Create loop input
/// @param [in] graph
/// @param [in] for_info
/// @param [out] range_input
/// @param [out] abs_delta_input
/// @return Status
///
Status CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info, OutDataAnchorPtr &range_input,
OutDataAnchorPtr &abs_delta_input);
///
/// @brief Create op_desc
/// @param [in] name
/// @param [in] type
/// @param [in] io_equal_flag
/// @return OpDescPtr
///
static OpDescPtr CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag);
///
/// @brief Build while-info
/// @param [in] for_info
/// @param [in] i_input
/// @param [in] range_input
/// @param [in] abs_delta_input
/// @param [out] while_info
/// @return void
///
static void BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input,
const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input,
WhileInfo &while_info);
///
/// @brief Insert while_node
/// @param [in] graph
/// @param [in] name
/// @param [in] while_info
/// @return Status
///
Status InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info);
///
/// @brief Build while link-edge
/// @param [in] while_info
/// @return Status
///
static Status BuildWhileLink(const WhileInfo &while_info);
///
/// @brief Create op_desc for subgraph node
/// @param [in] name
/// @param [in] input_num
/// @param [in] output_num
/// @return OpDescPtr
///
static OpDescPtr CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num);
};
} // namespace ge
#endif //GE_GRAPH_PASSES_FOR_PASS_H