!10257 【GraphKernel】Enhance the fusion capacity for getitem nodes

From: @dayschan
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
pull/10257/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 639e0c5fbd

@ -31,12 +31,8 @@ class GraphSplitByPattern:
self.ops = [init_op]
self.in_relations = dict() # {area1: relation1, area2: relation2, ...}
self.out_relations = dict() # {area1: relation1, area2: relation2, ...}
self.mode = self.MODE_BASIC
if self.pattern == PrimLib.TRANSFORM or self.pattern == PrimLib.BROADCAST or \
(use_poly_reduce and self.pattern == PrimLib.REDUCE):
self.mode = self.MODE_COMPOSITE
if init_op.prim == "AddN":
self.mode = self.MODE_COMPOSITE
self.mode = None
self.set_default_mode()
self.is_output = is_output
self.output_excluded = set()
if self.pattern == PrimLib.REDUCE:
@ -55,6 +51,17 @@ class GraphSplitByPattern:
def __repr__(self):
return str(self)
def set_default_mode(self):
def _get_default_mode(op):
if op.prim == "AddN":
return self.MODE_COMPOSITE
pattern = PrimLib.iter_type(op)
if pattern == PrimLib.TRANSFORM or pattern == PrimLib.BROADCAST or \
(use_poly_reduce and pattern == PrimLib.REDUCE):
return self.MODE_COMPOSITE
return self.MODE_BASIC
self.mode = _get_default_mode(self.ops[0])
def get_relation(self, op, i):
relation = PrimLib.UNKNOWN
_, elem_relation = PrimLib.input_relation(op, i)
@ -359,9 +366,40 @@ class GraphSplitByPattern:
if use_poly_reduce:
changed = self.fuse(_reduce_output) or changed
self.fuse(_transpose)
# The reshape should not be output node
# Note: after this function, the input output relation is not maintained.
self.split_output_reshapes()
subgraphs, graphmodes = self.to_subgraphs()
return subgraphs, graphmodes
def split_output_reshapes(self):
"""Force split the output reshapes into other new """
new_areas = []
for area in self.areas:
out_reshape_ops = [op for op in area.ops if PrimLib.iter_type(op) == PrimLib.RESHAPE]
remain_ops = [op for op in area.ops if op not in out_reshape_ops]
if not remain_ops or not out_reshape_ops:
continue
changed = True
while changed:
changed = False
for op in out_reshape_ops:
if any([to_op in remain_ops for to_op in op.output.to_ops]):
out_reshape_ops.remove(op)
remain_ops.append(op)
changed = True
break
if out_reshape_ops:
for op in out_reshape_ops:
new_areas.append(self.Area(op, False))
area.ops = remain_ops
if len(remain_ops) == 1:
area.set_default_mode()
if new_areas:
self.areas += new_areas
def split(graph):
"""Split graph"""

@ -97,15 +97,6 @@ AnfNodePtrList RemoveWildGetitem(const AnfNodePtrList &fused_op) {
remove_list = DeepLinkedGraphSearch(getitem, check_include);
break;
}
// To fix the issue of getitem-index, only support to fuse the previous node with its all users.
const auto &brothers = mng->node_users()[prev_node];
if (std::any_of(brothers.begin(), brothers.end(), [check_include](const std::pair<AnfNodePtr, int> &user) {
return check_include(user.first) == EXCLUDE;
})) {
remove_list = DeepLinkedGraphSearch(getitem, check_include);
break;
}
}
if (!remove_list.empty()) {
for (auto node : remove_list) {

@ -41,6 +41,16 @@ inline size_t GetIndex(const AnfNodePtr &getitem_node) {
getitem_node->cast<CNodePtr>()->input(kInputNodeOutputIndexInTupleGetItem)->cast<ValueNodePtr>()->value()));
}
void SetIndex(const AnfNodePtr &getitem_node, size_t index) {
auto getitem = getitem_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(getitem);
auto idx_node = NewValueNode(MakeValue<int64_t>(SizeToLong(index)));
auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
idx_node->set_abstract(abstract);
idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
}
bool GetGraphKernelGetitemList(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, AnfNodePtrList *getitem_list,
bool merge_repeated_getitem = false) {
MS_EXCEPTION_IF_NULL(mng);
@ -85,7 +95,69 @@ AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
return result;
}
/* Merge the get_item nodes that have same index.
/* Unify the repeated output in a func_graph.
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 1)
* graph_kernel:
* %1 = TensorAdd(p1, p2)
* %2 = Reshape(%1)
* return make_tuple(%2, %2)
* -->
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 0) // changed the index to 0.
* graph_kernel:
* %1 = TensorAdd(p1, p2)
* %2 = Reshape(%1)
* return make_tuple(%2, %2)
*/
class UnifyRepeatedOutput : public Pass {
public:
bool Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
MS_EXCEPTION_IF_NULL(mng);
auto todos = FindGraphKernelsWithMultiOutput(func_graph);
bool changed = false;
for (auto node : todos) {
if (CheckRepeatedOutput(AnfAlgo::GetCNodeFuncGraphPtr(node))) {
changed = true;
AnfNodePtrList getitem_list;
GetGraphKernelGetitemList(mng, node, &getitem_list, false);
if (getitem_list.size() != index_map_.size()) {
MS_LOG(EXCEPTION) << "getitem_list.size (" << getitem_list.size() << ") should be equal to index_map.size ("
<< index_map_.size() << ").";
}
for (size_t i = 0; i < index_map_.size(); ++i) {
if (index_map_[i] != i && getitem_list[i] != nullptr) {
SetIndex(getitem_list[i], index_map_[i]);
}
}
}
}
return changed;
}
private:
bool CheckRepeatedOutput(const FuncGraphPtr &sub_func_graph) {
// the output should be a MakeTuple.
auto maketuple = sub_func_graph->output()->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(maketuple);
AnfNodePtrList outputs(maketuple->inputs().begin() + 1, maketuple->inputs().end());
index_map_.resize(outputs.size());
bool found = false;
for (size_t i = 0; i < outputs.size(); ++i) {
index_map_[i] = std::find(outputs.begin(), outputs.begin() + i, outputs[i]) - outputs.begin();
if (index_map_[i] != i) {
found = true;
}
}
return found;
}
std::vector<size_t> index_map_;
};
/* Unify the get_item nodes that have same index.
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %3 = tuple_getitem(%1, 0)
@ -95,13 +167,13 @@ AnfNodePtrList FindGraphKernelsWithMultiOutput(const FuncGraphPtr &func_graph) {
* %7 = user_z(%4)
* --->
* %1 = call @graph_kernel(p1, p2)
* %2 = tuple_getitem(%1, 0)
* %2 = tuple_getitem(%1, 0) // unify the original %2 and %3
* %3 = tuple_getitem(%1, 1)
* %4 = user_x(%2)
* %5 = user_y(%2)
* %6 = user_z(%3)
*/
class MergeRepeatedGetitem : public Pass {
class UnifyRepeatedGetitem : public Pass {
public:
bool Run(const FuncGraphPtr &func_graph) {
auto mng = func_graph->manager();
@ -237,26 +309,24 @@ bool EliminateRedundantOutput::Run(const FuncGraphPtr &func_graph) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
bool changed = std::make_shared<MergeRepeatedGetitem>()->Run(func_graph);
bool changed = false;
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
changed = std::make_shared<UnifyRepeatedOutput>()->Run(func_graph) || changed;
changed = std::make_shared<UnifyRepeatedGetitem>()->Run(func_graph) || changed;
changed = std::make_shared<EliminateGetitemForControlDepend>()->Run(func_graph) || changed;
changed = Process(func_graph) || changed;
return changed;
}
void EliminateRedundantOutput::UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset) {
void EliminateRedundantOutput::UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset) {
if (offset == 0) return;
MS_EXCEPTION_IF_NULL(getitem);
int64_t index = SizeToLong(GetIndex(getitem));
auto index = GetIndex(getitem);
if (offset > index) {
MS_LOG(EXCEPTION) << "The offset is greater than the original index of GetItem: " << getitem->DebugString();
}
index -= offset;
auto idx_node = NewValueNode(MakeValue<int64_t>(index));
auto abstract = std::make_shared<abstract::AbstractScalar>(std::make_shared<Int64Imm>(index));
idx_node->set_abstract(abstract);
idx_node->set_kernel_info(std::make_shared<device::KernelInfo>());
getitem->set_input(kInputNodeOutputIndexInTupleGetItem, idx_node);
SetIndex(getitem, index);
}
AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems) {
@ -266,14 +336,14 @@ AnfNodePtr EliminateRedundantOutput::ReplaceMakeTuple(const AnfNodePtr &node, co
MS_EXCEPTION_IF_NULL(old_maketuple);
AnfNodePtrList new_maketuple_inputs = {NewValueNode(prim::kPrimMakeTuple)};
AbstractBasePtrList abstract_list;
int64_t offset = 0;
size_t offset = 0;
for (size_t i = 0; i < getitems.size(); ++i) {
if (getitems[i] == nullptr) {
offset++;
} else {
new_maketuple_inputs.push_back(old_maketuple->input(i + 1));
abstract_list.push_back(old_maketuple->input(i + 1)->abstract());
UpdateGetitemIndex(getitems[i]->cast<CNodePtr>(), offset);
UpdateGetitemIndex(getitems[i], offset);
}
}
if (offset == 0) return nullptr;

@ -28,7 +28,7 @@ class EliminateRedundantOutput : public Pass {
private:
bool Process(const FuncGraphPtr &func_graph);
void UpdateGetitemIndex(const CNodePtr &getitem, int64_t offset);
void UpdateGetitemIndex(const AnfNodePtr &getitem, size_t offset);
AnfNodePtr ReplaceMakeTuple(const AnfNodePtr &node, const AnfNodePtrList &getitems);
};
} // namespace opt

@ -496,7 +496,7 @@ void ReplaceNewFuseCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &new_f
fn_inputs.clear();
fn_inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
fn_inputs.push_back(new_fuse_cnode);
fn_inputs.push_back(NewValueNode(MakeValue(SizeToLong(out_idx))));
fn_inputs.push_back(NewValueNode(MakeValue(SizeToLong(out_idx + offset))));
auto new_out = func_graph->NewCNode(fn_inputs);
new_out->set_abstract(outputs[out_idx]->abstract());
mng->Replace(outputs[out_idx], new_out);

@ -544,8 +544,7 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
}
func_graph_ = func_graph;
this->Run();
if (split_plan_.empty()) return false;
return split_plan_.size() > 1 || NeedInline(0);
return !split_plan_.empty();
}
bool NeedInline(size_t group_id) const override {
@ -630,7 +629,12 @@ class CostModelSplitSchemer : public Splitter::SplitSchemer {
}
GetValidKernelNodes();
// call CostModel to get a split plan.
if (!SplitByCostModel()) {
if (!SplitByCostModel() || split_plan_.size() != need_inline_.size()) {
split_plan_.clear();
need_inline_.clear();
return;
} else if (split_plan_.size() == 1 && !NeedInline(0)) {
/*In this case, the CostModel decided to keep the whole graph unchanged.*/
split_plan_.clear();
need_inline_.clear();
return;

@ -188,6 +188,8 @@ void GPUSession::GraphKernelOptimize(const std::shared_ptr<KernelGraph> &kernel_
pm->AddPass(std::make_shared<opt::TensorPromotion>());
pm->AddPass(std::make_shared<opt::GraphKernelSplitter>());
pm->AddPass(std::make_shared<opt::GraphKernelCSE>());
// The CSE may output a graph with repeated outputs.
pm->AddPass(std::make_shared<opt::EliminateRedundantOutput>());
// After Simplify and Splitter, a lot of redundant getitem/maketuple
// will be exposed, use GetitemTuple Pass to delete them.
pm->AddPass(std::make_shared<opt::GetitemTuple>());

Loading…
Cancel
Save