|
|
|
|
@ -15,139 +15,8 @@
|
|
|
|
|
"""GraphKernel model builder"""
|
|
|
|
|
|
|
|
|
|
import copy
|
|
|
|
|
from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy, DataFormat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_tile_output_shape(shape, multiples):
|
|
|
|
|
"""compute output shape of tile"""
|
|
|
|
|
|
|
|
|
|
if multiples is None:
|
|
|
|
|
return shape
|
|
|
|
|
if not isinstance(shape, (list, tuple)):
|
|
|
|
|
raise TypeError("Input shape of Tile must be of type list or tuple")
|
|
|
|
|
if not isinstance(multiples, (list, tuple)):
|
|
|
|
|
raise TypeError("multiples of Tile must be of type list or tuple")
|
|
|
|
|
|
|
|
|
|
shape = list(shape)
|
|
|
|
|
multiples = list(multiples)
|
|
|
|
|
diff_len = len(multiples) - len(shape)
|
|
|
|
|
if diff_len < 0:
|
|
|
|
|
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
|
|
|
|
|
if diff_len > 0:
|
|
|
|
|
for _ in range(diff_len):
|
|
|
|
|
shape.insert(0, 1)
|
|
|
|
|
|
|
|
|
|
shape_compatible = True
|
|
|
|
|
output_shape = []
|
|
|
|
|
input_reshape = []
|
|
|
|
|
output_reshape = []
|
|
|
|
|
for sh, mul in list(zip(shape, multiples)):
|
|
|
|
|
dim = sh * mul
|
|
|
|
|
output_shape.append(dim)
|
|
|
|
|
if sh == 1 or mul == 1:
|
|
|
|
|
input_reshape.append(sh)
|
|
|
|
|
output_reshape.append(dim)
|
|
|
|
|
else:
|
|
|
|
|
shape_compatible = False
|
|
|
|
|
input_reshape.append(1)
|
|
|
|
|
input_reshape.append(sh)
|
|
|
|
|
output_reshape.append(mul)
|
|
|
|
|
output_reshape.append(sh)
|
|
|
|
|
|
|
|
|
|
return output_shape, input_reshape, output_reshape, shape_compatible
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpInfer:
|
|
|
|
|
"""Op infer"""
|
|
|
|
|
@staticmethod
|
|
|
|
|
def default_reduce_infer(inputs, attrs):
|
|
|
|
|
"""Default reduce infer"""
|
|
|
|
|
shape = copy.deepcopy(inputs[0].shape)
|
|
|
|
|
if attrs['keep_dims']:
|
|
|
|
|
for i in attrs['reduce_axis']:
|
|
|
|
|
shape[i] = 1
|
|
|
|
|
return shape
|
|
|
|
|
|
|
|
|
|
real_shape = []
|
|
|
|
|
for i, _ in enumerate(shape):
|
|
|
|
|
if i not in attrs['reduce_axis'] and i - len(shape) not in attrs['reduce_axis']:
|
|
|
|
|
real_shape.append(shape[i])
|
|
|
|
|
return real_shape
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def default_elementwise_infer(inputs, attrs):
|
|
|
|
|
"""Default elementwise infer"""
|
|
|
|
|
shape = (1,)
|
|
|
|
|
max_flatten_shape = 1
|
|
|
|
|
for t in inputs:
|
|
|
|
|
flatten_shape = 1
|
|
|
|
|
for s in t.shape:
|
|
|
|
|
flatten_shape *= s
|
|
|
|
|
if flatten_shape >= max_flatten_shape:
|
|
|
|
|
max_flatten_shape = flatten_shape
|
|
|
|
|
shape = t.shape
|
|
|
|
|
return shape
|
|
|
|
|
|
|
|
|
|
default_infer_shape_func = [
|
|
|
|
|
None,
|
|
|
|
|
None,
|
|
|
|
|
default_elementwise_infer.__func__,
|
|
|
|
|
lambda inputs, attrs: max([t.shape for t in inputs]),
|
|
|
|
|
default_reduce_infer.__func__,
|
|
|
|
|
None,
|
|
|
|
|
lambda inputs, attrs: [1], # control op
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def default_infer_dtype_func(inputs, attrs):
|
|
|
|
|
"""Infer dtype"""
|
|
|
|
|
return inputs[0].dtype
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def default_infer_format_func(inputs, attrs):
|
|
|
|
|
"""Infer format"""
|
|
|
|
|
result = inputs[0].data_format
|
|
|
|
|
# default_format and other_format results in other_format
|
|
|
|
|
for input_tensor in inputs[1:]:
|
|
|
|
|
data_format = input_tensor.data_format
|
|
|
|
|
if data_format != DataFormat.DEFAULT:
|
|
|
|
|
if result not in [DataFormat.DEFAULT, data_format]:
|
|
|
|
|
raise RuntimeError("Incompatible data format %s and %s" % (data_format, result))
|
|
|
|
|
result = data_format
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
infer_shape_func = {
|
|
|
|
|
# add special infer func here
|
|
|
|
|
'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
|
|
|
|
|
'Reshape': lambda inputs, attrs: attrs["shape"],
|
|
|
|
|
'BroadcastTo': lambda inputs, attrs: attrs["shape"],
|
|
|
|
|
'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0],
|
|
|
|
|
'ExpandDims': lambda inputs, attrs: list(inputs[0].shape).insert(attrs["axis"], 1),
|
|
|
|
|
}
|
|
|
|
|
infer_dtype_func = {
|
|
|
|
|
# add special infer func here
|
|
|
|
|
'Cast': lambda inputs, attrs: attrs['dst_type'],
|
|
|
|
|
'Less': lambda inputs, attrs: "bool",
|
|
|
|
|
'LessEqual': lambda inputs, attrs: "bool",
|
|
|
|
|
'Equal': lambda inputs, attrs: "bool",
|
|
|
|
|
'Greater': lambda inputs, attrs: "bool",
|
|
|
|
|
'GreaterEqual': lambda inputs, attrs: "bool",
|
|
|
|
|
}
|
|
|
|
|
infer_format_func = {
|
|
|
|
|
# add special infer func here
|
|
|
|
|
'Reshape': lambda inputs, attrs: "DefaultFormat",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def infer(cls, prim_name, inputs, attrs):
|
|
|
|
|
prim = PrimLib.primtives[prim_name]
|
|
|
|
|
infer_shape = cls.infer_shape_func.get(
|
|
|
|
|
prim_name, cls.default_infer_shape_func[prim.iter_type])
|
|
|
|
|
infer_dtype = cls.infer_dtype_func.get(
|
|
|
|
|
prim_name, cls.default_infer_dtype_func)
|
|
|
|
|
infer_format = cls.infer_format_func.get(
|
|
|
|
|
prim_name, cls.default_infer_format_func)
|
|
|
|
|
return infer_shape(inputs, attrs), infer_dtype(inputs, attrs), infer_format(inputs, attrs)
|
|
|
|
|
from . import op_infer
|
|
|
|
|
from .model import Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GraphBuilder:
|
|
|
|
|
@ -229,7 +98,7 @@ class GraphBuilder:
|
|
|
|
|
if isinstance(inputs, (Tensor, Value)):
|
|
|
|
|
inputs = [inputs]
|
|
|
|
|
tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))]
|
|
|
|
|
out_shape, out_dtype, out_format = OpInfer.infer(prim, tensor_inputs, attrs)
|
|
|
|
|
out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs)
|
|
|
|
|
output = self.tensor(out_shape, out_dtype, out_format, name)
|
|
|
|
|
self.op(prim, output, inputs, attrs)
|
|
|
|
|
return output
|
|
|
|
|
|