|
|
|
@ -13,11 +13,12 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
# ===========================================================================
|
|
|
|
|
"""Cost model splitter"""
|
|
|
|
|
|
|
|
|
|
from functools import reduce
|
|
|
|
|
from .model import PrimLib, Graph, Tensor
|
|
|
|
|
|
|
|
|
|
use_poly_reduce = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraphSplitByPattern:
|
|
|
|
|
"""Graph splitter"""
|
|
|
|
|
class Area:
|
|
|
|
@ -33,6 +34,8 @@ class GraphSplitByPattern:
|
|
|
|
|
self.mode = self.MODE_BASIC
|
|
|
|
|
if self.pattern == PrimLib.TRANSFORM or (use_poly_reduce and self.pattern == PrimLib.REDUCE):
|
|
|
|
|
self.mode = self.MODE_COMPOSITE
|
|
|
|
|
if init_op.prim == "AddN":
|
|
|
|
|
self.mode = self.MODE_COMPOSITE
|
|
|
|
|
self.is_output = is_output
|
|
|
|
|
self.output_excluded = set()
|
|
|
|
|
if self.pattern == PrimLib.REDUCE:
|
|
|
|
@ -196,7 +199,7 @@ class GraphSplitByPattern:
|
|
|
|
|
min_area, forward_fuse = None, False
|
|
|
|
|
for a, _ in dom.out_relations.items():
|
|
|
|
|
if a.pattern <= PrimLib.BROADCAST and dom.check_circle(a) and \
|
|
|
|
|
(min_area is None or a.pattern < min_area.pattern):
|
|
|
|
|
(min_area is None or a.pattern < min_area.pattern):
|
|
|
|
|
min_area = a
|
|
|
|
|
for a, _ in dom.in_relations.items():
|
|
|
|
|
if a.pattern <= PrimLib.BROADCAST and a.check_circle(dom) and \
|
|
|
|
@ -210,7 +213,7 @@ class GraphSplitByPattern:
|
|
|
|
|
return None
|
|
|
|
|
a, r = list(dom.in_relations.items())[0]
|
|
|
|
|
if a.pattern > PrimLib.BROADCAST or len(a.out_relations) != 1 or r != PrimLib.ELEMWISE or \
|
|
|
|
|
a.dom_op().output.shape != dom.dom_op().output.shape:
|
|
|
|
|
a.dom_op().output.shape != dom.dom_op().output.shape:
|
|
|
|
|
return None
|
|
|
|
|
return [a], True
|
|
|
|
|
|
|
|
|
@ -220,7 +223,7 @@ class GraphSplitByPattern:
|
|
|
|
|
fused = []
|
|
|
|
|
for a, r in dom.in_relations.items():
|
|
|
|
|
if a.pattern <= PrimLib.BROADCAST and r == PrimLib.ELEMWISE and a.check_circle(dom) and \
|
|
|
|
|
a.dom_op().output.shape == dom.dom_op().output.shape:
|
|
|
|
|
a.dom_op().output.shape == dom.dom_op().output.shape:
|
|
|
|
|
fused.append(a)
|
|
|
|
|
return fused, True
|
|
|
|
|
|
|
|
|
@ -231,7 +234,7 @@ class GraphSplitByPattern:
|
|
|
|
|
|
|
|
|
|
def _broadcast_depth(dom):
|
|
|
|
|
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or len(dom.out_relations) != 1 or \
|
|
|
|
|
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
|
|
|
|
|
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
|
|
|
|
|
return None
|
|
|
|
|
a, r = list(dom.out_relations.items())[0]
|
|
|
|
|
if _broadcast_pat_exclude(dom, a, r) or len(a.in_relations) != 1:
|
|
|
|
@ -240,12 +243,12 @@ class GraphSplitByPattern:
|
|
|
|
|
|
|
|
|
|
def _broadcast_width(dom):
|
|
|
|
|
if dom.pattern not in (PrimLib.ELEMWISE, PrimLib.BROADCAST) or \
|
|
|
|
|
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
|
|
|
|
|
dom.is_output or len(dom.ops) > self.BORADCAST_FUSE_DEPTH:
|
|
|
|
|
return None
|
|
|
|
|
fused = []
|
|
|
|
|
for a, r in dom.out_relations.items():
|
|
|
|
|
if _broadcast_pat_exclude(dom, a, r) or not dom.check_circle(a) or \
|
|
|
|
|
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
|
|
|
|
|
(fused and fused[0].dom_op().output.shape != a.dom_op().output.shape):
|
|
|
|
|
return None
|
|
|
|
|
fused.append(a)
|
|
|
|
|
return fused, False
|
|
|
|
@ -301,8 +304,19 @@ class GraphSplitByPattern:
|
|
|
|
|
return size
|
|
|
|
|
|
|
|
|
|
def _reduce_output(dom):
|
|
|
|
|
def _is_atomic_add_available(dom):
|
|
|
|
|
if any(["Reduce" in x.prim for x in dom.ops[1:]]):
|
|
|
|
|
return False
|
|
|
|
|
op = dom.ops[0]
|
|
|
|
|
reduce_axis = op.attrs["reduce_axis"]
|
|
|
|
|
if len(op.inputs[0].shape) - 1 in reduce_axis:
|
|
|
|
|
reduce_size = reduce(lambda x, y: x * y, [op.inputs[0].shape[i] for i in reduce_axis])
|
|
|
|
|
return reduce_size >= 1024
|
|
|
|
|
return True
|
|
|
|
|
if dom.pattern != PrimLib.REDUCE:
|
|
|
|
|
return None
|
|
|
|
|
if _is_atomic_add_available(dom):
|
|
|
|
|
return None
|
|
|
|
|
is_all_reduce = _tensor_size(dom.ops[0].output) == 1
|
|
|
|
|
# excluded large size all reduce
|
|
|
|
|
if is_all_reduce and _tensor_size(dom.ops[0].inputs[0]) > 1024 * 12:
|
|
|
|
@ -310,7 +324,7 @@ class GraphSplitByPattern:
|
|
|
|
|
fused = []
|
|
|
|
|
for a, r in dom.out_relations.items():
|
|
|
|
|
if a.pattern <= PrimLib.BROADCAST and r <= PrimLib.BROADCAST and \
|
|
|
|
|
dom.check_circle(a) and not dom.reduce_out_exclude(a):
|
|
|
|
|
dom.check_circle(a) and not dom.reduce_out_exclude(a):
|
|
|
|
|
fused.append(a)
|
|
|
|
|
return fused, False
|
|
|
|
|
|
|
|
|
|