|
|
|
@ -24,6 +24,7 @@
|
|
|
|
|
#include <set>
|
|
|
|
|
#include <stack>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <tuple>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "base/core_ops.h"
|
|
|
|
|
#include "ir/tensor.h"
|
|
|
|
@ -304,27 +305,52 @@ void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNo
|
|
|
|
|
user_cnode->set_input(index, depend_cnode);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &pre_node,
|
|
|
|
|
const AnfNodePtr &post_node, const FuncGraphManagerPtr &mng) {
|
|
|
|
|
// Collect use dependencies firstly.
|
|
|
|
|
auto post_users = mng->node_users()[post_node];
|
|
|
|
|
|
|
|
|
|
AnfNodePtr AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node,
|
|
|
|
|
const AnfNodePtr &behind_node, const AnfNodePtr &patron_node) {
|
|
|
|
|
// Create control depend, first input is composite op, second is user
|
|
|
|
|
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), pre_node, post_node};
|
|
|
|
|
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node};
|
|
|
|
|
auto control_depend_cnode = main_graph->NewCNode(cd_inputs);
|
|
|
|
|
main_graph->AddNode(control_depend_cnode);
|
|
|
|
|
|
|
|
|
|
// Create depend node to hold new control depend node.
|
|
|
|
|
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), post_node, control_depend_cnode};
|
|
|
|
|
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), patron_node, control_depend_cnode};
|
|
|
|
|
auto depend_cnode = main_graph->NewCNode(d_inputs);
|
|
|
|
|
depend_cnode->set_abstract(post_node->abstract());
|
|
|
|
|
depend_cnode->set_abstract(patron_node->abstract());
|
|
|
|
|
main_graph->AddNode(depend_cnode);
|
|
|
|
|
|
|
|
|
|
for (const auto &[user_node, index] : post_users) {
|
|
|
|
|
auto user_cnode = user_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(user_cnode);
|
|
|
|
|
user_cnode->set_input(index, depend_cnode);
|
|
|
|
|
return depend_cnode;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::tuple<AnfNodePtr, AnfNodePtr, int> AtomicCleanInsertter::FindPatronNode(const KernelGraphPtr &main_graph) {
|
|
|
|
|
auto mng = main_graph->manager();
|
|
|
|
|
if (mng == nullptr) {
|
|
|
|
|
mng = Manage(main_graph, true);
|
|
|
|
|
main_graph->set_manager(mng);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AnfNodePtr patron_node;
|
|
|
|
|
|
|
|
|
|
auto return_cnode = main_graph->get_return()->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(return_cnode);
|
|
|
|
|
auto output_node = return_cnode->input(kFirstDataInputIndex);
|
|
|
|
|
if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
|
|
|
|
|
auto output_cnode = output_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(output_cnode);
|
|
|
|
|
patron_node = output_cnode->input(kFirstDataInputIndex);
|
|
|
|
|
} else {
|
|
|
|
|
patron_node = output_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto &user_nodes = mng->node_users()[patron_node];
|
|
|
|
|
auto user = user_nodes.begin();
|
|
|
|
|
return std::make_tuple(patron_node, user->first, user->second);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AtomicCleanInsertter::PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user,
|
|
|
|
|
int index) {
|
|
|
|
|
auto patron_user_cnode = patron_user->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(patron_user_cnode);
|
|
|
|
|
patron_user_cnode->set_input(index, patron_node);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) {
|
|
|
|
@ -380,14 +406,14 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
|
|
|
|
|
kernel::Processor::CUDA);
|
|
|
|
|
auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
|
|
|
|
|
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
|
|
|
|
|
// mng->AddFuncGraph(new_sub_graph);
|
|
|
|
|
|
|
|
|
|
return broadcast_to_composite_node;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
|
|
|
|
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) {
|
|
|
|
|
// 1. find users, change getitem index if needed.
|
|
|
|
|
std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUsers(const KernelGraphPtr &main_graph,
|
|
|
|
|
const AnfNodePtr &composite_node,
|
|
|
|
|
const FuncGraphManagerPtr &mng,
|
|
|
|
|
bool correct_index) {
|
|
|
|
|
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes;
|
|
|
|
|
if (real_output_num_ <= 1) {
|
|
|
|
|
auto users = mng->node_users()[composite_node];
|
|
|
|
@ -409,7 +435,7 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra
|
|
|
|
|
auto item_idx = GetValue<int64_t>(value_node->value());
|
|
|
|
|
if (item_idx == static_cast<int64_t>(reduce_real_output_index_)) {
|
|
|
|
|
getitem_user_nodes.push_back(node_index);
|
|
|
|
|
} else {
|
|
|
|
|
} else if (correct_index) {
|
|
|
|
|
if (real_output_num_ > 2) {
|
|
|
|
|
// Recorrect other getitem index.
|
|
|
|
|
int64_t new_item_idx = CalNewIndex(item_idx, reduce_real_output_index_);
|
|
|
|
@ -431,7 +457,6 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto &pair : getitem_user_nodes) {
|
|
|
|
|
// dirctory to find real user.
|
|
|
|
|
auto real_users = mng->node_users()[pair.first];
|
|
|
|
@ -439,12 +464,16 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return reduce_user_nodes;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
|
|
|
|
|
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) {
|
|
|
|
|
// 1. find users, change getitem index if needed.
|
|
|
|
|
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes =
|
|
|
|
|
FindOriginCNodeUsers(main_graph, composite_node, mng, true);
|
|
|
|
|
for (const auto &[user_node, index] : reduce_user_nodes) {
|
|
|
|
|
// 2. set ac output as user's input.
|
|
|
|
|
auto user_cnode = user_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(user_cnode);
|
|
|
|
|
user_cnode->set_input(index, broadcast_to_node);
|
|
|
|
|
// mng->SetEdge(user_node, index, broadcast_to_node);
|
|
|
|
|
// 3. Make sure modified composite node running first.
|
|
|
|
|
// * To not change the origin node's dependency relation, add ControlDepend and Depend node.
|
|
|
|
|
// * For Return node and output node, ControlDepend node will change the order of these two node, which will may
|
|
|
|
@ -452,7 +481,10 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra
|
|
|
|
|
if (IsPrimitiveCNode(user_node, prim::kPrimReturn) || user_node == main_graph->output()) {
|
|
|
|
|
AddDepend(main_graph, broadcast_to_node, composite_node, user_node, index);
|
|
|
|
|
} else {
|
|
|
|
|
AddControlDepend(main_graph, composite_node, user_node, mng);
|
|
|
|
|
auto user_cnode = user_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(user_cnode);
|
|
|
|
|
user_cnode->set_input(index, broadcast_to_node);
|
|
|
|
|
to_process_order_.emplace_back(composite_node, user_node);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -473,6 +505,26 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c
|
|
|
|
|
|
|
|
|
|
// Replace origin ReduceSum's user with atomic clean output, and add control depend from composite op to user.
|
|
|
|
|
ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, mng);
|
|
|
|
|
MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
|
|
|
|
|
<< ", clean node: " << broadcast_to_node->fullname_with_scope();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node,
|
|
|
|
|
const FuncGraphManagerPtr &mng) {
|
|
|
|
|
auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false);
|
|
|
|
|
// If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce
|
|
|
|
|
// node and user node. If reduce is Depend or ControlDepend node, the origin node may be wrong!
|
|
|
|
|
return std::all_of(reduce_users.cbegin(), reduce_users.cend(),
|
|
|
|
|
[&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool {
|
|
|
|
|
auto &user = user_info.first;
|
|
|
|
|
if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend) ||
|
|
|
|
|
IsPrimitiveCNode(user, prim::kPrimControlDepend)) &&
|
|
|
|
|
!(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) {
|
|
|
|
|
return false;
|
|
|
|
|
} else {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
});
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
@ -487,7 +539,8 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
bool changed = false;
|
|
|
|
|
auto topo_nodes = TopoSort(kernel_graph->get_return());
|
|
|
|
|
for (const auto &node : topo_nodes) {
|
|
|
|
|
if (!AnfAlgo::IsGraphKernel(node) || !CanActivateAtomicAdd(node)) {
|
|
|
|
|
if (!AnfAlgo::IsGraphKernel(node) || !CanActivateAtomicAdd(node) ||
|
|
|
|
|
!IsExistStructuralObstacle(kernel_graph, node, mng)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
InsertAtomicClean(kernel_graph, node, mng);
|
|
|
|
@ -495,6 +548,14 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (changed) {
|
|
|
|
|
if (!to_process_order_.empty()) {
|
|
|
|
|
auto [patron_node, patron_user, user_index] = FindPatronNode(kernel_graph);
|
|
|
|
|
for (const auto &[prior, behind] : to_process_order_) {
|
|
|
|
|
patron_node = AddControlDepend(kernel_graph, prior, behind, patron_node);
|
|
|
|
|
}
|
|
|
|
|
PostprocessForLastPatron(patron_node, patron_user, user_index);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
mng->RemoveRoots();
|
|
|
|
|
mng->KeepRoots({func_graph});
|
|
|
|
|
}
|
|
|
|
|