|
|
|
/**
|
|
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
|
|
*
|
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
* you may not use this file except in compliance with the License.
|
|
|
|
* You may obtain a copy of the License at
|
|
|
|
*
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
* See the License for the specific language governing permissions and
|
|
|
|
* limitations under the License.
|
|
|
|
*/
|
|
|
|
|
|
|
|
#include "graph/passes/for_pass.h"
|
|
|
|
#include "common/ge/ge_util.h"
|
|
|
|
#include "common/op/ge_op_utils.h"
|
|
|
|
#include "framework/common/debug/ge_log.h"
|
|
|
|
#include "framework/common/debug/log.h"
|
|
|
|
#include "framework/common/ge_inner_error_codes.h"
|
|
|
|
#include "framework/common/types.h"
|
|
|
|
#include "graph/debug/ge_attr_define.h"
|
|
|
|
#include "graph/utils/graph_utils.h"
|
|
|
|
#include "graph/utils/type_utils.h"
|
|
|
|
#include "graph/utils/node_utils.h"
|
|
|
|
#include "graph/utils/op_desc_utils.h"
|
|
|
|
|
|
|
|
namespace {
|
|
|
|
const uint32_t kWhileIInputIndex = 0;
|
|
|
|
const uint32_t kWhileAbsDeltaInputIndex = 1;
|
|
|
|
const uint32_t kWhileRangeInputIndex = 2;
|
|
|
|
const uint32_t kWhileStartInputIndex = 3;
|
|
|
|
const uint32_t kWhileDeltaInputIndex = 4;
|
|
|
|
const uint32_t kWhileDataInputIndex = 5;
|
|
|
|
const uint32_t kSubgraphLoopVarInputIndex = 0;
|
|
|
|
const uint32_t kSubgraphInputIndex = 1;
|
|
|
|
const uint32_t kWhileOutputIndex = 5;
|
|
|
|
const size_t kIDiffValue = 2;
|
|
|
|
const std::string kAbs = "Abs";
|
|
|
|
}
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
Status ForPass::Run(NodePtr &node) {
|
|
|
|
if (node->GetType() != FOR) {
|
|
|
|
GELOGD("no need for_pass for node %s.", node->GetName().c_str());
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
GELOGI("Begin to transfer for_op to while_op, node:%s.", node->GetName().c_str());
|
|
|
|
ComputeGraphPtr graph = node->GetOwnerComputeGraph();
|
|
|
|
GE_CHECK_NOTNULL(graph);
|
|
|
|
ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph);
|
|
|
|
GE_CHECK_NOTNULL(root_graph);
|
|
|
|
|
|
|
|
ForInfo for_info;
|
|
|
|
GE_CHK_STATUS_RET(BuildForInfo(root_graph, node, for_info),
|
|
|
|
"Build ForInfo failed, node:%s.", node->GetName().c_str());
|
|
|
|
|
|
|
|
WhileInfo while_info;
|
|
|
|
GE_CHK_STATUS_RET(TranWhileInfo(graph, for_info, while_info),
|
|
|
|
"Transfer WhileInfo from ForInfo failed, node:%s.", node->GetName().c_str());
|
|
|
|
|
|
|
|
ComputeGraphPtr cond_graph = BuildCondGraph(while_info);
|
|
|
|
if ((cond_graph == nullptr) || (root_graph->AddSubgraph(cond_graph) != GRAPH_SUCCESS)) {
|
|
|
|
GELOGE(FAILED, "Add while_cond_graph failed, node:%s.", node->GetName().c_str());
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
ComputeGraphPtr body_graph = BuildBodyGraph(while_info);
|
|
|
|
if ((body_graph == nullptr) || (root_graph->AddSubgraph(body_graph) != GRAPH_SUCCESS)) {
|
|
|
|
GELOGE(FAILED, "Add while_body_graph failed, node:%s.", node->GetName().c_str());
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
GE_CHK_STATUS_RET(UpdateForBodyInputMapping(while_info),
|
|
|
|
"Update InputMapping for for-body-graph failed, node:%s.", node->GetName().c_str());
|
|
|
|
|
|
|
|
// for node has and only has one subgraph
|
|
|
|
GE_CHECK_NOTNULL(node->GetOpDesc());
|
|
|
|
node->GetOpDesc()->RemoveSubgraphInstanceName(node->GetOpDesc()->GetSubgraphInstanceName(0));
|
|
|
|
|
|
|
|
GELOGI("Transfer for_op to while_op succ, node:%s.", node->GetName().c_str());
|
|
|
|
return IsolateAndDeleteNode(node, std::vector<int>());
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Build for_info
|
|
|
|
/// @param [in] root_graph
|
|
|
|
/// @param [in] node
|
|
|
|
/// @param [out] for_info
|
|
|
|
/// @return Status
|
|
|
|
///
|
|
|
|
Status ForPass::BuildForInfo(const ComputeGraphPtr &root_graph, const NodePtr &node, ForInfo &for_info) {
|
|
|
|
GELOGI("Begin to build for_info for node %s.", node->GetName().c_str());
|
|
|
|
|
|
|
|
OutDataAnchorPtr start = FindInputWithIndex(node, FOR_START_INPUT);
|
|
|
|
OutDataAnchorPtr limit = FindInputWithIndex(node, FOR_LIMIT_INPUT);
|
|
|
|
OutDataAnchorPtr delta = FindInputWithIndex(node, FOR_DELTA_INPUT);
|
|
|
|
if ((start == nullptr) || (limit == nullptr) || (delta == nullptr)) {
|
|
|
|
GELOGE(FAILED, "BuildForInfo for %s failed: start/limit/delta is NULL.", node->GetName().c_str());
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
std::vector<OutDataAnchorPtr> data_inputs;
|
|
|
|
std::vector<std::vector<InDataAnchorPtr>> data_outputs;
|
|
|
|
std::vector<OutControlAnchorPtr> ctrl_inputs;
|
|
|
|
std::vector<InControlAnchorPtr> ctrl_outputs;
|
|
|
|
if (FindInputsAndOutputs(node, data_inputs, data_outputs, ctrl_inputs, ctrl_outputs) != SUCCESS) {
|
|
|
|
GELOGE(FAILED, "BuildForInfo for %s failed: find inputs/outputs failed.", node->GetName().c_str());
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
NodeUtils::UnlinkAll(*node);
|
|
|
|
|
|
|
|
OpDescPtr op_desc = node->GetOpDesc();
|
|
|
|
GE_CHECK_NOTNULL(op_desc);
|
|
|
|
// For node has and only has one sub_graph
|
|
|
|
std::string for_body_name = op_desc->GetSubgraphInstanceName(0);
|
|
|
|
if (for_body_name.empty()) {
|
|
|
|
GELOGE(FAILED, "BuildForInfo for %s failed: sub_graph_name is empty.", node->GetName().c_str());
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
ComputeGraphPtr for_body = root_graph->GetSubgraph(for_body_name);
|
|
|
|
if (for_body == nullptr) {
|
|
|
|
GELOGE(FAILED, "BuildForInfo for %s failed: for_body_graph is NULL.", node->GetName().c_str());
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
for_info.for_node = node;
|
|
|
|
for_info.start = start;
|
|
|
|
for_info.limit = limit;
|
|
|
|
for_info.delta = delta;
|
|
|
|
for_info.body_name = for_body_name;
|
|
|
|
for_info.for_body = for_body;
|
|
|
|
for_info.data_inputs = std::move(data_inputs);
|
|
|
|
for_info.data_outputs = std::move(data_outputs);
|
|
|
|
for_info.ctrl_inputs = std::move(ctrl_inputs);
|
|
|
|
for_info.ctrl_outputs = std::move(ctrl_outputs);
|
|
|
|
|
|
|
|
GELOGI("Build for_info for node %s success.", node->GetName().c_str());
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Find input with index for For node
|
|
|
|
/// @param [in] node
|
|
|
|
/// @param [in] index
|
|
|
|
/// @return OutDataAnchorPtr
|
|
|
|
///
|
|
|
|
OutDataAnchorPtr ForPass::FindInputWithIndex(const NodePtr &node, uint32_t index) {
|
|
|
|
if (node == nullptr) {
|
|
|
|
GELOGE(FAILED, "FindInputWithIndex failed: node is NULL.");
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
|
|
|
|
if (in_data_anchor == nullptr) {
|
|
|
|
GELOGE(FAILED, "FindInputWithIndex %s:%u failed: in_data_anchor is NULL.", node->GetName().c_str(), index);
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
return in_data_anchor->GetPeerOutAnchor();
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Find inputs / outputs for for node
|
|
|
|
/// @param [in] node
|
|
|
|
/// @param [out] data_inputs
|
|
|
|
/// @param [out] data_outputs
|
|
|
|
/// @param [out] ctrl_inputs
|
|
|
|
/// @param [out] ctrl_outputs
|
|
|
|
/// @return Status
|
|
|
|
///
|
|
|
|
Status ForPass::FindInputsAndOutputs(const NodePtr &node, std::vector<OutDataAnchorPtr> &data_inputs,
|
|
|
|
std::vector<std::vector<ge::InDataAnchorPtr>> &data_outputs,
|
|
|
|
std::vector<ge::OutControlAnchorPtr> &ctrl_inputs,
|
|
|
|
std::vector<ge::InControlAnchorPtr> &ctrl_outputs) {
|
|
|
|
GE_CHECK_NOTNULL(node);
|
|
|
|
|
|
|
|
uint32_t input_data_num = node->GetAllInDataAnchorsSize();
|
|
|
|
for (uint32_t index = FOR_DATA_INPUT; index < input_data_num; index++) {
|
|
|
|
InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(index);
|
|
|
|
GE_CHECK_NOTNULL(in_data_anchor);
|
|
|
|
data_inputs.emplace_back(in_data_anchor->GetPeerOutAnchor());
|
|
|
|
}
|
|
|
|
|
|
|
|
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
|
|
|
|
std::vector<ge::InDataAnchorPtr> peer_in_data_anchors;
|
|
|
|
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
|
|
|
|
peer_in_data_anchors.emplace_back(peer_in_data_anchor);
|
|
|
|
}
|
|
|
|
data_outputs.emplace_back(peer_in_data_anchors);
|
|
|
|
}
|
|
|
|
|
|
|
|
InControlAnchorPtr in_ctrl_anchor = node->GetInControlAnchor();
|
|
|
|
GE_CHECK_NOTNULL(in_ctrl_anchor);
|
|
|
|
for (const auto &peer_out_ctrl_anchor : in_ctrl_anchor->GetPeerOutControlAnchors()) {
|
|
|
|
ctrl_inputs.emplace_back(peer_out_ctrl_anchor);
|
|
|
|
}
|
|
|
|
|
|
|
|
OutControlAnchorPtr out_ctrl_anchor = node->GetOutControlAnchor();
|
|
|
|
GE_CHECK_NOTNULL(out_ctrl_anchor);
|
|
|
|
for (const auto &peer_in_ctrl_anchor : out_ctrl_anchor->GetPeerInControlAnchors()) {
|
|
|
|
ctrl_outputs.emplace_back(peer_in_ctrl_anchor);
|
|
|
|
}
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Transfer while_info from for_info
|
|
|
|
/// @param [in] graph
|
|
|
|
/// @param [in] for_info
|
|
|
|
/// @param [out] while_info
|
|
|
|
/// @return Status
|
|
|
|
///
|
|
|
|
Status ForPass::TranWhileInfo(const ComputeGraphPtr &graph, const ForInfo &for_info, WhileInfo &while_info) {
|
|
|
|
std::string for_name = for_info.for_node->GetName();
|
|
|
|
GELOGI("Begin to transfer for_info to while_info, node:%s.", for_name.c_str());
|
|
|
|
|
|
|
|
std::string i_name = for_name + "_i";
|
|
|
|
NodePtr i_node = graph->AddNode(CreateConstDesc(i_name, 0));
|
|
|
|
if (i_node == nullptr) {
|
|
|
|
GELOGE(FAILED, "TranWhileInfo failed: create i_node failed.");
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
AddRePassNode(i_node);
|
|
|
|
|
|
|
|
std::string identity_name = i_name + "_Identity";
|
|
|
|
NodePtr identity_node = graph->AddNode(CreateOpDesc(identity_name, IDENTITY, true));
|
|
|
|
// Const node has and only has one output, Identity node has and only has one input
|
|
|
|
if ((identity_node == nullptr) ||
|
|
|
|
(GraphUtils::AddEdge(i_node->GetOutDataAnchor(0), identity_node->GetInDataAnchor(0)) != GRAPH_SUCCESS)) {
|
|
|
|
GELOGE(FAILED, "TranWhileInfo failed: Add data-edge %s:0->%s:0 failed.", i_name.c_str(), identity_name.c_str());
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
AddRePassNode(identity_node);
|
|
|
|
|
|
|
|
// Identity node has and only has one output
|
|
|
|
OutDataAnchorPtr i_input = identity_node->GetOutDataAnchor(0);
|
|
|
|
if (i_input == nullptr) {
|
|
|
|
GELOGE(FAILED, "TranWhileInfo failed: i_input is NULL.");
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
OutDataAnchorPtr range_input = nullptr;
|
|
|
|
OutDataAnchorPtr abs_delta_input = nullptr;
|
|
|
|
if (CreateLoopInput(graph, for_info, range_input, abs_delta_input) != SUCCESS) {
|
|
|
|
GELOGE(FAILED, "TranWhileInfo failed: create loop input failed.");
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
BuildWhileInfo(for_info, i_input, range_input, abs_delta_input, while_info);
|
|
|
|
|
|
|
|
if (InsertWhileNode(graph, for_name + "_While", while_info) != SUCCESS) {
|
|
|
|
GELOGE(FAILED, "TranWhileInfo failed: insert while node failed.");
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
GELOGI("Transfer for_info to while_info succ, for_node:%s, while_node:%s.",
|
|
|
|
for_name.c_str(), while_info.while_node->GetName().c_str());
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Create const op_desc
|
|
|
|
/// @param [in] name
|
|
|
|
/// @param [in] value
|
|
|
|
/// @return OpDescPtr
|
|
|
|
///
|
|
|
|
OpDescPtr ForPass::CreateConstDesc(const std::string &name, int32_t value) {
|
|
|
|
OpDescPtr const_op_desc = MakeShared<OpDesc>(name, CONSTANT);
|
|
|
|
if (const_op_desc == nullptr) {
|
|
|
|
GELOGE(FAILED, "Create op_desc failed, const:%s.", name.c_str());
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_INT32);
|
|
|
|
GeTensorPtr const_value = MakeShared<GeTensor>(data_desc, reinterpret_cast<uint8_t *>(&value), sizeof(int32_t));
|
|
|
|
if (const_value == nullptr) {
|
|
|
|
GELOGE(FAILED, "Create tensor failed, const:%s.", name.c_str());
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (!AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value)) {
|
|
|
|
GELOGE(FAILED, "Set ATTR_NAME_WEIGHTS failed, const:%s.", name.c_str());
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
if (const_op_desc->AddOutputDesc("y", data_desc) != GRAPH_SUCCESS) {
|
|
|
|
GELOGE(FAILED, "Add output desc failed, const:%s.", name.c_str());
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
return const_op_desc;
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Create loop node
|
|
|
|
/// @param [in] graph
|
|
|
|
/// @param [in] for_info
|
|
|
|
/// @param [out] range_input
|
|
|
|
/// @param [out] abs_delta_input
|
|
|
|
/// @return Status
|
|
|
|
///
|
|
|
|
Status ForPass::CreateLoopInput(const ComputeGraphPtr &graph, const ForInfo &for_info,
|
|
|
|
OutDataAnchorPtr &range_input, OutDataAnchorPtr &abs_delta_input) {
|
|
|
|
std::string for_name = for_info.for_node->GetName();
|
|
|
|
GELOGD("Begin to create loop_count input, node:%s", for_name.c_str());
|
|
|
|
|
|
|
|
OutDataAnchorPtr start = for_info.start;
|
|
|
|
OutDataAnchorPtr limit = for_info.limit;
|
|
|
|
OutDataAnchorPtr delta = for_info.delta;
|
|
|
|
|
|
|
|
std::string sub_name_0 = for_name + "_Sub_0";
|
|
|
|
std::string abs_name_0 = for_name + "_Abs_0";
|
|
|
|
std::string abs_name_1 = for_name + "_Abs_1";
|
|
|
|
|
|
|
|
// i * |delta| < |limit-start|
|
|
|
|
PartialGraphBuilder graph_builder;
|
|
|
|
graph_builder.SetOwnerGraph(graph)
|
|
|
|
.AddExistNode(for_info.start->GetOwnerNode())
|
|
|
|
.AddExistNode(for_info.limit->GetOwnerNode())
|
|
|
|
.AddExistNode(for_info.delta->GetOwnerNode())
|
|
|
|
.AddNode(CreateOpDesc(sub_name_0, SUB, false))
|
|
|
|
.AddNode(CreateOpDesc(abs_name_0, kAbs, true))
|
|
|
|
.AddNode(CreateOpDesc(abs_name_1, kAbs, true))
|
|
|
|
.AddDataLink(delta->GetOwnerNode()->GetName(), delta->GetIdx(), abs_name_0, 0)
|
|
|
|
.AddDataLink(limit->GetOwnerNode()->GetName(), limit->GetIdx(), sub_name_0, 0)
|
|
|
|
.AddDataLink(start->GetOwnerNode()->GetName(), start->GetIdx(), sub_name_0, 1)
|
|
|
|
.AddDataLink(sub_name_0, 0, abs_name_1, 0);
|
|
|
|
|
|
|
|
graphStatus error_code = GRAPH_SUCCESS;
|
|
|
|
std::string error_msg;
|
|
|
|
if ((graph_builder.Build(error_code, error_msg) == nullptr) || (error_code != GRAPH_SUCCESS)) {
|
|
|
|
GELOGE(FAILED, "Create loop_count node failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
// Add repass_nodes
|
|
|
|
for (auto &node : graph_builder.GetAllNodes()) {
|
|
|
|
AddRePassNode(node);
|
|
|
|
}
|
|
|
|
|
|
|
|
NodePtr abs_delta_node = graph_builder.GetNode(abs_name_0);
|
|
|
|
NodePtr loop_count_node = graph_builder.GetNode(abs_name_1);
|
|
|
|
if ((abs_delta_node == nullptr) || (loop_count_node == nullptr)) {
|
|
|
|
GELOGE(FAILED, "Create loop node failed: node is NULL.");
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
GELOGD("Create loop_range input succ, node:%s", for_name.c_str());
|
|
|
|
// abs_node has and only has one output
|
|
|
|
abs_delta_input = abs_delta_node->GetOutDataAnchor(0);
|
|
|
|
range_input = loop_count_node->GetOutDataAnchor(0);
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Create op_desc
|
|
|
|
/// @param [in] name
|
|
|
|
/// @param [in] type
|
|
|
|
/// @param [in] io_equal_flag
|
|
|
|
/// @return OpDescPtr
|
|
|
|
///
|
|
|
|
OpDescPtr ForPass::CreateOpDesc(const std::string &name, const std::string &type, bool io_equal_flag) {
|
|
|
|
OpDescBuilder op_desc_builder(name, type);
|
|
|
|
if (io_equal_flag) {
|
|
|
|
op_desc_builder.AddInput("x")
|
|
|
|
.AddOutput("y");
|
|
|
|
} else {
|
|
|
|
op_desc_builder.AddInput("x1")
|
|
|
|
.AddInput("x2")
|
|
|
|
.AddOutput("y");
|
|
|
|
}
|
|
|
|
|
|
|
|
return op_desc_builder.Build();
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Build while-info
|
|
|
|
/// @param [in] for_info
|
|
|
|
/// @param [in] i_input
|
|
|
|
/// @param [in] range_input
|
|
|
|
/// @param [in] abs_delta_input
|
|
|
|
/// @param [out] while_info
|
|
|
|
/// @return void
|
|
|
|
///
|
|
|
|
void ForPass::BuildWhileInfo(const ForInfo &for_info, const OutDataAnchorPtr &i_input,
|
|
|
|
const OutDataAnchorPtr &range_input, const OutDataAnchorPtr &abs_delta_input,
|
|
|
|
WhileInfo &while_info) {
|
|
|
|
while_info.i = i_input;
|
|
|
|
while_info.abs_delta = abs_delta_input;
|
|
|
|
while_info.range = range_input;
|
|
|
|
while_info.start = for_info.start;
|
|
|
|
while_info.delta = for_info.delta;
|
|
|
|
while_info.for_body_name = for_info.body_name;
|
|
|
|
while_info.for_body = for_info.for_body;
|
|
|
|
while_info.data_inputs.emplace_back(while_info.i);
|
|
|
|
while_info.data_inputs.emplace_back(while_info.abs_delta);
|
|
|
|
while_info.data_inputs.emplace_back(while_info.range);
|
|
|
|
while_info.data_inputs.emplace_back(while_info.start);
|
|
|
|
while_info.data_inputs.emplace_back(while_info.delta);
|
|
|
|
for (auto &item : for_info.data_inputs) {
|
|
|
|
while_info.data_inputs.emplace_back(item);
|
|
|
|
}
|
|
|
|
for (auto &item : for_info.data_outputs) {
|
|
|
|
while_info.data_outputs.emplace_back(item);
|
|
|
|
}
|
|
|
|
for (auto &item : for_info.ctrl_inputs) {
|
|
|
|
while_info.ctrl_inputs.emplace_back(item);
|
|
|
|
}
|
|
|
|
for (auto &item : for_info.ctrl_outputs) {
|
|
|
|
while_info.ctrl_outputs.emplace_back(item);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Insert while_node
|
|
|
|
/// @param [in] graph
|
|
|
|
/// @param [in] name
|
|
|
|
/// @param [in&out] while_info
|
|
|
|
/// @return Status
|
|
|
|
///
|
|
|
|
Status ForPass::InsertWhileNode(const ComputeGraphPtr &graph, const std::string &name, WhileInfo &while_info) {
|
|
|
|
GELOGD("Begin to create while node, name:%s.", name.c_str());
|
|
|
|
|
|
|
|
size_t arg_num = while_info.data_inputs.size();
|
|
|
|
OpDescBuilder op_desc_builder(name, WHILE);
|
|
|
|
OpDescPtr op_desc = op_desc_builder.AddDynamicInput("input", arg_num).AddDynamicOutput("output", arg_num).Build();
|
|
|
|
if (op_desc == nullptr) {
|
|
|
|
GELOGE(FAILED, "Create while op_desc failed, name:%s.", name.c_str());
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
NodePtr while_node = graph->AddNode(op_desc);
|
|
|
|
if (while_node == nullptr) {
|
|
|
|
GELOGE(FAILED, "Create while node failed, name:%s.", name.c_str());
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
AddRePassNode(while_node);
|
|
|
|
|
|
|
|
while_info.while_node = while_node;
|
|
|
|
if (BuildWhileLink(while_info) != SUCCESS) {
|
|
|
|
GELOGE(FAILED, "Build while link-edge failed, name:%s.", name.c_str());
|
|
|
|
return FAILED;
|
|
|
|
}
|
|
|
|
|
|
|
|
GELOGD("Create while node succ, name:%s.", name.c_str());
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Build while link-edge
|
|
|
|
/// @param [in] while_info
|
|
|
|
/// @return Status
|
|
|
|
///
|
|
|
|
Status ForPass::BuildWhileLink(const WhileInfo &while_info) {
|
|
|
|
NodePtr while_node = while_info.while_node;
|
|
|
|
GE_CHECK_NOTNULL(while_node);
|
|
|
|
|
|
|
|
size_t input_num = while_info.data_inputs.size();
|
|
|
|
for (size_t i = 0; i < input_num; i++) {
|
|
|
|
InDataAnchorPtr in_data_anchor = while_node->GetInDataAnchor(i);
|
|
|
|
GE_CHECK_NOTNULL(in_data_anchor);
|
|
|
|
OutDataAnchorPtr peer_out_anchor = while_info.data_inputs[i];
|
|
|
|
if (peer_out_anchor == nullptr) {
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_data_anchor),
|
|
|
|
"Add data-edge %s:%d->%s:%d failed.",
|
|
|
|
peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(),
|
|
|
|
while_node->GetName().c_str(), i);
|
|
|
|
}
|
|
|
|
|
|
|
|
size_t output_num = while_info.data_outputs.size();
|
|
|
|
for (size_t i = 0; i < output_num; i++) {
|
|
|
|
OutDataAnchorPtr out_data_anchor = while_node->GetOutDataAnchor(static_cast<int>(i + kWhileOutputIndex));
|
|
|
|
GE_CHECK_NOTNULL(out_data_anchor);
|
|
|
|
for (auto &peer_in_anchor : while_info.data_outputs[i]) {
|
|
|
|
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_data_anchor, peer_in_anchor),
|
|
|
|
"Add data-edge %s:%d->%s:%d failed.",
|
|
|
|
while_node->GetName().c_str(), i + kWhileOutputIndex,
|
|
|
|
peer_in_anchor->GetOwnerNode()->GetName().c_str(), peer_in_anchor->GetIdx());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
InControlAnchorPtr in_ctrl_anchor = while_node->GetInControlAnchor();
|
|
|
|
GE_CHECK_NOTNULL(in_ctrl_anchor);
|
|
|
|
for (auto &peer_out_anchor : while_info.ctrl_inputs) {
|
|
|
|
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(peer_out_anchor, in_ctrl_anchor),
|
|
|
|
"Add ctrl-edge %s->%s failed.",
|
|
|
|
peer_out_anchor->GetOwnerNode()->GetName().c_str(),
|
|
|
|
in_ctrl_anchor->GetOwnerNode()->GetName().c_str());
|
|
|
|
}
|
|
|
|
|
|
|
|
OutControlAnchorPtr out_ctrl_anchor = while_node->GetOutControlAnchor();
|
|
|
|
GE_CHECK_NOTNULL(out_ctrl_anchor);
|
|
|
|
for (auto &peer_in_anchor : while_info.ctrl_outputs) {
|
|
|
|
GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_ctrl_anchor, peer_in_anchor),
|
|
|
|
"Add ctrl-edge %s->%s failed.",
|
|
|
|
out_ctrl_anchor->GetOwnerNode()->GetName().c_str(),
|
|
|
|
peer_in_anchor->GetOwnerNode()->GetName().c_str());
|
|
|
|
}
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Build cond_graph for while_node
|
|
|
|
/// @param [in&out] while_info
|
|
|
|
/// @return ComputeGraphPtr
|
|
|
|
///
|
|
|
|
ComputeGraphPtr ForPass::BuildCondGraph(WhileInfo &while_info) {
|
|
|
|
std::string cond_name = while_info.for_body_name + "_Cond";
|
|
|
|
CompleteGraphBuilder graph_builder(cond_name);
|
|
|
|
|
|
|
|
// Add parent node
|
|
|
|
graph_builder.SetParentNode(while_info.while_node);
|
|
|
|
|
|
|
|
// Add Node
|
|
|
|
const std::string mul_name = "Mul";
|
|
|
|
graph_builder.AddNode(CreateOpDesc(mul_name, MUL, false));
|
|
|
|
const std::string less_name = "Less";
|
|
|
|
graph_builder.AddNode(CreateOpDesc(less_name, LESS, false));
|
|
|
|
|
|
|
|
// Set Input
|
|
|
|
graph_builder.SetInput(kWhileIInputIndex, { mul_name }, { 0 })
|
|
|
|
.SetInput(kWhileAbsDeltaInputIndex, { mul_name }, { 1 })
|
|
|
|
.SetInput(kWhileRangeInputIndex, { less_name }, { 1 })
|
|
|
|
.SetUselessInput(kWhileStartInputIndex)
|
|
|
|
.SetUselessInput(kWhileDeltaInputIndex);
|
|
|
|
size_t input_num = while_info.data_inputs.size();
|
|
|
|
for (size_t i = kWhileDataInputIndex; i < input_num; i++) {
|
|
|
|
graph_builder.SetUselessInput(i);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Add Output
|
|
|
|
graph_builder.AddOutput(less_name, 0);
|
|
|
|
|
|
|
|
// Add Edges
|
|
|
|
graph_builder.AddDataLink(mul_name, 0, less_name, 0);
|
|
|
|
|
|
|
|
// Add Input-Mapping
|
|
|
|
std::map<uint32_t, uint32_t> input_mapping;
|
|
|
|
for (size_t i = 0; i < input_num; i++) {
|
|
|
|
input_mapping[i] = i;
|
|
|
|
}
|
|
|
|
graph_builder.SetInputMapping(input_mapping);
|
|
|
|
|
|
|
|
graphStatus error_code = GRAPH_SUCCESS;
|
|
|
|
std::string error_msg;
|
|
|
|
ComputeGraphPtr cond_graph = graph_builder.Build(error_code, error_msg);
|
|
|
|
if (cond_graph == nullptr) {
|
|
|
|
GELOGE(FAILED, "Build cond_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
|
|
|
|
while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_COND);
|
|
|
|
while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, cond_name);
|
|
|
|
while_info.while_cond = cond_graph;
|
|
|
|
return cond_graph;
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Build body_graph for while_node
|
|
|
|
/// @param [in&out] while_info
|
|
|
|
/// @return ComputeGraphPtr
|
|
|
|
///
|
|
|
|
ComputeGraphPtr ForPass::BuildBodyGraph(WhileInfo &while_info) {
|
|
|
|
std::string body_name = while_info.for_body_name + "_Body";
|
|
|
|
CompleteGraphBuilder graph_builder(body_name);
|
|
|
|
|
|
|
|
// Add parent node
|
|
|
|
graph_builder.SetParentNode(while_info.while_node);
|
|
|
|
|
|
|
|
// Add calculation nodes
|
|
|
|
std::string const_name = "Const";
|
|
|
|
std::string add_name_0 = "Add_0";
|
|
|
|
std::string mul_name = "Mul";
|
|
|
|
std::string add_name_1 = "Add_1";
|
|
|
|
graph_builder.AddNode(CreateConstDesc(const_name, 1))
|
|
|
|
.AddNode(CreateOpDesc(add_name_0, ADD, false))
|
|
|
|
.AddNode(CreateOpDesc(mul_name, MUL, false))
|
|
|
|
.AddNode(CreateOpDesc(add_name_1, ADD, false));
|
|
|
|
|
|
|
|
// Add Subgraph node
|
|
|
|
auto input_num = static_cast<uint32_t>(while_info.data_inputs.size());
|
|
|
|
std::string sub_graph_node_name = while_info.for_body_name;
|
|
|
|
uint32_t sub_graph_input_num = input_num - kWhileDataInputIndex + kSubgraphInputIndex;
|
|
|
|
auto sub_graph_output_num = static_cast<uint32_t>(while_info.data_outputs.size());
|
|
|
|
graph_builder.AddNode(CreateSubgraphOpDesc(sub_graph_node_name, sub_graph_input_num, sub_graph_output_num));
|
|
|
|
|
|
|
|
// Set Input
|
|
|
|
graph_builder.SetInput(kWhileIInputIndex, { add_name_0, mul_name }, { 0, 0 })
|
|
|
|
.SetUselessInput(kWhileAbsDeltaInputIndex)
|
|
|
|
.SetUselessInput(kWhileRangeInputIndex)
|
|
|
|
.SetInput(kWhileStartInputIndex, { add_name_1 }, { 0 })
|
|
|
|
.SetInput(kWhileDeltaInputIndex, { mul_name }, { 1 });
|
|
|
|
for (uint32_t i = 0; i < input_num - kWhileDataInputIndex; i++) {
|
|
|
|
graph_builder.SetInput(i + kWhileDataInputIndex, { sub_graph_node_name }, { i + kSubgraphInputIndex });
|
|
|
|
}
|
|
|
|
|
|
|
|
// Add Outputs
|
|
|
|
graph_builder.AddOutput(add_name_0, 0);
|
|
|
|
for (uint32_t i = kWhileAbsDeltaInputIndex; i < kWhileDataInputIndex; i++) {
|
|
|
|
graph_builder.AddOutput("Data_" + std::to_string(i), 0);
|
|
|
|
}
|
|
|
|
for (uint32_t i = 0; i < sub_graph_output_num; i++) {
|
|
|
|
graph_builder.AddOutput(sub_graph_node_name, i);
|
|
|
|
}
|
|
|
|
|
|
|
|
// Add Edges
|
|
|
|
graph_builder.AddDataLink(const_name, 0, add_name_0, 1)
|
|
|
|
.AddDataLink(mul_name, 0, add_name_1, 1)
|
|
|
|
.AddDataLink(add_name_1, 0, sub_graph_node_name, kSubgraphLoopVarInputIndex);
|
|
|
|
|
|
|
|
// Add Input-Mapping
|
|
|
|
std::map<uint32_t, uint32_t> input_mapping;
|
|
|
|
for (size_t i = 0; i < input_num; i++) {
|
|
|
|
input_mapping[i] = i;
|
|
|
|
}
|
|
|
|
graph_builder.SetInputMapping(input_mapping);
|
|
|
|
|
|
|
|
// Add outputMapping
|
|
|
|
std::map<uint32_t, uint32_t> output_mapping;
|
|
|
|
for (size_t i = 0; i < sub_graph_output_num + kWhileOutputIndex; i++) {
|
|
|
|
output_mapping[i] = i;
|
|
|
|
}
|
|
|
|
graph_builder.SetOutputMapping(output_mapping);
|
|
|
|
|
|
|
|
graphStatus error_code = GRAPH_SUCCESS;
|
|
|
|
std::string error_msg;
|
|
|
|
ComputeGraphPtr body_graph = graph_builder.Build(error_code, error_msg);
|
|
|
|
if (body_graph == nullptr) {
|
|
|
|
GELOGE(FAILED, "Build body_graph failed: error_code:%u, error_msg:%s.", error_code, error_msg.c_str());
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
NodePtr sub_graph_node = graph_builder.GetNode(sub_graph_node_name);
|
|
|
|
if (sub_graph_node == nullptr) {
|
|
|
|
GELOGE(FAILED, "Get sub_graph_node failed: name:%s.", sub_graph_node_name.c_str());
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
while_info.sub_graph_node = sub_graph_node;
|
|
|
|
|
|
|
|
size_t index = while_info.while_node->GetOpDesc()->GetSubgraphInstanceNames().size();
|
|
|
|
while_info.while_node->GetOpDesc()->AddSubgraphName(ATTR_NAME_WHILE_BODY);
|
|
|
|
while_info.while_node->GetOpDesc()->SetSubgraphInstanceName(index, body_name);
|
|
|
|
while_info.while_body = body_graph;
|
|
|
|
return body_graph;
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Create op_desc for subgraph node
|
|
|
|
/// @param [in] name
|
|
|
|
/// @param [in] input_num
|
|
|
|
/// @param [in] output_num
|
|
|
|
/// @return OpDescPtr
|
|
|
|
///
|
|
|
|
OpDescPtr ForPass::CreateSubgraphOpDesc(const std::string &name, uint32_t input_num, uint32_t output_num) {
|
|
|
|
OpDescBuilder op_desc_builder(name, PARTITIONEDCALL);
|
|
|
|
op_desc_builder.AddDynamicInput("args", input_num)
|
|
|
|
.AddDynamicOutput("output", output_num);
|
|
|
|
|
|
|
|
OpDescPtr op_desc = op_desc_builder.Build();
|
|
|
|
if (op_desc == nullptr) {
|
|
|
|
GELOGE(FAILED, "Create op_desc for subgraph node failed, name:%s.", name.c_str());
|
|
|
|
return nullptr;
|
|
|
|
}
|
|
|
|
|
|
|
|
size_t index = op_desc->GetSubgraphInstanceNames().size();
|
|
|
|
op_desc->AddSubgraphName("f");
|
|
|
|
op_desc->SetSubgraphInstanceName(index, name);
|
|
|
|
return op_desc;
|
|
|
|
}
|
|
|
|
|
|
|
|
///
|
|
|
|
/// @brief Update InputMapping for for-body-graph
|
|
|
|
/// @param [in] while_info
|
|
|
|
/// @return Status
|
|
|
|
///
|
|
|
|
Status ForPass::UpdateForBodyInputMapping(const WhileInfo &while_info) {
|
|
|
|
ComputeGraphPtr for_body = while_info.for_body;
|
|
|
|
GE_CHECK_NOTNULL(for_body);
|
|
|
|
|
|
|
|
// index_of_cur_graph_node_input -> index_of_new_graph_node_input
|
|
|
|
std::map<uint32_t, uint32_t> input_mapping;
|
|
|
|
size_t input_num = while_info.data_inputs.size() - kWhileDataInputIndex + FOR_DATA_INPUT;
|
|
|
|
for (size_t i = 0; i < input_num; i++) {
|
|
|
|
if (i == FOR_START_INPUT) {
|
|
|
|
input_mapping[i] = i;
|
|
|
|
} else if ((i == FOR_LIMIT_INPUT) || (i == FOR_DELTA_INPUT)) {
|
|
|
|
continue;
|
|
|
|
} else {
|
|
|
|
input_mapping[i] = i - kIDiffValue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
for_body->UpdateInputMapping(input_mapping);
|
|
|
|
for_body->SetParentNode(while_info.sub_graph_node);
|
|
|
|
for_body->SetParentGraph(while_info.while_body);
|
|
|
|
|
|
|
|
return SUCCESS;
|
|
|
|
}
|
|
|
|
} // namespace ge
|
|
|
|
|