You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/ge/graph/passes/variable_prepare_op_pass.cc

429 lines
19 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "graph/passes/variable_prepare_op_pass.h"
#include <map>
#include <memory>
#include <string>
#include "common/ge/ge_util.h"
#include "external/graph/graph.h"
#include "framework/common/debug/ge_log.h"
#include "graph/common/omg_util.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/node.h"
#include "graph/utils/tensor_utils.h"
namespace ge {
std::map<std::string, std::map<int, std::vector<int>>> VariablePrepareOpPass::ref_node_without_prototype_map_ = {
{REFSWITCH, {{0, {0, 1}}}}};
Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) {
GE_CHECK_NOTNULL(graph);
for (const auto &node : graph->GetDirectNode()) {
auto iter = ref_input_output_map_.find(node->GetType());
if (iter == ref_input_output_map_.end()) {
GenerateRefTypeAndInputOutputMap(node);
}
}
if (ref_input_output_map_.empty()) {
GELOGI("No need to add variable_ref.");
return SUCCESS;
}
for (auto &node : graph->GetDirectNode()) {
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue);
if (node->GetOpDesc()->GetType() == VARIABLE) {
Status ret = DealVariableNode(node);
if (ret != SUCCESS) {
GELOGE(ret, "variable add back edge failed");
return FAILED;
}
}
}
return SUCCESS;
}
Status VariablePrepareOpPass::DealVariableNode(NodePtr &var_node) {
GE_CHECK_NOTNULL(var_node);
for (auto &dst_node_and_inanchor : var_node->GetOutDataNodesAndAnchors()) {
NodePtr dst_node = dst_node_and_inanchor.first;
GE_CHECK_NOTNULL(dst_node);
InDataAnchorPtr dst_in_data_anchor = dst_node_and_inanchor.second;
GE_CHECK_NOTNULL(dst_in_data_anchor);
auto input_index = dst_in_data_anchor->GetIdx();
vector<int> ref_output_indexes;
GetWritableNodeOutIndex(dst_node, input_index, ref_output_indexes);
if (!ref_output_indexes.empty()) {
for (auto output_index : ref_output_indexes) {
Status ret = DealWritableNode(dst_node, input_index, output_index, var_node);
if (ret != SUCCESS) {
GELOGE(FAILED, "Deal writable node[%s] failed, input index: %d, var: %s.", dst_node->GetName().c_str(),
input_index, var_node->GetName().c_str());
return FAILED;
}
}
}
}
return SUCCESS;
}
Status VariablePrepareOpPass::DealWritableNode(const ge::NodePtr &writable_node, int input_index, int output_index,
const ge::NodePtr &var_node) {
// Find the last ref node:
// If the ref input has corresponding output, add variable ref after it.
// If the ref input has no corresponding output, insert RefIdentity and variable ref before it.
// If ref node with control output was found while finding the last ref node, add variable ref after it.
std::stack<pair<NodePtr, pair<int, int>>> nodes_to_check;
nodes_to_check.push({writable_node, {input_index, output_index}});
while (!nodes_to_check.empty()) {
auto node_index = nodes_to_check.top();
nodes_to_check.pop();
auto cur_node = node_index.first;
int cur_input_index = node_index.second.first;
int cur_output_index = node_index.second.second;
// Collect ref node after cur node
const auto nodes_size = nodes_to_check.size();
// Add peer ref output node of current node to stack
CHECK_FALSE_EXEC(GetPeerNodeOfRefOutput(cur_node, cur_output_index, nodes_to_check) == SUCCESS,
GELOGE(FAILED, "GetPeerNodeOfRefOutput for node[%s] failed.", cur_node->GetName().c_str());
return FAILED);
if (nodes_size == nodes_to_check.size()) {
const auto &op_desc = cur_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
// No need to add variable_ref for framework op
if (op_desc->GetType() == FRAMEWORKOP) {
GELOGD("No need to add variable_ref for frameworkop");
continue;
}
if (static_cast<uint32_t>(cur_output_index) < op_desc->GetOutputsSize()) {
// Add variable ref node after ref output for final ref node
CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, cur_output_index) == SUCCESS,
GELOGE(FAILED, "Add variable ref failed");
return FAILED);
} else {
// Insert variable ref node before ref input without corresponding ref output
CHECK_FALSE_EXEC(InsertVariableRef(cur_node, cur_input_index, var_node) == SUCCESS,
GELOGE(FAILED, "Insert variable ref and ref identity failed");
return FAILED);
}
continue;
}
if (HasControlOut(cur_node)) {
// Add variable ref node after ref output for ref node has control output.
CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, cur_output_index) == SUCCESS,
GELOGE(FAILED, "Add variable ref failed");
return FAILED);
}
}
return SUCCESS;
}
Status VariablePrepareOpPass::GetPeerNodeOfRefOutput(const ge::NodePtr &node, int output_index,
std::stack<pair<NodePtr, pair<int, int>>> &nodes) {
if (output_index < 0) {
GELOGE(PARAM_INVALID, "Invalid ref output index: %s-%d.", node->GetName().c_str(), output_index);
return PARAM_INVALID;
}
const auto &op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
if (static_cast<uint32_t>(output_index) == op_desc->GetOutputsSize()) {
return SUCCESS;
}
if (output_index >= static_cast<int>(node->GetAllOutDataAnchorsSize())) {
GELOGW("Can not get %d th output anchor of %s", output_index, node->GetName().c_str());
return SUCCESS;
}
const auto &out_anchor = node->GetOutDataAnchor(output_index);
GE_CHECK_NOTNULL(out_anchor);
for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
auto peer_node = peer_in_anchor->GetOwnerNode();
if (peer_node == nullptr) {
continue;
}
const int peer_in_index = peer_in_anchor->GetIdx();
vector<int> ref_output_indexes;
GetWritableNodeOutIndex(peer_node, peer_in_index, ref_output_indexes);
for (auto ref_output_index : ref_output_indexes) {
nodes.push({peer_node, {peer_in_index, ref_output_index}});
}
}
return SUCCESS;
}
Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, const ge::NodePtr &var_node, int index) {
GE_CHECK_NOTNULL(final_writable_node);
GE_CHECK_NOTNULL(var_node);
if (index >= static_cast<int>(final_writable_node->GetAllOutDataAnchorsSize())) {
GELOGW("Can not get %d th output anchor of %s", index, final_writable_node->GetName().c_str());
return SUCCESS;
}
// Check for duplicate creation
OutDataAnchorPtr out_anchor = final_writable_node->GetOutDataAnchor(index);
GE_CHECK_NOTNULL(out_anchor);
for (const auto &peer_anchor : out_anchor->GetPeerAnchors()) {
NodePtr peer_node = peer_anchor->GetOwnerNode();
OpDescPtr peer_opdesc = peer_node->GetOpDesc();
GE_CHECK_NOTNULL(peer_opdesc);
string src_var_name;
(void)ge::AttrUtils::GetStr(peer_opdesc, REF_VAR_SRC_VAR_NAME, src_var_name);
if (peer_node->GetType() == VARIABLE && var_node->GetName() == src_var_name) {
GELOGI("The corresponding variable_ref has been added to this connection.");
return SUCCESS;
}
}
// creat variable_ref
std::stringstream variable_ref_name;
variable_ref_name << "_TO_" << final_writable_node->GetName() << "_REF_" << index;
NodePtr variable_ref_node = CreateVariableRef(var_node->GetName() + variable_ref_name.str(), var_node);
GE_CHECK_NOTNULL(variable_ref_node);
Status ret_check = CheckStreamLabel(variable_ref_node, final_writable_node);
if (ret_check != SUCCESS) {
GELOGE(FAILED, "check stream lable failed");
return FAILED;
}
GELOGI("Add variable_ref between [%s] and [%s]", var_node->GetName().c_str(), variable_ref_node->GetName().c_str());
// add control anchor between variable_ref and final peer node
// variable_ref_node need to execute before other nodes
CHECK_FALSE_EXEC(AddControlEdge(final_writable_node, variable_ref_node) == SUCCESS,
GELOGE(FAILED, "Add control edges between variable ref node and output nodes of ref node failed");
return FAILED);
graphStatus ret = ge::GraphUtils::AddEdge(out_anchor, variable_ref_node->GetInDataAnchor(0));
if (ret != GRAPH_SUCCESS) {
GELOGE(FAILED, "add data anchor between variable_ref and final_writable peer node failed");
return FAILED;
}
return SUCCESS;
}
Status VariablePrepareOpPass::InsertVariableRef(ge::NodePtr &node, int in_index, const ge::NodePtr &var_node) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(var_node);
// Check connection between two nodes
const auto in_anchor = node->GetInDataAnchor(in_index);
GE_CHECK_NOTNULL(in_anchor);
auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(peer_out_anchor);
auto peer_in_node = peer_out_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(peer_in_node);
// Create ref_identity
std::stringstream ref_identity_name;
ref_identity_name << "RefIdentity_" << peer_in_node->GetName() << "_" << peer_out_anchor->GetIdx() << "_TO_"
<< node->GetName() << "_" << in_index;
NodePtr ref_identity_node = CreateRefIdentity(ref_identity_name.str(), node, static_cast<uint32_t>(in_index));
GE_CHECK_NOTNULL(ref_identity_node);
// Create variable_ref
std::stringstream variable_ref_name;
variable_ref_name << "_TO_" << node->GetName() << "_REF_" << in_index;
NodePtr variable_ref_node = CreateVariableRef(var_node->GetName() + variable_ref_name.str(), var_node);
GE_CHECK_NOTNULL(variable_ref_node);
Status ret_check = CheckStreamLabel(variable_ref_node, node);
if (ret_check != SUCCESS) {
GELOGE(FAILED, "check stream lable failed");
return FAILED;
}
GELOGI("Insert variable_ref of [%s] between [%s] and [%s]", var_node->GetName().c_str(),
peer_in_node->GetName().c_str(), node->GetName().c_str());
// add control anchor between variable_ref and node
// variable_ref_node need to execute before other nodes
CHECK_FALSE_EXEC(AddControlEdge(node, variable_ref_node) == SUCCESS,
GELOGE(FAILED, "Add control edges between variable ref node and output nodes of ref node failed");
return FAILED);
// Insert variable ref node between two nodes and remove the original edge.
CHECK_FALSE_EXEC(ge::GraphUtils::RemoveEdge(peer_out_anchor, in_anchor) == SUCCESS,
GELOGE(FAILED, "Remove edge between ref node and its peer node failed");
return FAILED);
CHECK_FALSE_EXEC(ge::GraphUtils::AddEdge(peer_out_anchor, ref_identity_node->GetInDataAnchor(0)) == SUCCESS,
GELOGE(FAILED, "Add data edge between pre node and ref_identity failed");
return FAILED);
CHECK_FALSE_EXEC(ge::GraphUtils::AddEdge(ref_identity_node->GetOutDataAnchor(0), in_anchor) == SUCCESS,
GELOGE(FAILED, "Add data edge between ref_identity and ref node failed");
return FAILED);
// Add edge from ref identity node to variable ref node.
CHECK_FALSE_EXEC(
ge::GraphUtils::AddEdge(ref_identity_node->GetOutDataAnchor(0), variable_ref_node->GetInDataAnchor(0)) == SUCCESS,
GELOGE(FAILED, "Add data edge between ref_identity and variable_ref failed");
return FAILED);
CHECK_FALSE_EXEC(
ge::GraphUtils::AddEdge(node->GetOutControlAnchor(), variable_ref_node->GetInControlAnchor()) == SUCCESS,
GELOGE(FAILED, "Add control edge between ref_identity and variable_ref failed");
return FAILED);
return SUCCESS;
}
Status VariablePrepareOpPass::AddControlEdge(const ge::NodePtr &node, const ge::NodePtr &variable_ref_node) {
auto out_anchors = node->GetAllOutAnchors();
for (auto &out_anchor : out_anchors) {
GE_CHECK_NOTNULL(out_anchor);
for (auto &peer_in_anchor : out_anchor->GetPeerAnchors()) {
GE_CHECK_NOTNULL(peer_in_anchor);
NodePtr peer_node = peer_in_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(peer_node);
CHECK_FALSE_EXEC(
ge::GraphUtils::AddEdge(variable_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()) == SUCCESS,
GELOGE(FAILED, "Add control edge between variable_ref and ref node's peer node failed");
return FAILED);
}
}
return SUCCESS;
}
ge::NodePtr VariablePrepareOpPass::CreateRefIdentity(const std::string &ref_identity_name, const ge::NodePtr &node,
uint32_t input_index) {
OpDescPtr op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
GELOGE(FAILED, "opdesc is nullptr");
return nullptr;
}
OpDescPtr ref_identity_op_desc = MakeShared<OpDesc>(ref_identity_name.c_str(), REFIDENTITY);
if (ref_identity_op_desc == nullptr) {
GELOGE(FAILED, "ref_identity op desc is nullptr");
return nullptr;
}
GE_IF_BOOL_EXEC(ref_identity_op_desc->AddOutputDesc(op_desc->GetInputDesc(input_index)) != SUCCESS,
GELOGW("add output desc edge failed");
return nullptr);
GE_IF_BOOL_EXEC(ref_identity_op_desc->AddInputDesc(op_desc->GetInputDesc(input_index)) != SUCCESS,
GELOGW("add input desc edge failed");
return nullptr);
NodePtr ref_identity_node = node->GetOwnerComputeGraph()->AddNode(ref_identity_op_desc);
GE_IF_BOOL_EXEC(ref_identity_node == nullptr, GELOGW("ref_identity_node is null"); return nullptr);
return ref_identity_node;
}
ge::NodePtr VariablePrepareOpPass::CreateVariableRef(const std::string &variable_ref_name,
const ge::NodePtr &var_node) {
OpDescPtr var_op_desc = var_node->GetOpDesc();
if (var_op_desc == nullptr) {
GELOGE(FAILED, "get var opdesc is nullptr");
return nullptr;
}
OpDescPtr var_ref_op_desc = MakeShared<OpDesc>(variable_ref_name.c_str(), var_op_desc->GetType());
if (var_ref_op_desc == nullptr) {
GELOGE(FAILED, "var_ref opdesc is nullptr");
return nullptr;
}
GE_IF_BOOL_EXEC(var_ref_op_desc->AddOutputDesc(var_op_desc->GetOutputDesc(0)) != SUCCESS,
GELOGW("add output desc edge failed");
return nullptr);
GE_IF_BOOL_EXEC(var_ref_op_desc->AddInputDesc(var_op_desc->GetOutputDesc(0)) != SUCCESS,
GELOGW("add input desc edge failed");
return nullptr);
NodePtr variable_ref_node = var_node->GetOwnerComputeGraph()->AddNode(var_ref_op_desc);
GE_IF_BOOL_EXEC(variable_ref_node == nullptr, GELOGW("variable_ref_node is null"); return nullptr);
bool is_set_str = ge::AttrUtils::SetStr(var_ref_op_desc, REF_VAR_SRC_VAR_NAME, var_op_desc->GetName());
if (is_set_str) {
GELOGD("Set node [%s] REF_VAR_SRC_VAR_NAME [%s]", variable_ref_node->GetName().c_str(),
var_op_desc->GetName().c_str());
}
return variable_ref_node;
}
void VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int input_index,
std::vector<int> &output_indexes) {
if (node == nullptr) {
return;
}
GELOGD("get writable node and input index %s:%d", node->GetName().c_str(), input_index);
auto node_type = node->GetType();
if (node_type == FRAMEWORKOP) {
std::string original_type;
GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, GELOGW("Get node original type fail"));
GELOGD("find frameworkop: [%s], original type is %s", node->GetName().c_str(), original_type.c_str());
FindRefOutIndex(original_type, input_index, ref_node_without_prototype_map_, output_indexes);
return;
}
FindRefOutIndex(node_type, input_index, ref_input_output_map_, output_indexes);
return;
}
void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node) {
auto op_desc = node->GetOpDesc();
if (op_desc == nullptr) {
GELOGW("op_desc in null, please check node:[%s]", node->GetName().c_str());
return;
}
for (const auto &name_index : op_desc->GetAllInputName()) {
// Record the index of output with the same name as input, thinking of them as a pair of ref input and output.
const int out_index = op_desc->GetOutputIndexByName(name_index.first);
if (out_index != -1) {
ref_input_output_map_[node->GetType()][name_index.second] = {out_index};
continue;
}
// Record the ref input without corresponding output.
const auto &input_desc = op_desc->GetInputDesc(name_index.second);
if (!input_desc.GetRefPortIndex().empty()) {
ref_input_output_map_[node->GetType()][name_index.second] = {static_cast<int>(op_desc->GetOutputsSize())};
}
}
}
void VariablePrepareOpPass::FindRefOutIndex(const std::string &node_type, int input_index,
const std::map<std::string, std::map<int, vector<int>>> &ref_map,
std::vector<int> &output_indexes) {
auto node_iter = ref_map.find(node_type);
if (node_iter == ref_map.end()) {
return;
}
auto index_iter = node_iter->second.find(input_index);
if (index_iter == node_iter->second.end()) {
return;
}
for (const auto &out_index : index_iter->second) {
output_indexes.emplace_back(out_index);
}
}
Status VariablePrepareOpPass::CheckStreamLabel(const ge::NodePtr &var_ref_node,
const ge::NodePtr &final_writable_node) {
// Solve the problem that the writable node is not in the same stream as the subsequent node.
// Causes the stream to not trigger properly.
// The label of node should be handled uniformly.
OpDescPtr writable_desc = final_writable_node->GetOpDesc();
GE_CHECK_NOTNULL(writable_desc);
std::string stream_label;
(void)AttrUtils::GetStr(writable_desc, ATTR_NAME_STREAM_LABEL, stream_label);
if (!stream_label.empty()) {
GE_CHK_STATUS_RET(SetStreamLabel(var_ref_node, stream_label), "set stream label failed");
}
return SUCCESS;
}
bool VariablePrepareOpPass::HasControlOut(const ge::NodePtr &node) {
const auto &out_control_anchor = node->GetOutControlAnchor();
for (const auto &peer_in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) {
if (peer_in_control_anchor == nullptr || peer_in_control_anchor->GetOwnerNode() == nullptr) {
continue;
}
return true;
}
return false;
}
} // namespace ge