!9678 【GraphKernel】Fix precision problem

From: @dayschan
Reviewed-by: @gaoxiong1,@ckey_dou
Signed-off-by: @ckey_dou
pull/9678/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit c1f65f7460

2
akg

@ -1 +1 @@
Subproject commit f8f4e60bf3c435cec41cbe48fe24901277ef9556
Subproject commit 72b359ad457ed8f4f254c8a3bd2bde88967202fb

@ -283,6 +283,9 @@ class GraphSplitByPattern:
if _check_reduce_exclude(dom):
return None
a, r = list(dom.in_relations.items())[0]
if a.is_output and len(a.ops) >= 10 and _is_atomic_add_available(dom):
# to evade the precision problem in akg
return None
if _reduce_pat_exclude(dom, a, r) or len(a.out_relations) != 1:
return None
return [a], True
@ -292,6 +295,8 @@ class GraphSplitByPattern:
return None
if _check_reduce_exclude(dom):
return None
if len(dom.ops) == 1:
return None
fused = []
for a, r in dom.in_relations.items():
if not _reduce_pat_exclude(dom, a, r) and a.check_circle(dom):
@ -304,7 +309,6 @@ class GraphSplitByPattern:
size *= i
return size
def _reduce_output(dom):
def _is_atomic_add_available(dom):
if any(["Reduce" in x.prim for x in dom.ops[1:]]):
return False
@ -314,6 +318,8 @@ class GraphSplitByPattern:
reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis])
return reduce_size >= 1024
return True
def _reduce_output(dom):
if dom.pattern != PrimLib.REDUCE:
return None
if _is_atomic_add_available(dom):

Loading…
Cancel
Save