You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
graphengine/metadef/graph/utils/node_utils.cc

957 lines
33 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/node_utils.h"
#include "utils/op_desc_utils.h"
#include "graph/utils/graph_utils.h"
#include "debug/ge_op_types.h"
#include "debug/ge_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/anchor.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/types.h"
#include "utils/tensor_utils.h"
#include "utils/type_utils.h"
namespace ge {
std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{};
std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{};
const std::set<std::string> kConstOpTypes = {"Const", "Constant"};
const std::set<std::string> kIfOpTypes = {"If", "_If", "StatelessIf"};
const std::set<std::string> kWhileOpTypes = {"While", "_While", "StatelessWhile"};
const std::set<std::string> kCaseOpTypes = {"Case"};
const std::set<std::string> kForOpTypes = {"For"};
bool OpShapeIsUnknown(const OpDescPtr &desc) {
for (const auto &ptr : desc->GetAllInputsDescPtr()) {
auto ge_shape = ptr->GetShape();
for (const auto &dim : ge_shape.GetDims()) {
if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
return true;
}
}
}
for (const auto &ptr : desc->GetAllOutputsDescPtr()) {
auto ge_shape = ptr->GetShape();
for (const auto &dim : ge_shape.GetDims()) {
if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
return true;
}
}
}
return false;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node,
const uint32_t &event_id) {
GE_CHECK_NOTNULL(node);
map_send_info_[node].push_back(event_id);
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddRecvEventId(const NodePtr &node,
const uint32_t &event_id) {
GE_CHECK_NOTNULL(node);
map_recv_info_[node].push_back(event_id);
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
NodeUtils::GetSendEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_send) {
GE_CHECK_NOTNULL(node);
auto find = map_send_info_.find(node);
if (find == map_send_info_.end()) {
return GRAPH_FAILED;
} else {
vec_send = find->second;
return GRAPH_SUCCESS;
}
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
NodeUtils::GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &vec_recv) {
GE_CHECK_NOTNULL(node);
auto find = map_recv_info_.find(node);
if (find == map_recv_info_.end()) {
return GRAPH_FAILED;
} else {
vec_recv = find->second;
return GRAPH_SUCCESS;
}
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearSendInfo() {
map_send_info_.clear();
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::ClearRecvInfo() {
map_recv_info_.clear();
return GRAPH_SUCCESS;
}
graphStatus NodeUtils::GetSingleOutputNodeOfNthLayer(const NodePtr &src, int depth, NodePtr &dst) {
GE_CHECK_NOTNULL(src);
NodePtr cur_ptr;
if (depth < 1) {
return GRAPH_FAILED;
}
for (int i = 0; i < depth; i++) {
if (src->GetOutDataNodes().size() != 1) {
return GRAPH_FAILED;
}
cur_ptr = src->GetOutDataNodes().at(0);
GE_CHECK_NOTNULL(cur_ptr);
}
dst = cur_ptr;
return GRAPH_SUCCESS;
}
graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_ptr, OutDataAnchorPtr &out_data,
InControlAnchorPtr &in_control) {
GE_CHECK_NOTNULL(node_ptr);
for (const auto &p : node_ptr->GetAllOutDataAnchors()) {
GE_CHK_BOOL_EXEC((p != nullptr), continue, "GetAllOutDataAnchors is nullptr");
for (const auto &p_in : p->GetPeerInControlAnchors()) {
GE_CHK_BOOL_EXEC((p_in != nullptr), continue, "GetPeerInDataAnchors is nullptr");
out_data = p;
in_control = p_in;
return GRAPH_SUCCESS;
}
}
return GRAPH_FAILED;
}
graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) {
GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED,
"node or in_data_anchor is nullptr");
bool find_flag = false;
uint32_t index = 0;
vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end();
for (const auto &tmp : node_ptr->in_data_anchors_) {
if (tmp == in_data_anchor) {
find_flag = true;
auto iter = node_ptr->in_data_anchors_.begin() + index;
if (iter != node_ptr->in_data_anchors_.end()) {
it = node_ptr->in_data_anchors_.erase(iter);
}
break;
}
index++;
}
for (; it != node_ptr->in_data_anchors_.end(); ++it) {
(*it)->SetIdx(index);
index++;
}
if (!find_flag) {
return GRAPH_FAILED;
}
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::SetAllAnchorStatus(const NodePtr &node_ptr) {
GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "node is nullptr");
GE_CHK_BOOL_EXEC(SetAllAnchorStatus(*node_ptr) == GRAPH_SUCCESS, return GRAPH_FAILED, "set all anchor status failed");
return GRAPH_SUCCESS;
}
graphStatus NodeUtils::SetAllAnchorStatus(Node &node) {
node.anchor_status_updated_ = true;
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool NodeUtils::IsAnchorStatusSet(const NodePtr &node_ptr) {
GE_CHK_BOOL_EXEC(node_ptr != nullptr, return false, "node is nullptr");
return IsAnchorStatusSet(*node_ptr);
}
bool NodeUtils::IsAnchorStatusSet(const Node &node) { return node.anchor_status_updated_; }
graphStatus NodeUtils::MoveOutputEdges(const NodePtr &origin_node, const NodePtr &new_node) {
if ((origin_node == nullptr) || (new_node == nullptr)) {
return GRAPH_FAILED;
}
auto origin_out_data_anchors = origin_node->GetAllOutDataAnchors();
auto new_out_data_anchors = new_node->GetAllOutDataAnchors();
if (origin_out_data_anchors.size() != new_out_data_anchors.size()) {
return GRAPH_FAILED;
}
for (size_t i = 0; i < origin_out_data_anchors.size(); ++i) {
for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInDataAnchors()) {
GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
"unlink peer_anchor failed");
GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
"linkto peer_anchor failed");
}
for (const auto &peer_anchor : origin_out_data_anchors.at(i)->GetPeerInControlAnchors()) {
GE_CHK_BOOL_EXEC(origin_out_data_anchors.at(i)->Unlink(peer_anchor) == GRAPH_SUCCESS, continue,
"unlink peer_anchor failed");
GE_CHK_BOOL_EXEC(new_out_data_anchors.at(i)->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
"linkto peer_anchor failed");
}
}
auto origin_out_control_anchor = origin_node->GetOutControlAnchor();
GE_CHECK_NOTNULL(origin_out_control_anchor);
auto new_out_control_anchor = new_node->GetOutControlAnchor();
GE_CHECK_NOTNULL(new_out_control_anchor);
for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInControlAnchors()) {
GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
"linkto peer_anchor failed");
}
for (const auto &peer_anchor : origin_out_control_anchor->GetPeerInDataAnchors()) {
GE_CHK_BOOL_EXEC(new_out_control_anchor->LinkTo(peer_anchor) == GRAPH_SUCCESS, continue,
"linkto peer_anchor failed");
}
origin_out_control_anchor->UnlinkAll();
return GRAPH_SUCCESS;
}
bool NodeUtils::IsConst(const Node &node) {
auto src_node_type = node.GetType();
bool is_const = ((src_node_type == CONSTANT) || (src_node_type == CONSTANTOP));
return is_const;
}
void NodeUtils::UpdateIsInputConst(const NodePtr &node_ptr) {
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "node is null");
return;
}
UpdateIsInputConst(*node_ptr);
}
///
/// update is_input_const
/// @param node
/// @return void
///
void NodeUtils::UpdateIsInputConst(Node &node) {
std::vector<bool> is_input_const;
size_t anchor_num = node.GetAllInDataAnchors().size();
for (size_t i = 0; i < anchor_num; i++) {
auto in_anchor = node.GetInDataAnchor(static_cast<int>(i));
if (in_anchor == nullptr) {
is_input_const.push_back(false);
continue;
}
auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
is_input_const.push_back(false);
continue;
}
auto src_node = peer_out_anchor->GetOwnerNode();
if (src_node == nullptr) {
is_input_const.push_back(false);
continue;
}
if (IsConst(*(src_node))) {
is_input_const.push_back(true);
} else {
is_input_const.push_back(false);
}
}
if (node.GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "Node get opdesc is nullptr");
return;
}
node.GetOpDesc()->SetIsInputConst(is_input_const);
}
void NodeUtils::UnlinkAll(const Node &node) {
for (const auto &anchor : node.GetAllOutAnchors()) {
anchor->UnlinkAll();
}
for (const auto &anchor : node.GetAllInAnchors()) {
anchor->UnlinkAll();
}
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeerNodeInputDesc(const NodePtr &node_ptr) {
if (node_ptr == nullptr) {
GELOGE(GRAPH_FAILED, "Nodeptr is nullptr");
return GRAPH_FAILED;
}
auto op_desc = node_ptr->GetOpDesc();
if (op_desc == nullptr) {
return GRAPH_FAILED;
}
bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag();
if (is_unknown_graph) {
return GRAPH_SUCCESS;
}
for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) {
auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx());
auto out_dims = output_tensor->GetShape().GetDims();
auto out_dtype = output_tensor->GetDataType();
ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
output_tensor->SetOriginShape(output_tensor->GetShape());
output_tensor->SetOriginDataType(output_tensor->GetDataType());
GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s",
node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(),
TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(),
TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str());
for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) {
if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) {
GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null");
continue;
}
auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx());
if (peer_input_desc == nullptr) {
GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr");
continue;
}
// check shape and dtype continuity. do not stop process
auto peer_input_dims = peer_input_desc->GetShape().GetDims();
auto peer_input_dtype = peer_input_desc->GetDataType();
if (out_dtype != peer_input_dtype) {
GELOGW(
"current node [%s] [%d]\'th out_dtype is [%s].peer input node [%s] [%d]\'th "
"input_dtype is [%s].The two dtype should be same! Please check graph and fix it",
node_ptr->GetName().c_str(), out_anchor->GetIdx(), TypeUtils::DataTypeToSerialString(out_dtype).c_str(),
peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(),
TypeUtils::DataTypeToSerialString(peer_input_dtype).c_str());
} else if ((!peer_input_dims.empty()) && (out_dims != peer_input_dims)) {
string out_shape_str, peer_in_shape_str;
out_shape_str += "[";
for (int64_t dim : out_dims) {
out_shape_str += std::to_string(dim) + " ";
}
out_shape_str += "]";
peer_in_shape_str += "[";
for (int64_t dim : peer_input_dims) {
peer_in_shape_str += std::to_string(dim) + " ";
}
peer_in_shape_str += "]";
GELOGW(
"current node [%s] [%d]\'th out_shape is [%s].peer input node [%s] [%d]\'th "
"input_shape is [%s].The two shape should be same! Please check graph and fix it",
node_ptr->GetName().c_str(), out_anchor->GetIdx(), out_shape_str.c_str(),
peer_anchor->GetOwnerNode()->GetName().c_str(), peer_anchor->GetIdx(), peer_in_shape_str.c_str());
}
GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(),
output_tensor->GetDataType(), output_tensor->GetOriginDataType());
peer_input_desc->SetOriginShape(output_tensor->GetOriginShape());
peer_input_desc->SetShape(output_tensor->GetShape());
peer_input_desc->SetDataType(output_tensor->GetDataType());
peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType());
std::vector<std::pair<int64_t, int64_t>> shape_range;
(void)output_tensor->GetShapeRange(shape_range);
peer_input_desc->SetShapeRange(shape_range);
ge::TensorUtils::SetRealDimCnt(*peer_input_desc,
static_cast<uint32_t>(output_tensor->GetShape().GetDims().size()));
GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d",
peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(),
peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType());
}
}
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node,
uint32_t num) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Input node is null");
return GRAPH_FAILED;
}
GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
const auto &op_desc = node->GetOpDesc();
for (size_t i = op_desc->GetInputsSize(); i < num; ++i) {
if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Add input desc failed");
return GRAPH_FAILED;
}
auto anchor = ComGraphMakeShared<InDataAnchor>(node, i);
if (anchor == nullptr) {
GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed.");
return GRAPH_FAILED;
}
node->in_data_anchors_.push_back(anchor);
}
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node,
uint32_t num) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Input node is null");
return GRAPH_FAILED;
}
const auto &op_desc = node->GetOpDesc();
while (op_desc->GetInputsSize() > num) {
if (!OpDescUtils::ClearInputDesc(op_desc, num)) {
return GRAPH_FAILED;
}
}
auto input_names = op_desc->GetAllInputName();
(void)op_desc->UpdateInputName(input_names);
auto is_input_const = op_desc->GetIsInputConst();
is_input_const.resize(num);
op_desc->SetIsInputConst(is_input_const);
while (node->in_data_anchors_.size() > num) {
node->in_data_anchors_.pop_back();
}
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutputAnchor(const NodePtr &node,
uint32_t num) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Input node is null");
return GRAPH_FAILED;
}
GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT);
const OpDescPtr &op_desc = node->GetOpDesc();
for (size_t i = op_desc->GetOutputsSize(); i < num; ++i) {
if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Add output desc failed");
return GRAPH_FAILED;
}
auto anchor = ComGraphMakeShared<OutDataAnchor>(node, i);
if (anchor == nullptr) {
GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed.");
return GRAPH_FAILED;
}
node->out_data_anchors_.push_back(anchor);
}
return GRAPH_SUCCESS;
}
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveOutputAnchor(const NodePtr &node,
uint32_t num) {
if (node == nullptr) {
GELOGE(GRAPH_FAILED, "Input node is null");
return GRAPH_FAILED;
}
const auto &op_desc = node->GetOpDesc();
auto output_names = op_desc->GetAllOutputName();
while (op_desc->GetOutputsSize() > num) {
if (!OpDescUtils::ClearOutputDesc(op_desc, num)) {
return GRAPH_FAILED;
}
}
(void)op_desc->UpdateOutputName(output_names);
while (node->out_data_anchors_.size() > num) {
node->out_data_anchors_.pop_back();
}
return GRAPH_SUCCESS;
}
bool NodeUtils::IsInNodesEmpty(const Node &node) {
for (const auto &in_anchor : node.in_data_anchors_) {
if (in_anchor != nullptr) {
auto out_anchor = in_anchor->GetPeerOutAnchor();
if (out_anchor != nullptr) {
if (out_anchor->GetOwnerNode() != nullptr) {
return false;
}
}
}
}
if ((node.in_control_anchor_ != nullptr) && (!node.in_control_anchor_->IsPeerOutAnchorsEmpty())) {
auto peer_out_control_anchors = node.in_control_anchor_->GetPeerOutControlAnchors();
for (const auto &out_control_anchor : peer_out_control_anchors) {
if (out_control_anchor != nullptr) {
if (out_control_anchor->GetOwnerNode() != nullptr) {
return false;
}
}
}
}
return true;
}
GeTensorDesc NodeUtils::GetOutputDesc(const Node &node, uint32_t index) {
auto desc = node.GetOpDesc();
if (desc == nullptr) {
return GeTensorDesc();
}
return desc->GetOutputDesc(index);
}
GeTensorDesc NodeUtils::GetInputDesc(const Node &node, uint32_t index) {
auto desc = node.GetOpDesc();
if (desc == nullptr) {
return GeTensorDesc();
}
return desc->GetInputDesc(index);
}
graphStatus NodeUtils::UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape) {
auto desc = node.GetOpDesc();
if (desc == nullptr) {
return GRAPH_PARAM_INVALID;
}
auto output_desc = desc->MutableOutputDesc(index);
if (output_desc == nullptr) {
return GRAPH_PARAM_INVALID;
}
output_desc->SetShape(shape);
return GRAPH_SUCCESS;
}
graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape) {
auto desc = node.GetOpDesc();
if (desc == nullptr) {
return GRAPH_PARAM_INVALID;
}
auto input_desc = desc->MutableInputDesc(index);
if (input_desc == nullptr) {
return GRAPH_PARAM_INVALID;
}
input_desc->SetShape(shape);
return GRAPH_SUCCESS;
}
graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) {
auto desc = node.GetOpDesc();
GE_CHECK_NOTNULL(desc);
// check self
is_unknow = OpShapeIsUnknown(desc);
if (is_unknow) {
return GRAPH_SUCCESS;
}
auto sub_graph_names = desc->GetSubgraphInstanceNames();
if (sub_graph_names.empty()) {
return GRAPH_SUCCESS;
} else {
auto owner_graph = node.GetOwnerComputeGraph();
GE_CHECK_NOTNULL(owner_graph);
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
if (root_graph == nullptr) {
GE_LOGE("Node %s gets null root graph", node.GetName().c_str());
return GRAPH_PARAM_INVALID;
}
for (auto &sub_graph_name : sub_graph_names) {
auto sub_graph = root_graph->GetSubgraph(sub_graph_name);
GE_CHECK_NOTNULL(sub_graph);
for (const auto &node_ptr : sub_graph->GetDirectNode()) {
auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow);
if (status != GRAPH_SUCCESS) {
GE_LOGE("get node unknown shape status failed!");
return status;
}
if (is_unknow) {
return GRAPH_SUCCESS;
}
}
}
}
return GRAPH_SUCCESS;
}
std::string NodeUtils::GetNodeType(const Node &node) {
if (node.GetType() != FRAMEWORKOP) {
return node.GetType();
}
std::string type;
(void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
return type;
}
std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); }
graphStatus NodeUtils::GetInputConstData(const ConstNodePtr &node_ptr, const string &dst_name, GeTensorPtr &ge_tensor) {
return GRAPH_SUCCESS;
}
graphStatus NodeUtils::GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor) {
return GRAPH_SUCCESS;
}
ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) {
auto op_desc = node.GetOpDesc();
if (op_desc == nullptr) {
return nullptr;
}
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
if (root_graph == nullptr) {
return nullptr;
}
return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index));
}
graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) {
if (subgraph == nullptr) {
GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index);
return GRAPH_PARAM_INVALID;
}
auto op_desc = node.GetOpDesc();
if (op_desc == nullptr) {
return GRAPH_PARAM_INVALID;
}
auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
if (root_graph == nullptr) {
GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str());
return GRAPH_PARAM_INVALID;
}
auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName());
if (ret != GRAPH_SUCCESS) {
GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index);
return ret;
}
subgraph->SetParentNode(node.shared_from_this());
subgraph->SetParentGraph(node.GetOwnerComputeGraph());
return root_graph->AddSubgraph(subgraph);
}
///
/// Check if node is input of subgraph
/// @param [in] node
/// @return bool
///
bool NodeUtils::IsSubgraphInput(const NodePtr &node) {
if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
return false;
}
auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
if (parent_op_desc == nullptr) {
return false;
}
// dynamic shape unknown graph false
// dynamic shape known graph with functional subgraph maybe true
if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
return false;
} else {
if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
return false;
}
}
}
return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX);
}
///
/// Check if node is output of subgraph
/// @param [in] node
/// @return bool
///
bool NodeUtils::IsSubgraphOutput(const NodePtr &node) {
if ((node == nullptr) || (node->GetOpDesc() == nullptr) ||
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) {
return false;
}
auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc();
if (parent_op_desc == nullptr) {
return false;
}
if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) {
if (node->GetOwnerComputeGraph()->GetParentGraph()->GetGraphUnknownFlag()) {
return false;
} else {
if (node->GetOwnerComputeGraph()->GetParentNode()->GetOwnerComputeGraph()->GetParentNode() == nullptr) {
return false;
}
}
}
for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) {
if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) {
return true;
}
}
return false;
}
///
/// @brief Get subgraph original input node.
/// @param [in] node
/// @return Node
///
NodePtr NodeUtils::GetParentInput(const Node &node) {
uint32_t parent_index = 0;
if (!AttrUtils::GetInt(node.GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
return nullptr;
}
// Subgraph Data Node, check for constant input.
const ComputeGraphPtr &graph = node.GetOwnerComputeGraph();
GE_CHECK_NOTNULL_EXEC(graph, return nullptr);
const NodePtr &parent_node = graph->GetParentNode();
GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr);
const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index);
GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr);
const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr);
return peer_out_anchor->GetOwnerNode();
}
NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return node == nullptr ? node : GetParentInput(*node); }
///
/// @brief Get is dynamic shape graph from node.
/// @param [in] node
/// @return bool
///
bool NodeUtils::IsDynamicShape(const Node &node) {
const auto graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph());
if (graph == nullptr) {
return false;
}
bool is_dynamic_shape = false;
(void)AttrUtils::GetBool(graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape);
return is_dynamic_shape;
}
bool NodeUtils::IsDynamicShape(const NodePtr &node) { return node == nullptr ? false : IsDynamicShape(*node); }
///
/// @brief Check is varying_input for while node
/// @param [in] node: Data node for subgraph
/// @return bool
///
bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) {
if (node == nullptr) {
return false;
}
if (node->GetType() != DATA) {
return false; // not input_node for subgraph
}
const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode();
if (parent_node == nullptr) {
return false; // root graph
}
if (kWhileOpTypes.count(parent_node->GetType()) == 0) {
return false; // not input_node for while subgraph
}
uint32_t index_i = 0;
if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) {
GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str());
return false;
}
bool varying_flag = true;
for (const auto &item : node->GetOutDataNodesAndAnchors()) {
if (item.first->GetType() != NETOUTPUT) {
continue;
}
OpDescPtr op_desc = item.first->GetOpDesc();
uint32_t index_o = 0;
if ((op_desc == nullptr) ||
!AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) {
continue; // input for while-cond subgraph
}
if (index_i != index_o) {
continue; // varying input for while-body subgraph
}
varying_flag = false;
break;
}
return varying_flag;
}
///
/// @brief Get subgraph input is constant.
/// @param [in] node
/// @param [out] string
/// @return bool
///
bool NodeUtils::GetConstOpType(const NodePtr &node, std::string &type) {
if (node == nullptr) {
return false;
}
if ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) {
type = node->GetType();
return true;
}
if (node->GetType() != DATA) {
return false; // not subgraph input node
}
const auto &parent = GetParentInput(node);
return GetConstOpType(parent, type);
}
///
/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph.
/// @param [in] node
/// @return return GRAPH_SUCCESS if remove successfully, other for failed.
///
Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) {
GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
return GRAPH_SUCCESS;
} else {
auto owner_graph = node->GetOwnerComputeGraph();
GE_CHECK_NOTNULL(owner_graph);
auto root_graph = GraphUtils::FindRootGraph(owner_graph);
GE_CHECK_NOTNULL(root_graph);
std::unordered_set<std::string> subgraph_to_remove;
for (auto &subgraph_name : subgraph_names) {
std::deque<std::string> queue;
queue.push_back(subgraph_name);
subgraph_to_remove.insert(subgraph_name);
op_desc->RemoveSubgraphInstanceName(subgraph_name);
while (!queue.empty()) {
auto graph_name = queue.front();
queue.pop_front();
auto subgraph = root_graph->GetSubgraph(graph_name);
GE_CHECK_NOTNULL(subgraph);
for (const auto &sub_node : subgraph->GetDirectNode()) {
auto sub_op_desc = sub_node->GetOpDesc();
GE_CHECK_NOTNULL(sub_op_desc);
auto sub_names = sub_op_desc->GetSubgraphInstanceNames();
// Subgraph and all nodes in it will be removed later,
// no need to remove 'SubgraphInstanceName' in op desc here.
for (auto &name : sub_names) {
if (subgraph_to_remove.insert(name).second) {
queue.push_back(name);
}
}
}
}
}
// Remove subgraph from root_graph
for (const auto &name : subgraph_to_remove) {
GELOGI("Remove subgraph:%s.", name.c_str());
root_graph->RemoveSubgraph(name);
}
}
return GRAPH_SUCCESS;
}
///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) {
vector<NodePtr> in_data_node_vec;
auto op_desc = node.GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
GELOGW("Node %s is single node without sub graph.", node.GetName().c_str());
return in_data_node_vec;
}
auto compute_graph = node.GetOwnerComputeGraph();
for (const std::string &instance_name : subgraph_names) {
auto subgraph = compute_graph->GetSubgraph(instance_name);
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
int parent_index = -1;
if (NodeUtils::IsSubgraphInput(node_in_subgraph)) {
(void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index);
if (parent_index == index) {
in_data_node_vec.emplace_back(node_in_subgraph);
}
}
}
}
return in_data_node_vec;
}
///
/// @brief Get subgraph input data node by index.
/// @param [in] node
/// @return Node
///
vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) {
vector<NodePtr> out_data_node_vec;
auto op_desc = node.GetOpDesc();
GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec);
auto subgraph_names = op_desc->GetSubgraphInstanceNames();
if (subgraph_names.empty()) {
GELOGI("Node %s is single node without sub graph.", node.GetName().c_str());
return out_data_node_vec;
}
auto compute_graph = node.GetOwnerComputeGraph();
for (const std::string &instance_name : subgraph_names) {
auto subgraph = compute_graph->GetSubgraph(instance_name);
for (const auto &node_in_subgraph : subgraph->GetDirectNode()) {
if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) {
out_data_node_vec.emplace_back(node_in_subgraph);
}
}
}
return out_data_node_vec;
}
NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, const int index) {
if (node.GetInDataAnchor(index) == nullptr) {
return nullptr;
}
if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) {
return nullptr;
}
return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode();
}
vector<pair<InDataAnchorPtr, NodePtr>> NodeUtils::GetOutDataNodesWithAnchorByIndex(const Node &node, const int index) {
vector<pair<InDataAnchorPtr, NodePtr>> out_data_nodes;
auto out_data_anchor = node.GetOutDataAnchor(index);
if (out_data_anchor == nullptr) {
return out_data_nodes;
}
for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
if (peer_in_anchor == nullptr) {
continue;
}
if (peer_in_anchor->GetOwnerNode() == nullptr) {
continue;
}
out_data_nodes.emplace_back(std::make_pair(peer_in_anchor, peer_in_anchor->GetOwnerNode()));
}
return out_data_nodes;
}
ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); }
} // namespace ge