From 297f075dca8276cf8efc80afae6b4999e687c037 Mon Sep 17 00:00:00 2001 From: dayschan Date: Tue, 8 Dec 2020 22:10:21 +0800 Subject: [PATCH] Fix precision problem --- akg | 2 +- .../graph_kernel/model/graph_split.py | 26 ++++++++++++------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/akg b/akg index f8f4e60bf3..72b359ad45 160000 --- a/akg +++ b/akg @@ -1 +1 @@ -Subproject commit f8f4e60bf3c435cec41cbe48fe24901277ef9556 +Subproject commit 72b359ad457ed8f4f254c8a3bd2bde88967202fb diff --git a/mindspore/_extends/graph_kernel/model/graph_split.py b/mindspore/_extends/graph_kernel/model/graph_split.py index 3ae0167223..8ca1ac8064 100644 --- a/mindspore/_extends/graph_kernel/model/graph_split.py +++ b/mindspore/_extends/graph_kernel/model/graph_split.py @@ -33,7 +33,7 @@ class GraphSplitByPattern: self.out_relations = dict() # {area1: relation1, area2: relation2, ...} self.mode = self.MODE_BASIC if self.pattern == PrimLib.TRANSFORM or self.pattern == PrimLib.BROADCAST or \ - (use_poly_reduce and self.pattern == PrimLib.REDUCE): + (use_poly_reduce and self.pattern == PrimLib.REDUCE): self.mode = self.MODE_COMPOSITE if init_op.prim == "AddN": self.mode = self.MODE_COMPOSITE @@ -283,6 +283,9 @@ class GraphSplitByPattern: if _check_reduce_exclude(dom): return None a, r = list(dom.in_relations.items())[0] + if a.is_output and len(a.ops) >= 10 and _is_atomic_add_available(dom): + # to evade the precision problem in akg + return None if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1: return None return [a], True @@ -292,6 +295,8 @@ class GraphSplitByPattern: return None if _check_reduce_exclude(dom): return None + if len(dom.ops) == 1: + return None fused = [] for a, r in dom.in_relations.items(): if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom): @@ -304,16 +309,17 @@ class GraphSplitByPattern: size *= i return size + def _is_atomic_add_available(dom): + if any(["Reduce" in x.prim for x in dom.ops[1:]]): + return False + op = dom.ops[0] + reduce_axis = op.attrs["reduce_axis"] + if len(op.inputs[0].shape) - 1 in reduce_axis: + reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis]) + return reduce_size >= 1024 + return True + def _reduce_output(dom): - def _is_atomic_add_available(dom): - if any(["Reduce" in x.prim for x in dom.ops[1:]]): - return False - op = dom.ops[0] - reduce_axis = op.attrs["reduce_axis"] - if len(op.inputs[0].shape) - 1 in reduce_axis: - reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis]) - return reduce_size >= 1024 - return True if dom.pattern != PrimLib.REDUCE: return None if _is_atomic_add_available(dom):