|
|
|
@ -607,6 +607,30 @@ Status InsertIdentityAsNeeded(const NodePtr &node) {
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) {
|
|
|
|
|
for (const auto &node : compute_graph->GetDirectNode()) {
|
|
|
|
|
if (node->GetType() == HCOMALLREDUCE) {
|
|
|
|
|
std::set<OutDataAnchorPtr> pre_out_anchor_set;
|
|
|
|
|
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
|
|
|
|
|
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
|
|
|
|
|
GE_CHECK_NOTNULL(pre_out_anchor);
|
|
|
|
|
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) {
|
|
|
|
|
pre_out_anchor_set.emplace(pre_out_anchor);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// need insert identity
|
|
|
|
|
auto pre_node = pre_out_anchor->GetOwnerNode();
|
|
|
|
|
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
|
|
|
|
|
GE_CHECK_NOTNULL(identity_node);
|
|
|
|
|
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
|
|
|
|
|
GE_CHK_STATUS_RET(ret, "Fail to insert identity.");
|
|
|
|
|
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
|
|
|
|
|
pre_node->GetName().c_str(), node->GetName().c_str())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
|
|
|
|
|
namespace ge {
|
|
|
|
@ -665,13 +689,14 @@ Status GraphOptimize::CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
}
|
|
|
|
|
Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) {
|
|
|
|
|
GE_DUMP(compute_graph, "BeforeHandleMemConflict");
|
|
|
|
|
node_rwtype_map_.clear();
|
|
|
|
|
auto sub_graph_vec = compute_graph->GetAllSubgraphs();
|
|
|
|
|
if (sub_graph_vec.empty()) {
|
|
|
|
|
GELOGD("No sub graph here. Ignore memory conflict handle.");
|
|
|
|
|
return SUCCESS;
|
|
|
|
|
// only root graph, to handle allreduce servral input from one output anchor
|
|
|
|
|
return HandleAllreduceDuplicateInput(compute_graph);
|
|
|
|
|
}
|
|
|
|
|
GE_DUMP(compute_graph, "BeforeHandleMemConflict");
|
|
|
|
|
|
|
|
|
|
// 1.loop all subgraph, mark rw type from inside to outside
|
|
|
|
|
Status ret = MarkRWTypeForAllSubgraph(sub_graph_vec);
|
|
|
|
|
if (ret != SUCCESS) {
|
|
|
|
|