change atomic add struct and add new condition for controldepend

pull/9594/head
tronzhang 4 years ago
parent 2ced65ece3
commit 2b88731417

@ -24,6 +24,7 @@
#include <set>
#include <stack>
#include <string>
#include <tuple>
#include <vector>
#include "base/core_ops.h"
#include "ir/tensor.h"
@ -304,27 +305,52 @@ void AtomicCleanInsertter::AddDepend(const FuncGraphPtr &main_graph, const AnfNo
user_cnode->set_input(index, depend_cnode);
}
void AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &pre_node,
const AnfNodePtr &post_node, const FuncGraphManagerPtr &mng) {
// Collect use dependencies firstly.
auto post_users = mng->node_users()[post_node];
AnfNodePtr AtomicCleanInsertter::AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node,
const AnfNodePtr &behind_node, const AnfNodePtr &patron_node) {
// Create control depend, first input is composite op, second is user
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), pre_node, post_node};
AnfNodePtrList cd_inputs = {NewValueNode(prim::kPrimControlDepend), prior_node, behind_node};
auto control_depend_cnode = main_graph->NewCNode(cd_inputs);
main_graph->AddNode(control_depend_cnode);
// Create depend node to hold new control depend node.
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), post_node, control_depend_cnode};
AnfNodePtrList d_inputs = {NewValueNode(prim::kPrimDepend), patron_node, control_depend_cnode};
auto depend_cnode = main_graph->NewCNode(d_inputs);
depend_cnode->set_abstract(post_node->abstract());
depend_cnode->set_abstract(patron_node->abstract());
main_graph->AddNode(depend_cnode);
for (const auto &[user_node, index] : post_users) {
auto user_cnode = user_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(index, depend_cnode);
return depend_cnode;
}
std::tuple<AnfNodePtr, AnfNodePtr, int> AtomicCleanInsertter::FindPatronNode(const KernelGraphPtr &main_graph) {
auto mng = main_graph->manager();
if (mng == nullptr) {
mng = Manage(main_graph, true);
main_graph->set_manager(mng);
}
AnfNodePtr patron_node;
auto return_cnode = main_graph->get_return()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(return_cnode);
auto output_node = return_cnode->input(kFirstDataInputIndex);
if (IsPrimitiveCNode(output_node, prim::kPrimMakeTuple)) {
auto output_cnode = output_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
patron_node = output_cnode->input(kFirstDataInputIndex);
} else {
patron_node = output_node;
}
auto &user_nodes = mng->node_users()[patron_node];
auto user = user_nodes.begin();
return std::make_tuple(patron_node, user->first, user->second);
}
void AtomicCleanInsertter::PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user,
int index) {
auto patron_user_cnode = patron_user->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(patron_user_cnode);
patron_user_cnode->set_input(index, patron_node);
}
CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type) {
@ -380,14 +406,14 @@ CNodePtr AtomicCleanInsertter::CreateAtomicCleanCompositeNode(const KernelGraphP
kernel::Processor::CUDA);
auto graph_attr = ExtractGraphKernelName(TopoSort(new_sub_graph->get_return()), "", "atomic_clean");
new_sub_graph->set_attr(FUNC_GRAPH_ATTR_GRAPH_KERNEL, MakeValue(graph_attr));
// mng->AddFuncGraph(new_sub_graph);
return broadcast_to_composite_node;
}
void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) {
// 1. find users, change getitem index if needed.
std::vector<std::pair<AnfNodePtr, int> > AtomicCleanInsertter::FindOriginCNodeUsers(const KernelGraphPtr &main_graph,
const AnfNodePtr &composite_node,
const FuncGraphManagerPtr &mng,
bool correct_index) {
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes;
if (real_output_num_ <= 1) {
auto users = mng->node_users()[composite_node];
@ -409,7 +435,7 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra
auto item_idx = GetValue<int64_t>(value_node->value());
if (item_idx == static_cast<int64_t>(reduce_real_output_index_)) {
getitem_user_nodes.push_back(node_index);
} else {
} else if (correct_index) {
if (real_output_num_ > 2) {
// Recorrect other getitem index.
int64_t new_item_idx = CalNewIndex(item_idx, reduce_real_output_index_);
@ -431,7 +457,6 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra
}
}
}
for (auto &pair : getitem_user_nodes) {
// dirctory to find real user.
auto real_users = mng->node_users()[pair.first];
@ -439,12 +464,16 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra
}
}
return reduce_user_nodes;
}
void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng) {
// 1. find users, change getitem index if needed.
std::vector<std::pair<AnfNodePtr, int> > reduce_user_nodes =
FindOriginCNodeUsers(main_graph, composite_node, mng, true);
for (const auto &[user_node, index] : reduce_user_nodes) {
// 2. set ac output as user's input.
auto user_cnode = user_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(index, broadcast_to_node);
// mng->SetEdge(user_node, index, broadcast_to_node);
// 3. Make sure modified composite node running first.
// * To not change the origin node's dependency relation, add ControlDepend and Depend node.
// * For Return node and output node, ControlDepend node will change the order of these two node, which will may
@ -452,7 +481,10 @@ void AtomicCleanInsertter::ProcessOriginCNodeUser(const KernelGraphPtr &main_gra
if (IsPrimitiveCNode(user_node, prim::kPrimReturn) || user_node == main_graph->output()) {
AddDepend(main_graph, broadcast_to_node, composite_node, user_node, index);
} else {
AddControlDepend(main_graph, composite_node, user_node, mng);
auto user_cnode = user_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(user_cnode);
user_cnode->set_input(index, broadcast_to_node);
to_process_order_.emplace_back(composite_node, user_node);
}
}
}
@ -473,6 +505,26 @@ void AtomicCleanInsertter::InsertAtomicClean(const KernelGraphPtr &main_graph, c
// Replace origin ReduceSum's user with atomic clean output, and add control depend from composite op to user.
ProcessOriginCNodeUser(main_graph, origin_composite_node, broadcast_to_node, mng);
MS_LOG(INFO) << "Target node: " << origin_composite_node->fullname_with_scope()
<< ", clean node: " << broadcast_to_node->fullname_with_scope();
}
bool AtomicCleanInsertter::IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node,
const FuncGraphManagerPtr &mng) {
auto reduce_users = FindOriginCNodeUsers(main_graph, node, mng, false);
// If reduce user is MakeTuple and not last node, there is no cheap method to set right running order between reduce
// node and user node. If reduce is Depend or ControlDepend node, the origin node may be wrong!
return std::all_of(reduce_users.cbegin(), reduce_users.cend(),
[&main_graph](const std::pair<AnfNodePtr, int> &user_info) -> bool {
auto &user = user_info.first;
if ((IsPrimitiveCNode(user, prim::kPrimMakeTuple) || IsPrimitiveCNode(user, prim::kPrimDepend) ||
IsPrimitiveCNode(user, prim::kPrimControlDepend)) &&
!(IsPrimitiveCNode(user, prim::kPrimReturn) || user == main_graph->output())) {
return false;
} else {
return true;
}
});
}
bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
@ -487,7 +539,8 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
bool changed = false;
auto topo_nodes = TopoSort(kernel_graph->get_return());
for (const auto &node : topo_nodes) {
if (!AnfAlgo::IsGraphKernel(node) || !CanActivateAtomicAdd(node)) {
if (!AnfAlgo::IsGraphKernel(node) || !CanActivateAtomicAdd(node) ||
!IsExistStructuralObstacle(kernel_graph, node, mng)) {
continue;
}
InsertAtomicClean(kernel_graph, node, mng);
@ -495,6 +548,14 @@ bool AtomicCleanInsertter::Run(const FuncGraphPtr &func_graph) {
}
if (changed) {
if (!to_process_order_.empty()) {
auto [patron_node, patron_user, user_index] = FindPatronNode(kernel_graph);
for (const auto &[prior, behind] : to_process_order_) {
patron_node = AddControlDepend(kernel_graph, prior, behind, patron_node);
}
PostprocessForLastPatron(patron_node, patron_user, user_index);
}
mng->RemoveRoots();
mng->KeepRoots({func_graph});
}

@ -18,6 +18,8 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_ADD_ATOMIC_CLEAN_GPU_H_
#include <memory>
#include <tuple>
#include <utility>
#include <vector>
#include "backend/optimizer/common/optimizer.h"
#include "backend/session/kernel_graph.h"
@ -37,18 +39,26 @@ class AtomicCleanInsertter : public Pass {
void InsertAtomicClean(const KernelGraphPtr &main_graph, const AnfNodePtr &anf_node, const FuncGraphManagerPtr &mng);
void AddDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &clean_node, const AnfNodePtr &composite_node,
const AnfNodePtr &user_node, int index);
void AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &pre_node, const AnfNodePtr &post_node,
const FuncGraphManagerPtr &mng);
void CreateInplaceAssignNodeAndCorrectReturn(const FuncGraphPtr &sub_graph, const AnfNodePtr &new_parameter);
void CorrectAbstract(const AnfNodePtr &composite_node);
void CorrectKernelBuildInfo(const AnfNodePtr &composite_node, const AnfNodePtr &new_input);
CNodePtr CreateAtomicCleanCompositeNode(const KernelGraphPtr &main_graph, TypeId dst_type);
void ProcessOriginCNodeUser(const KernelGraphPtr &main_graph, const AnfNodePtr &composite_node,
const AnfNodePtr &broadcast_to_node, const FuncGraphManagerPtr &mng);
std::tuple<AnfNodePtr, AnfNodePtr, int> FindPatronNode(const KernelGraphPtr &main_graph);
AnfNodePtr AddControlDepend(const FuncGraphPtr &main_graph, const AnfNodePtr &prior_node,
const AnfNodePtr &behind_node, const AnfNodePtr &patron_node);
void PostprocessForLastPatron(const AnfNodePtr &patron_node, const AnfNodePtr &patron_user, int index);
std::vector<std::pair<AnfNodePtr, int>> FindOriginCNodeUsers(const KernelGraphPtr &main_graph,
const AnfNodePtr &composite_node,
const FuncGraphManagerPtr &mng, bool correct_index);
bool IsExistStructuralObstacle(const KernelGraphPtr &main_graph, const AnfNodePtr &node,
const FuncGraphManagerPtr &mng);
CNodePtr atomic_add_node_{nullptr};
size_t reduce_real_output_index_{0};
size_t real_output_num_{0};
std::vector<std::pair<AnfNodePtr, AnfNodePtr>> to_process_order_;
};
using AtomicCleanInsertterPtr = std::shared_ptr<AtomicCleanInsertter>;
} // namespace opt

Loading…
Cancel
Save