!12549 【GraphKernel】Refactor GraphKernelExpander (3rd submission)

From: @dayschan
Reviewed-by: 
Signed-off-by:
pull/12549/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit be99db696d

@ -20,24 +20,30 @@ from ._utils import Expander, ExpanderInfoValidator as VLD
@VLD.add_format(DF.DEFAULT)
@VLD.add_format(DF.NHWC)
@VLD.add_format(DF.NCHW)
@VLD.add_format(DF.FRAC_NZ)
class BiasAddGrad(Expander):
"""BiasAddGrad expander"""
def _expand(self, graph_builder):
input_x = self.inputs[0]
x = self.inputs[0]
reduce_axis = ()
if input_x.data_format == 'NHWC':
if x.data_format == DF.NHWC:
reduce_axis = (0, 1, 2)
elif input_x.data_format == 'NCHW':
elif x.data_format == DF.NCHW:
reduce_axis = (0, 2, 3)
# DefaultFormat shape's length should be from 2 to 4
elif x.data_format == DF.FRAC_NZ:
reduce_axis = (-2, -3)
else:
if len(input_x.shape) == 2:
# DefaultFormat shape's length should be from 2 to 4
if len(x.shape) == 2:
reduce_axis = (0,)
elif len(input_x.shape) == 3:
elif len(x.shape) == 3:
reduce_axis = (0, 1)
else:
reduce_axis = (0, 2, 3)
result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False})
if x.data_format == DF.FRAC_NZ:
out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]]
result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape})
return result

@ -13,7 +13,7 @@
# limitations under the License.
# ===========================================================================
"""generate json desc for Tile"""
from mindspore._extends.graph_kernel.model import model_builder as builder
from mindspore._extends.graph_kernel.model.op_infer import Tile as TileInfer
from mindspore._extends.graph_kernel.model.model import DataFormat as DF
from ._utils import Expander, ExpanderInfoValidator as VLD
@ -27,8 +27,9 @@ class Tile(Expander):
input_x = self.inputs[0]
multiples = self.attrs['multiples']
output_shape, _, _, shape_compatible = builder.get_tile_output_shape(self.inputs[0].shape, multiples)
if shape_compatible:
tile_infer = TileInfer(self.name, self.inputs, self.attrs)
output_shape, _, _ = tile_infer.infer()
if tile_infer.broadcast_compatible:
result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape})
else:
result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples})

@ -15,139 +15,8 @@
"""GraphKernel model builder"""
import copy
from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy, DataFormat
def get_tile_output_shape(shape, multiples):
"""compute output shape of tile"""
if multiples is None:
return shape
if not isinstance(shape, (list, tuple)):
raise TypeError("Input shape of Tile must be of type list or tuple")
if not isinstance(multiples, (list, tuple)):
raise TypeError("multiples of Tile must be of type list or tuple")
shape = list(shape)
multiples = list(multiples)
diff_len = len(multiples) - len(shape)
if diff_len < 0:
raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape))
if diff_len > 0:
for _ in range(diff_len):
shape.insert(0, 1)
shape_compatible = True
output_shape = []
input_reshape = []
output_reshape = []
for sh, mul in list(zip(shape, multiples)):
dim = sh * mul
output_shape.append(dim)
if sh == 1 or mul == 1:
input_reshape.append(sh)
output_reshape.append(dim)
else:
shape_compatible = False
input_reshape.append(1)
input_reshape.append(sh)
output_reshape.append(mul)
output_reshape.append(sh)
return output_shape, input_reshape, output_reshape, shape_compatible
class OpInfer:
"""Op infer"""
@staticmethod
def default_reduce_infer(inputs, attrs):
"""Default reduce infer"""
shape = copy.deepcopy(inputs[0].shape)
if attrs['keep_dims']:
for i in attrs['reduce_axis']:
shape[i] = 1
return shape
real_shape = []
for i, _ in enumerate(shape):
if i not in attrs['reduce_axis'] and i - len(shape) not in attrs['reduce_axis']:
real_shape.append(shape[i])
return real_shape
@staticmethod
def default_elementwise_infer(inputs, attrs):
"""Default elementwise infer"""
shape = (1,)
max_flatten_shape = 1
for t in inputs:
flatten_shape = 1
for s in t.shape:
flatten_shape *= s
if flatten_shape >= max_flatten_shape:
max_flatten_shape = flatten_shape
shape = t.shape
return shape
default_infer_shape_func = [
None,
None,
default_elementwise_infer.__func__,
lambda inputs, attrs: max([t.shape for t in inputs]),
default_reduce_infer.__func__,
None,
lambda inputs, attrs: [1], # control op
]
@staticmethod
def default_infer_dtype_func(inputs, attrs):
"""Infer dtype"""
return inputs[0].dtype
@staticmethod
def default_infer_format_func(inputs, attrs):
"""Infer format"""
result = inputs[0].data_format
# default_format and other_format results in other_format
for input_tensor in inputs[1:]:
data_format = input_tensor.data_format
if data_format != DataFormat.DEFAULT:
if result not in [DataFormat.DEFAULT, data_format]:
raise RuntimeError("Incompatible data format %s and %s" % (data_format, result))
result = data_format
return result
infer_shape_func = {
# add special infer func here
'InplaceAssign': lambda inputs, attrs: inputs[2].shape,
'Reshape': lambda inputs, attrs: attrs["shape"],
'BroadcastTo': lambda inputs, attrs: attrs["shape"],
'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0],
'ExpandDims': lambda inputs, attrs: list(inputs[0].shape).insert(attrs["axis"], 1),
}
infer_dtype_func = {
# add special infer func here
'Cast': lambda inputs, attrs: attrs['dst_type'],
'Less': lambda inputs, attrs: "bool",
'LessEqual': lambda inputs, attrs: "bool",
'Equal': lambda inputs, attrs: "bool",
'Greater': lambda inputs, attrs: "bool",
'GreaterEqual': lambda inputs, attrs: "bool",
}
infer_format_func = {
# add special infer func here
'Reshape': lambda inputs, attrs: "DefaultFormat",
}
@classmethod
def infer(cls, prim_name, inputs, attrs):
prim = PrimLib.primtives[prim_name]
infer_shape = cls.infer_shape_func.get(
prim_name, cls.default_infer_shape_func[prim.iter_type])
infer_dtype = cls.infer_dtype_func.get(
prim_name, cls.default_infer_dtype_func)
infer_format = cls.infer_format_func.get(
prim_name, cls.default_infer_format_func)
return infer_shape(inputs, attrs), infer_dtype(inputs, attrs), infer_format(inputs, attrs)
from . import op_infer
from .model import Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy
class GraphBuilder:
@ -229,7 +98,7 @@ class GraphBuilder:
if isinstance(inputs, (Tensor, Value)):
inputs = [inputs]
tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))]
out_shape, out_dtype, out_format = OpInfer.infer(prim, tensor_inputs, attrs)
out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs)
output = self.tensor(out_shape, out_dtype, out_format, name)
self.op(prim, output, inputs, attrs)
return output

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save