You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
770 lines
31 KiB
770 lines
31 KiB
/**
|
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
#include "common/ge/ge_util.h"
|
|
#include "graph/common/omg_util.h"
|
|
#include "graph/debug/ge_attr_define.h"
|
|
#include "graph/optimize/graph_optimize.h"
|
|
#include "graph/utils/graph_utils.h"
|
|
#include "graph/utils/node_utils.h"
|
|
|
|
namespace {
|
|
using namespace ge;
|
|
const int kIdentityAnchorIndex = 0;
|
|
const size_t kSerialStringVecSize = 4;
|
|
|
|
const int kCaseReadOnly = 0;
|
|
const int kCaseScopeWriteable = 2;
|
|
const int kCaseWriteable = 3;
|
|
const int kCaseInvalidRWType = 5;
|
|
|
|
// rw type of input.
|
|
enum class InputRWType {
|
|
kReadOnly, // Normal op input only read
|
|
kWriteable, // Op like Assign/ApplyMomentum
|
|
kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput
|
|
kInvalidRWType
|
|
};
|
|
// rw type of output
|
|
enum class OutputRWType {
|
|
kReadOnly, // 1.const output 2.not ref output but has several peer output
|
|
kSoftRead, // not ref output but only has one output node
|
|
kWriteable, // ref output. Like Assign/ApplyMomentum
|
|
kInvalidRWType
|
|
};
|
|
|
|
// input and output rw_type of one node. key is anchor_idx, value is rw_type
|
|
struct NodeInputOutputRWType {
|
|
map<uint32_t, InputRWType> input_rw_type_map;
|
|
map<uint32_t, OutputRWType> output_rw_type_map;
|
|
};
|
|
// input and output rw_type of node in current graph
|
|
thread_local map<string, NodeInputOutputRWType> node_rwtype_map_;
|
|
|
|
///
|
|
/// @brief Convert input rw_type enum to string. For log print.
|
|
/// @param rw_type
|
|
/// @return rw_type_name
|
|
///
|
|
static std::string InputRWTypeToSerialString(InputRWType rw_type) {
|
|
const static char *names[kSerialStringVecSize] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"};
|
|
return names[static_cast<int>(rw_type)];
|
|
}
|
|
|
|
///
|
|
/// @brief Convert output rw_type enum to string. For log print.
|
|
/// @param rw_type
|
|
/// @return rw_type_name
|
|
///
|
|
static std::string OutputRWTypeToSerialString(OutputRWType rw_type) {
|
|
const static char *names[kSerialStringVecSize] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"};
|
|
return names[static_cast<int>(rw_type)];
|
|
}
|
|
|
|
OutputRWType GetSingleNodeOutputRWTypeByIndex(const Node &node, uint32_t index) {
|
|
auto op_desc = node.GetOpDesc();
|
|
if (op_desc == nullptr) {
|
|
return OutputRWType::kInvalidRWType;
|
|
}
|
|
if (op_desc->GetType() == VARIABLE) {
|
|
return OutputRWType::kWriteable;
|
|
}
|
|
// check if it is ref output
|
|
auto input_names = op_desc->GetAllInputName();
|
|
for (auto &input_name_2_idx : input_names) {
|
|
if (op_desc->GetOutputNameByIndex(index) == input_name_2_idx.first) {
|
|
return OutputRWType::kWriteable;
|
|
}
|
|
}
|
|
// check if it is ref switch
|
|
std::string type;
|
|
if ((node.GetType() == FRAMEWORK_OP_TYPE) && AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)
|
|
&& (type == REFSWITCH)) {
|
|
return OutputRWType::kWriteable;
|
|
}
|
|
|
|
if (op_desc->GetType() == CONSTANT || op_desc->GetType() == CONSTANTOP) {
|
|
return OutputRWType::kReadOnly;
|
|
}
|
|
auto out_data_anchor = node.GetOutDataAnchor(index);
|
|
if (out_data_anchor == nullptr) {
|
|
return OutputRWType::kInvalidRWType;
|
|
}
|
|
if (out_data_anchor->GetPeerInDataNodesSize() > 1) {
|
|
return OutputRWType::kReadOnly;
|
|
} else {
|
|
return OutputRWType::kSoftRead;
|
|
}
|
|
}
|
|
|
|
///
|
|
/// @brief Get input rw_type of one node with sub graph. It will return rw_type after solve conflict scene.
|
|
/// @param rw_type_set
|
|
/// @return
|
|
///
|
|
InputRWType GetInputRwTypeInConflict(const std::set<int> &rw_type_set) {
|
|
// for input rw type calc
|
|
int total_rw_type = 0;
|
|
for (const auto rw : rw_type_set) {
|
|
total_rw_type += rw;
|
|
}
|
|
|
|
switch (total_rw_type) {
|
|
case kCaseReadOnly:
|
|
return InputRWType::kReadOnly; // all input rw type is readonly
|
|
case kCaseScopeWriteable:
|
|
return InputRWType::kScopeWriteable; // readonly 2 scope_writeable
|
|
case kCaseWriteable:
|
|
return InputRWType::kWriteable; // all input rw type is writeable or readonly 2 writeable
|
|
case kCaseInvalidRWType:
|
|
return InputRWType::kInvalidRWType; // writeable 2 scope_writeable
|
|
default:
|
|
return InputRWType::kInvalidRWType;
|
|
}
|
|
}
|
|
|
|
bool IsSubgraphInputNode(const NodePtr &node) {
|
|
if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetType() != DATA) ||
|
|
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool IsSubgraphOutputNode(const NodePtr &node) {
|
|
if ((node == nullptr) || (node->GetOpDesc() == nullptr) || (node->GetType() != NETOUTPUT) ||
|
|
(node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) {
|
|
if (src_node.GetOpDesc() == nullptr) {
|
|
return nullptr;
|
|
}
|
|
static std::atomic_long identity_num(0);
|
|
auto next_num = identity_num.fetch_add(1);
|
|
// 1. create new identity op desc
|
|
string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num);
|
|
auto identity_opdesc = MakeShared<OpDesc>(identity_name, IDENTITY);
|
|
if (identity_opdesc == nullptr) {
|
|
GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str());
|
|
return nullptr;
|
|
}
|
|
auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx);
|
|
// 2. add input_desc & output_desc for new identity
|
|
Status ret = identity_opdesc->AddInputDesc("x", data_desc);
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str());
|
|
return nullptr;
|
|
}
|
|
ret = identity_opdesc->AddOutputDesc("y", data_desc);
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str());
|
|
return nullptr;
|
|
}
|
|
GELOGI("Insert new Identity node %s.", identity_name.c_str());
|
|
auto graph = src_node.GetOwnerComputeGraph();
|
|
if (graph == nullptr) {
|
|
GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str());
|
|
return nullptr;
|
|
}
|
|
return graph->AddNode(identity_opdesc);
|
|
}
|
|
|
|
OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) {
|
|
auto op_desc = node.GetOpDesc();
|
|
if (op_desc == nullptr) {
|
|
return OutputRWType::kInvalidRWType;
|
|
}
|
|
if (op_desc->GetType() == WHILE) {
|
|
return OutputRWType::kSoftRead;
|
|
}
|
|
vector<string> subgraph_names = op_desc->GetSubgraphInstanceNames();
|
|
if (subgraph_names.empty()) {
|
|
// single node without sub graph
|
|
return GetSingleNodeOutputRWTypeByIndex(node, index);
|
|
} else {
|
|
// node with sub graph
|
|
auto output_node_vec = NodeUtils::GetSubgraphOutputNodes(node);
|
|
auto output_rw_type = OutputRWType::kInvalidRWType;
|
|
if (output_node_vec.size() == 1) {
|
|
// find rw type from map.
|
|
auto iter = node_rwtype_map_.find(output_node_vec.at(0)->GetName());
|
|
if (iter == node_rwtype_map_.end()) {
|
|
GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
|
|
output_node_vec.at(0)->GetName().c_str());
|
|
return OutputRWType::kInvalidRWType;
|
|
}
|
|
auto index_2_output_rw_type = iter->second.output_rw_type_map.find(index);
|
|
if (index_2_output_rw_type == iter->second.output_rw_type_map.end()) {
|
|
GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
|
|
output_node_vec.at(0)->GetName().c_str());
|
|
return OutputRWType::kInvalidRWType;
|
|
}
|
|
output_rw_type = index_2_output_rw_type->second;
|
|
} else {
|
|
output_rw_type = OutputRWType::kSoftRead;
|
|
}
|
|
// check peer input
|
|
auto out_data_anchor = node.GetOutDataAnchor(index);
|
|
if (out_data_anchor == nullptr) {
|
|
return OutputRWType::kInvalidRWType;
|
|
}
|
|
if (out_data_anchor->GetPeerInDataNodesSize() > 1) {
|
|
return OutputRWType::kReadOnly;
|
|
} else {
|
|
return output_rw_type;
|
|
}
|
|
}
|
|
}
|
|
|
|
InputRWType GetSingleNodeInputRWTypeByIndex(const Node &node, uint32_t index) {
|
|
auto op_desc = node.GetOpDesc();
|
|
if (op_desc == nullptr) {
|
|
return InputRWType::kInvalidRWType;
|
|
}
|
|
if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMALLGATHER
|
|
|| op_desc->GetType() == HCOMREDUCESCATTER || op_desc->GetType() == HCOMREDUCE) {
|
|
return InputRWType::kScopeWriteable;
|
|
}
|
|
// check if it is ref input
|
|
auto output_names = op_desc->GetAllOutputName();
|
|
for (auto &output_name_2_idx : output_names) {
|
|
if (op_desc->GetInputNameByIndex(index) == output_name_2_idx.first) {
|
|
return InputRWType::kWriteable;
|
|
}
|
|
}
|
|
// check if it is ref switch
|
|
std::string type;
|
|
if ((node.GetType() == FRAMEWORK_OP_TYPE) && (AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type))
|
|
&& (type == REFSWITCH) && (index == 0)) {
|
|
return InputRWType::kWriteable;
|
|
}
|
|
|
|
return InputRWType::kReadOnly;
|
|
}
|
|
|
|
InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) {
|
|
auto op_desc = node.GetOpDesc();
|
|
if (op_desc == nullptr) {
|
|
return InputRWType::kInvalidRWType;
|
|
}
|
|
if (op_desc->GetType() == WHILE) {
|
|
return InputRWType::kScopeWriteable;
|
|
}
|
|
vector<string> subgraph_names = op_desc->GetSubgraphInstanceNames();
|
|
if (subgraph_names.empty()) {
|
|
// single node without sub graph
|
|
return GetSingleNodeInputRWTypeByIndex(node, index);
|
|
} else {
|
|
// node with sub graph
|
|
std::set<int> node_rw_type_set;
|
|
auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index);
|
|
// get all input data node in subgraph
|
|
std::set<int> anchor_rw_type_set;
|
|
for (const auto &data_node : data_node_vec) {
|
|
// Data only has 1 out data anchor. Here just take first out data anchor. And index 0 is valid.
|
|
auto out_data_anchor = data_node->GetOutDataAnchor(0);
|
|
if (out_data_anchor == nullptr) {
|
|
continue;
|
|
}
|
|
auto data_op_desc = data_node->GetOpDesc();
|
|
if (data_op_desc == nullptr) {
|
|
continue;
|
|
}
|
|
// find rw type from map.
|
|
auto iter = node_rwtype_map_.find(data_op_desc->GetName());
|
|
if (iter == node_rwtype_map_.end()) {
|
|
GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
|
|
data_op_desc->GetName().c_str());
|
|
return InputRWType::kInvalidRWType;
|
|
}
|
|
auto input_rw_type = iter->second.input_rw_type_map.find(out_data_anchor->GetIdx());
|
|
if (input_rw_type == iter->second.input_rw_type_map.end()) {
|
|
GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.",
|
|
data_op_desc->GetName().c_str());
|
|
return InputRWType::kInvalidRWType;
|
|
}
|
|
anchor_rw_type_set.emplace(static_cast<int>(input_rw_type->second));
|
|
}
|
|
return GetInputRwTypeInConflict(anchor_rw_type_set);
|
|
}
|
|
}
|
|
Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) {
|
|
for (const auto &node : sub_graph->GetDirectNode()) {
|
|
GE_CHECK_NOTNULL(node);
|
|
GE_CHECK_NOTNULL(node->GetOpDesc());
|
|
std::set<int> anchor_rw_type_set;
|
|
if (node->GetType() == DATA) {
|
|
// calc all input_rw_type of peer output , as input_rw_type of DATA. Index 0 is valid.
|
|
auto anchor_2_node_vec = NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, 0);
|
|
for (const auto anchor_2_node_pair : anchor_2_node_vec) {
|
|
auto input_rw_type = GetInputRWTypeByIndex(*anchor_2_node_pair.second, anchor_2_node_pair.first->GetIdx());
|
|
GELOGD("Input rw type of Node %s %dth input anchor is %s", anchor_2_node_pair.second->GetName().c_str(),
|
|
anchor_2_node_pair.first->GetIdx(), InputRWTypeToSerialString(input_rw_type).c_str());
|
|
anchor_rw_type_set.emplace(static_cast<int>(input_rw_type));
|
|
}
|
|
auto anchor_rw_type = GetInputRwTypeInConflict(anchor_rw_type_set);
|
|
GELOGD("Input rw type of Node %s is %s", node->GetName().c_str(),
|
|
InputRWTypeToSerialString(anchor_rw_type).c_str());
|
|
map<uint32_t, InputRWType> input_rw_type_map{std::make_pair(0, anchor_rw_type)};
|
|
NodeInputOutputRWType data_rw_type{input_rw_type_map};
|
|
node_rwtype_map_.emplace(std::make_pair(node->GetName(), data_rw_type));
|
|
}
|
|
|
|
if (node->GetType() == NETOUTPUT) {
|
|
// calc all output_rw_type of peer input , as output_rw_type of DATA
|
|
map<uint32_t, OutputRWType> output_rw_type_map;
|
|
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
|
|
GE_CHECK_NOTNULL(in_data_anchor);
|
|
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
|
|
GE_CHECK_NOTNULL(pre_out_anchor);
|
|
auto pre_node = pre_out_anchor->GetOwnerNode();
|
|
GE_CHECK_NOTNULL(pre_node);
|
|
|
|
auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx());
|
|
GELOGD("Output rw type of Node %s %dth output anchor is %s", pre_node->GetName().c_str(),
|
|
pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str());
|
|
auto parent_node = sub_graph->GetParentNode();
|
|
if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) {
|
|
// insert identity
|
|
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
|
|
GE_CHECK_NOTNULL(identity_node);
|
|
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Fail to insert identity");
|
|
return ret;
|
|
}
|
|
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
|
|
pre_node->GetName().c_str(), node->GetName().c_str());
|
|
pre_output_rw_type = OutputRWType::kSoftRead;
|
|
}
|
|
output_rw_type_map.emplace(std::make_pair(in_data_anchor->GetIdx(), pre_output_rw_type));
|
|
}
|
|
NodeInputOutputRWType output_rw_type{{}, output_rw_type_map};
|
|
node_rwtype_map_.emplace(std::make_pair(node->GetName(), output_rw_type));
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
///
|
|
/// @brief Reverse traversal all subgraph and mark rw_type for Data/Netoutput.
|
|
/// @param sub_graph_vecgs
|
|
///
|
|
Status MarkRWTypeForAllSubgraph(const vector<ComputeGraphPtr> &sub_graph_vec) {
|
|
for (auto iter = sub_graph_vec.rbegin(); iter != sub_graph_vec.rend(); ++iter) {
|
|
auto parent_node = (*iter)->GetParentNode();
|
|
if (parent_node == nullptr) {
|
|
GELOGD("Current sub graph has no parent node. Ignore it.");
|
|
continue;
|
|
}
|
|
if (parent_node->GetType() == WHILE) {
|
|
continue;
|
|
}
|
|
auto ret = MarkRWTypeForSubgraph(*iter);
|
|
if (ret != SUCCESS) {
|
|
return ret;
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
///
|
|
/// @brief Check identity is near subgraph.
|
|
/// Eg. As output of Data node in subgraph
|
|
/// or as input of Netoutput of subgraph
|
|
/// or as input of one node with subgraph
|
|
/// or as output of one node with subgraph
|
|
/// @param node
|
|
/// @return is_near_subgraph
|
|
///
|
|
bool CheckIdentityIsNearSubgraph(const Node &node) {
|
|
for (const auto &in_node : node.GetInDataNodes()) {
|
|
auto in_node_opdesc = in_node->GetOpDesc();
|
|
if (in_node_opdesc == nullptr) {
|
|
continue;
|
|
}
|
|
// near entrance of subgraph
|
|
if (IsSubgraphInputNode(in_node)) {
|
|
return true;
|
|
}
|
|
// near subgraph
|
|
if (!in_node_opdesc->GetSubgraphInstanceNames().empty()) {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
for (const auto &out_node : node.GetOutDataNodes()) {
|
|
auto out_node_opdesc = out_node->GetOpDesc();
|
|
if (out_node_opdesc == nullptr) {
|
|
continue;
|
|
}
|
|
// near output of subgraph
|
|
if (IsSubgraphOutputNode(out_node)) {
|
|
return true;
|
|
}
|
|
// near subgraph
|
|
if (!out_node_opdesc->GetSubgraphInstanceNames().empty()) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
enum ConflictResult { DO_NOTHING, WRONG_GRAPH, INSERT_IDENTITY };
|
|
vector<vector<ConflictResult>> output_2_input_rwtype = {{DO_NOTHING, WRONG_GRAPH, INSERT_IDENTITY},
|
|
{DO_NOTHING, WRONG_GRAPH, DO_NOTHING},
|
|
{DO_NOTHING, DO_NOTHING, INSERT_IDENTITY}};
|
|
ConflictResult GetConflictResultBetweenNode(const OutputRWType output_rw_type, const InputRWType input_rw_type) {
|
|
if (output_rw_type == OutputRWType::kInvalidRWType || input_rw_type == InputRWType::kInvalidRWType) {
|
|
return WRONG_GRAPH;
|
|
}
|
|
auto n = static_cast<int>(output_rw_type);
|
|
auto m = static_cast<int>(input_rw_type);
|
|
// no need to check index or container, because container and index is all defined.
|
|
return output_2_input_rwtype[n][m];
|
|
}
|
|
|
|
///
|
|
/// @brief Keep identity_node which near subgraph or has multi output
|
|
/// @param node
|
|
/// @return
|
|
///
|
|
Status RemoveNoUseIdentity(const NodePtr &node) {
|
|
if (node->GetInDataNodes().empty() || node->GetOutDataNodesSize() > 1) {
|
|
return SUCCESS;
|
|
}
|
|
if (node->GetOutDataNodesSize() == 1 && node->GetOutDataNodes().at(0)->GetType() == STREAMMERGE) {
|
|
return SUCCESS;
|
|
}
|
|
if (CheckIdentityIsNearSubgraph(*node)) {
|
|
return SUCCESS;
|
|
}
|
|
GE_CHECK_NOTNULL(node->GetInDataAnchor(kIdentityAnchorIndex));
|
|
auto pre_out_anchor = node->GetInDataAnchor(kIdentityAnchorIndex)->GetPeerOutAnchor();
|
|
GE_CHECK_NOTNULL(pre_out_anchor);
|
|
auto pre_node = pre_out_anchor->GetOwnerNode();
|
|
auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx());
|
|
|
|
auto anchor_2_outnode_vec = NodeUtils::GetOutDataNodesWithAnchorByIndex(*node, kIdentityAnchorIndex);
|
|
ConflictResult conflict_result = WRONG_GRAPH;
|
|
if (!anchor_2_outnode_vec.empty()) {
|
|
auto anchor_2_outnode = anchor_2_outnode_vec.at(0);
|
|
auto peer_input_rw_type = GetInputRWTypeByIndex(*anchor_2_outnode.second, anchor_2_outnode.first->GetIdx());
|
|
|
|
GELOGD("Pre Node %s %dth output rw type is %s, peer node %s %dth input rw type is %s.", pre_node->GetName().c_str(),
|
|
pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str(),
|
|
anchor_2_outnode.second->GetName().c_str(), anchor_2_outnode.first->GetIdx(),
|
|
InputRWTypeToSerialString(peer_input_rw_type).c_str());
|
|
conflict_result = GetConflictResultBetweenNode(pre_output_rw_type, peer_input_rw_type);
|
|
} else {
|
|
// identity node has no out data node, it can be removed
|
|
conflict_result = DO_NOTHING;
|
|
}
|
|
if (conflict_result != DO_NOTHING) {
|
|
return SUCCESS;
|
|
}
|
|
|
|
GELOGI("No need insert Identity. Node %s need to remove.", node->GetName().c_str());
|
|
auto ret = GraphUtils::IsolateNode(node, {0});
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str());
|
|
return ret;
|
|
}
|
|
ret = GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node);
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str());
|
|
return ret;
|
|
}
|
|
GELOGI("Pre node is %s and %dth output rw type is %s. Isolate and remove Identity node %s.",
|
|
pre_node->GetName().c_str(), pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str(),
|
|
node->GetName().c_str());
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const InDataAnchorPtr &peer_in_data_anchor,
|
|
const OutDataAnchorPtr &pre_out_data_anchor, NodePtr &pre_node) {
|
|
// 1.check peer in node RW type.
|
|
GE_CHECK_NOTNULL(peer_in_data_anchor);
|
|
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode();
|
|
GE_CHECK_NOTNULL(peer_in_data_node);
|
|
auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx());
|
|
auto ret = out_data_anchor->Unlink(peer_in_data_anchor);
|
|
auto old_identity = out_data_anchor->GetOwnerNode();
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(),
|
|
peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
|
|
return ret;
|
|
}
|
|
if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) {
|
|
auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx());
|
|
GE_CHECK_NOTNULL(new_identity);
|
|
if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS
|
|
|| GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) {
|
|
GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s",
|
|
pre_out_data_anchor->GetOwnerNode()->GetName().c_str(),
|
|
peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
|
|
return INTERNAL_ERROR;
|
|
}
|
|
|
|
// 2. copy in-control-edge from dst to Identity
|
|
if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) {
|
|
GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(),
|
|
new_identity->GetName().c_str());
|
|
return INTERNAL_ERROR;
|
|
}
|
|
GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(),
|
|
InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(),
|
|
peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
|
|
} else {
|
|
// copy control edge to pre and peer node
|
|
if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS
|
|
|| GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) {
|
|
GELOGW("Fail to copy control edge from node %s.", old_identity->GetName().c_str());
|
|
return FAILED;
|
|
}
|
|
// link identity pre node to next node directly
|
|
if (GraphUtils::AddEdge(pre_out_data_anchor, peer_in_data_anchor) != SUCCESS) {
|
|
GELOGW("Fail to link data edge from node %s to %s.", pre_out_data_anchor->GetOwnerNode()->GetName().c_str(),
|
|
peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
|
|
return FAILED;
|
|
}
|
|
GELOGI("Node %s input rw type is %s, link data edge from Identity input node %s to out node %s directly.",
|
|
peer_in_data_node->GetName().c_str(), InputRWTypeToSerialString(input_rw_type).c_str(),
|
|
pre_node->GetName().c_str(), peer_in_data_node->GetName().c_str());
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status SplitIdentity(const NodePtr &node) {
|
|
GE_CHECK_NOTNULL(node);
|
|
auto out_data_anchor = node->GetOutDataAnchor(kIdentityAnchorIndex);
|
|
GE_CHECK_NOTNULL(out_data_anchor);
|
|
if (out_data_anchor->GetPeerInDataNodesSize() <= 1) {
|
|
return SUCCESS;
|
|
}
|
|
// get pre node and next node of identity
|
|
GE_CHECK_NOTNULL(node->GetInDataAnchor(kIdentityAnchorIndex));
|
|
auto pre_out_data_anchor = node->GetInDataAnchor(kIdentityAnchorIndex)->GetPeerOutAnchor();
|
|
GE_CHECK_NOTNULL(pre_out_data_anchor);
|
|
auto pre_node = pre_out_data_anchor->GetOwnerNode();
|
|
GE_CHECK_NOTNULL(pre_node);
|
|
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
|
|
Status ret = SplitIdentityAlongAnchor(out_data_anchor, peer_in_data_anchor, pre_out_data_anchor, pre_node);
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Split identity node along anchor failed.");
|
|
return ret;
|
|
}
|
|
}
|
|
// 2.isolate Identity node with no data output
|
|
if (node->GetOutDataNodesSize() == 0) {
|
|
Status ret = GraphUtils::IsolateNode(node, {});
|
|
if (ret != SUCCESS) {
|
|
GELOGE(FAILED, "IsolateAndDelete identity node %s.", node->GetName().c_str());
|
|
return FAILED;
|
|
}
|
|
ret = GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node);
|
|
if (ret != SUCCESS) {
|
|
GELOGE(FAILED, "IsolateAndDelete identity node %s.", node->GetName().c_str());
|
|
return FAILED;
|
|
}
|
|
GELOGI("IsolateAndDelete identity node %s.", node->GetName().c_str());
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
|
|
Status InsertIdentityAsNeeded(const NodePtr &node) {
|
|
auto op_desc = node->GetOpDesc();
|
|
GE_CHECK_NOTNULL(op_desc);
|
|
if (node->GetOutDataNodesSize() == 0) {
|
|
return SUCCESS;
|
|
}
|
|
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
|
|
GE_CHECK_NOTNULL(out_data_anchor);
|
|
auto output_rw_type = GetOutputRWTypeByIndex(*node, out_data_anchor->GetIdx());
|
|
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
|
|
GE_CHECK_NOTNULL(peer_in_data_anchor);
|
|
auto peer_in_node = peer_in_data_anchor->GetOwnerNode();
|
|
GE_CHECK_NOTNULL(peer_in_node);
|
|
auto input_rw_type = GetInputRWTypeByIndex(*peer_in_node, peer_in_data_anchor->GetIdx());
|
|
GELOGD("Node %s output rw type is %s, Node %s input rw type is %s", node->GetName().c_str(),
|
|
OutputRWTypeToSerialString(output_rw_type).c_str(), peer_in_node->GetName().c_str(),
|
|
InputRWTypeToSerialString(input_rw_type).c_str());
|
|
auto conflict_result = GetConflictResultBetweenNode(output_rw_type, input_rw_type);
|
|
switch (conflict_result) {
|
|
case DO_NOTHING:
|
|
case WRONG_GRAPH:
|
|
GELOGD("No need insert Identity.");
|
|
continue;
|
|
case INSERT_IDENTITY:
|
|
auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx());
|
|
if (identity_node == nullptr) {
|
|
GELOGE(FAILED, "Create identity node failed.");
|
|
return FAILED;
|
|
}
|
|
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node);
|
|
if (ret != GRAPH_SUCCESS) {
|
|
GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(),
|
|
peer_in_node->GetName().c_str());
|
|
return INTERNAL_ERROR;
|
|
}
|
|
GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(),
|
|
peer_in_node->GetName().c_str());
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) {
|
|
for (const auto &node : compute_graph->GetDirectNode()) {
|
|
if (node->GetType() == HCOMALLREDUCE) {
|
|
std::set<OutDataAnchorPtr> pre_out_anchor_set;
|
|
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
|
|
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
|
|
GE_CHECK_NOTNULL(pre_out_anchor);
|
|
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) {
|
|
pre_out_anchor_set.emplace(pre_out_anchor);
|
|
continue;
|
|
}
|
|
// need insert identity
|
|
auto pre_node = pre_out_anchor->GetOwnerNode();
|
|
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
|
|
GE_CHECK_NOTNULL(identity_node);
|
|
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
|
|
GE_CHK_STATUS_RET(ret, "Fail to insert identity.");
|
|
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
|
|
pre_node->GetName().c_str(), node->GetName().c_str());
|
|
}
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
} // namespace
|
|
|
|
namespace ge {
|
|
Status GraphOptimize::CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_conflict) {
|
|
node_rwtype_map_.clear();
|
|
auto sub_graph_vec = compute_graph->GetAllSubgraphs();
|
|
if (sub_graph_vec.empty()) {
|
|
GELOGD("No sub graph here. Ignore memory conflict handle.");
|
|
return SUCCESS;
|
|
}
|
|
// 1.loop all subgraph, mark rw type from inside to outside
|
|
Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec);
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Fail to mark rw type for subgraph.");
|
|
return ret;
|
|
}
|
|
has_conflict = false;
|
|
for (const auto &node : compute_graph->GetAllNodes()) {
|
|
auto op_desc = node->GetOpDesc();
|
|
GE_CHECK_NOTNULL(op_desc);
|
|
if (node->GetOutDataNodesSize() == 0) {
|
|
return SUCCESS;
|
|
}
|
|
if (node->GetType() == WHILE) {
|
|
return SUCCESS;
|
|
}
|
|
for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) {
|
|
GE_CHECK_NOTNULL(out_data_anchor);
|
|
auto output_rw_type = GetOutputRWTypeByIndex(*node, out_data_anchor->GetIdx());
|
|
for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
|
|
GE_CHECK_NOTNULL(peer_in_data_anchor);
|
|
auto peer_in_node = peer_in_data_anchor->GetOwnerNode();
|
|
GE_CHECK_NOTNULL(peer_in_node);
|
|
if (peer_in_node->GetType() == WHILE) {
|
|
return SUCCESS;
|
|
}
|
|
auto input_rw_type = GetInputRWTypeByIndex(*peer_in_node, peer_in_data_anchor->GetIdx());
|
|
auto conflict_result = GetConflictResultBetweenNode(output_rw_type, input_rw_type);
|
|
switch (conflict_result) {
|
|
case DO_NOTHING:
|
|
GELOGD("No rw conflict.");
|
|
continue;
|
|
case WRONG_GRAPH:
|
|
has_conflict = true;
|
|
GELOGI("Node %s output rw type is %s, next node %s input_rw_type is %s.It is wrong graph.",
|
|
node->GetName().c_str(), OutputRWTypeToSerialString(output_rw_type).c_str(),
|
|
peer_in_node->GetName().c_str(), InputRWTypeToSerialString(input_rw_type).c_str());
|
|
return SUCCESS;
|
|
case INSERT_IDENTITY:
|
|
GELOGD("There is rw conflict. It will handle later.");
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return SUCCESS;
|
|
}
|
|
Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) {
|
|
GE_DUMP(compute_graph, "BeforeHandleMemConflict");
|
|
node_rwtype_map_.clear();
|
|
auto sub_graph_vec = compute_graph->GetAllSubgraphs();
|
|
if (sub_graph_vec.empty()) {
|
|
// only root graph, to handle allreduce servral input from one output anchor
|
|
return HandleAllreduceDuplicateInput(compute_graph);
|
|
}
|
|
|
|
// 1.loop all subgraph, mark rw type from inside to outside
|
|
Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec);
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Fail to mark rw type for subgraph.");
|
|
return ret;
|
|
}
|
|
// 2.loop all node, including node in subgraph and handle memory rw conflict
|
|
for (auto &node : compute_graph->GetAllNodes()) {
|
|
// ignore while subgraph node
|
|
const auto parent_node = node->GetOwnerComputeGraph()->GetParentNode();
|
|
if ((parent_node != nullptr) && (kWhileOpTypes.count(parent_node->GetType()) > 0)) {
|
|
continue;
|
|
}
|
|
// ignore data / netoutput of subgraph
|
|
if (node->GetType() == DATA && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) {
|
|
continue;
|
|
}
|
|
if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) {
|
|
continue;
|
|
}
|
|
if (node->GetType() == IDENTITY || node->GetType() == READVARIABLEOP) {
|
|
// split identity
|
|
ret = SplitIdentity(node);
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Fail to split identity node %s.", node->GetName().c_str());
|
|
return ret;
|
|
}
|
|
// remove no use identity
|
|
ret = RemoveNoUseIdentity(node);
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Fail to remove useless identity node %s.", node->GetName().c_str());
|
|
return ret;
|
|
}
|
|
}
|
|
// insert Identity
|
|
ret = InsertIdentityAsNeeded(node);
|
|
if (ret != SUCCESS) {
|
|
GELOGE(ret, "Fail to insert Identity node.");
|
|
return ret;
|
|
}
|
|
}
|
|
GE_DUMP(compute_graph, "AfterHandleMemConflict");
|
|
return SUCCESS;
|
|
}
|
|
} // namespace ge
|