diff --git a/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py b/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py index 9d9163efc9..161f33c07c 100644 --- a/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py +++ b/mindspore/_extends/graph_kernel/expanders/bias_add_grad.py @@ -20,24 +20,30 @@ from ._utils import Expander, ExpanderInfoValidator as VLD @VLD.add_format(DF.DEFAULT) @VLD.add_format(DF.NHWC) @VLD.add_format(DF.NCHW) +@VLD.add_format(DF.FRAC_NZ) class BiasAddGrad(Expander): """BiasAddGrad expander""" def _expand(self, graph_builder): - input_x = self.inputs[0] + x = self.inputs[0] reduce_axis = () - if input_x.data_format == 'NHWC': + if x.data_format == DF.NHWC: reduce_axis = (0, 1, 2) - elif input_x.data_format == 'NCHW': + elif x.data_format == DF.NCHW: reduce_axis = (0, 2, 3) - # DefaultFormat shape's length should be from 2 to 4 + elif x.data_format == DF.FRAC_NZ: + reduce_axis = (-2, -3) else: - if len(input_x.shape) == 2: + # DefaultFormat shape's length should be from 2 to 4 + if len(x.shape) == 2: reduce_axis = (0,) - elif len(input_x.shape) == 3: + elif len(x.shape) == 3: reduce_axis = (0, 1) else: reduce_axis = (0, 2, 3) - result = graph_builder.emit('ReduceSum', [input_x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + result = graph_builder.emit('ReduceSum', [x], attrs={'reduce_axis': reduce_axis, 'keep_dims': False}) + if x.data_format == DF.FRAC_NZ: + out_shape = x.shape[:-4] + [x.shape[-1] * x.shape[-4]] + result = graph_builder.emit('Reshape', [result], attrs={'shape': out_shape}) return result diff --git a/mindspore/_extends/graph_kernel/expanders/tile.py b/mindspore/_extends/graph_kernel/expanders/tile.py index 45e881f550..23c2584105 100644 --- a/mindspore/_extends/graph_kernel/expanders/tile.py +++ b/mindspore/_extends/graph_kernel/expanders/tile.py @@ -13,7 +13,7 @@ # limitations under the License. # =========================================================================== """generate json desc for Tile""" -from mindspore._extends.graph_kernel.model import model_builder as builder +from mindspore._extends.graph_kernel.model.op_infer import Tile as TileInfer from mindspore._extends.graph_kernel.model.model import DataFormat as DF from ._utils import Expander, ExpanderInfoValidator as VLD @@ -27,8 +27,9 @@ class Tile(Expander): input_x = self.inputs[0] multiples = self.attrs['multiples'] - output_shape, _, _, shape_compatible = builder.get_tile_output_shape(self.inputs[0].shape, multiples) - if shape_compatible: + tile_infer = TileInfer(self.name, self.inputs, self.attrs) + output_shape, _, _ = tile_infer.infer() + if tile_infer.broadcast_compatible: result = graph_builder.emit('BroadcastTo', [input_x], attrs={'shape': output_shape}) else: result = graph_builder.emit('Tile', [input_x], attrs={'multiples': multiples}) diff --git a/mindspore/_extends/graph_kernel/model/model_builder.py b/mindspore/_extends/graph_kernel/model/model_builder.py index cdcd4dc92e..4125b75e1b 100644 --- a/mindspore/_extends/graph_kernel/model/model_builder.py +++ b/mindspore/_extends/graph_kernel/model/model_builder.py @@ -15,139 +15,8 @@ """GraphKernel model builder""" import copy -from .model import PrimLib, Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy, DataFormat - - -def get_tile_output_shape(shape, multiples): - """compute output shape of tile""" - - if multiples is None: - return shape - if not isinstance(shape, (list, tuple)): - raise TypeError("Input shape of Tile must be of type list or tuple") - if not isinstance(multiples, (list, tuple)): - raise TypeError("multiples of Tile must be of type list or tuple") - - shape = list(shape) - multiples = list(multiples) - diff_len = len(multiples) - len(shape) - if diff_len < 0: - raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) - if diff_len > 0: - for _ in range(diff_len): - shape.insert(0, 1) - - shape_compatible = True - output_shape = [] - input_reshape = [] - output_reshape = [] - for sh, mul in list(zip(shape, multiples)): - dim = sh * mul - output_shape.append(dim) - if sh == 1 or mul == 1: - input_reshape.append(sh) - output_reshape.append(dim) - else: - shape_compatible = False - input_reshape.append(1) - input_reshape.append(sh) - output_reshape.append(mul) - output_reshape.append(sh) - - return output_shape, input_reshape, output_reshape, shape_compatible - - -class OpInfer: - """Op infer""" - @staticmethod - def default_reduce_infer(inputs, attrs): - """Default reduce infer""" - shape = copy.deepcopy(inputs[0].shape) - if attrs['keep_dims']: - for i in attrs['reduce_axis']: - shape[i] = 1 - return shape - - real_shape = [] - for i, _ in enumerate(shape): - if i not in attrs['reduce_axis'] and i - len(shape) not in attrs['reduce_axis']: - real_shape.append(shape[i]) - return real_shape - - @staticmethod - def default_elementwise_infer(inputs, attrs): - """Default elementwise infer""" - shape = (1,) - max_flatten_shape = 1 - for t in inputs: - flatten_shape = 1 - for s in t.shape: - flatten_shape *= s - if flatten_shape >= max_flatten_shape: - max_flatten_shape = flatten_shape - shape = t.shape - return shape - - default_infer_shape_func = [ - None, - None, - default_elementwise_infer.__func__, - lambda inputs, attrs: max([t.shape for t in inputs]), - default_reduce_infer.__func__, - None, - lambda inputs, attrs: [1], # control op - ] - - @staticmethod - def default_infer_dtype_func(inputs, attrs): - """Infer dtype""" - return inputs[0].dtype - - @staticmethod - def default_infer_format_func(inputs, attrs): - """Infer format""" - result = inputs[0].data_format - # default_format and other_format results in other_format - for input_tensor in inputs[1:]: - data_format = input_tensor.data_format - if data_format != DataFormat.DEFAULT: - if result not in [DataFormat.DEFAULT, data_format]: - raise RuntimeError("Incompatible data format %s and %s" % (data_format, result)) - result = data_format - return result - - infer_shape_func = { - # add special infer func here - 'InplaceAssign': lambda inputs, attrs: inputs[2].shape, - 'Reshape': lambda inputs, attrs: attrs["shape"], - 'BroadcastTo': lambda inputs, attrs: attrs["shape"], - 'Tile': lambda inputs, attrs: get_tile_output_shape(inputs[0].shape, attrs["multiples"])[0], - 'ExpandDims': lambda inputs, attrs: list(inputs[0].shape).insert(attrs["axis"], 1), - } - infer_dtype_func = { - # add special infer func here - 'Cast': lambda inputs, attrs: attrs['dst_type'], - 'Less': lambda inputs, attrs: "bool", - 'LessEqual': lambda inputs, attrs: "bool", - 'Equal': lambda inputs, attrs: "bool", - 'Greater': lambda inputs, attrs: "bool", - 'GreaterEqual': lambda inputs, attrs: "bool", - } - infer_format_func = { - # add special infer func here - 'Reshape': lambda inputs, attrs: "DefaultFormat", - } - - @classmethod - def infer(cls, prim_name, inputs, attrs): - prim = PrimLib.primtives[prim_name] - infer_shape = cls.infer_shape_func.get( - prim_name, cls.default_infer_shape_func[prim.iter_type]) - infer_dtype = cls.infer_dtype_func.get( - prim_name, cls.default_infer_dtype_func) - infer_format = cls.infer_format_func.get( - prim_name, cls.default_infer_format_func) - return infer_shape(inputs, attrs), infer_dtype(inputs, attrs), infer_format(inputs, attrs) +from . import op_infer +from .model import Tensor, Value, Operator, Graph, AlignShape, AddControlBuddy class GraphBuilder: @@ -229,7 +98,7 @@ class GraphBuilder: if isinstance(inputs, (Tensor, Value)): inputs = [inputs] tensor_inputs = [t for t in inputs if isinstance(t, (Tensor, Value))] - out_shape, out_dtype, out_format = OpInfer.infer(prim, tensor_inputs, attrs) + out_shape, out_dtype, out_format = op_infer.infer(prim, tensor_inputs, attrs) output = self.tensor(out_shape, out_dtype, out_format, name) self.op(prim, output, inputs, attrs) return output diff --git a/mindspore/_extends/graph_kernel/model/op_infer.py b/mindspore/_extends/graph_kernel/model/op_infer.py new file mode 100644 index 0000000000..8ae890364e --- /dev/null +++ b/mindspore/_extends/graph_kernel/model/op_infer.py @@ -0,0 +1,275 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========================================================================== +"""GraphKernel Op Infer""" + + +import copy +import sys +from functools import reduce +from .model import GraphKernelUnsupportedException as GKException +from .model import PrimLib, DataFormat as DF + + +def infer(op_name, inputs, attrs): + """infer shape dtype and format""" + def _create_opinfer(): + if hasattr(sys.modules[__name__], op_name): + op_cls = getattr(sys.modules[__name__], op_name) + return op_cls(op_name, inputs, attrs) + # common infer + class_name_map = { + PrimLib.ELEMWISE: "_Elemwise", + PrimLib.REDUCE: "_Reduce", + } + cls_name = class_name_map.get(PrimLib.primtives.get(op_name, PrimLib.default_primtive).iter_type, None) + if not cls_name: + raise GKException("OpInfo does not support op {}".format(op_name)) + op_cls = getattr(sys.modules[__name__], cls_name) + return op_cls(op_name, inputs, attrs) + return _create_opinfer().infer() + + +class OpInfer: + """ + OpInfer is the base class for inferring operator info in GraphKernel model builder. + + There are three methods should be overridden to define the infer logic of the operator: + _infer_shape(), _infer_type() and _infer_format(). + """ + + def __init__(self, name, inputs, attrs): + self.name = name + self.inputs = inputs + self.attrs = attrs + + def infer(self): + """Infer shape, type and format by op inputs""" + self._check() + return self._infer_shape(), self._infer_type(), self._infer_format() + + def _infer_shape(self): + return self.inputs[0].shape + + def _infer_type(self): + return self.inputs[0].dtype + + def _infer_format(self): + return self.inputs[0].data_format + + def _check(self): + self._check_shape() + self._check_type() + self._check_format() + + def _check_shape(self): + pass + + def _check_type(self): + """check all dtypes are same""" + dtype = self.inputs[0].dtype + for i, t in enumerate(self.inputs[1:]): + if t.dtype != dtype: + raise GKException( + "Incompatible dtype between input {}({}) and {}({})".format(0, dtype, i + 1, t.dtype)) + + def _check_format(self): + """check formats are compatible. only DefaultFormat is compatible with others""" + result = self.inputs[0].data_format + i = 0 + for j, t in enumerate(self.inputs[1:]): + if t.data_format != result: + if DF.DEFAULT not in (result, t.data_format): + raise GKException("Incompatible format between input {}({}) and {}({})".format( + i, result, j + 1, t.data_format)) + if result == DF.DEFAULT: + result = t.data_format + i = j + 1 + + +class _Elemwise(OpInfer): + """Common infer for elementwise operators""" + + def _infer_shape(self): + """returns the input shape with largest flatten size""" + shape = (1,) + max_flatten_size = 1 + for t in self.inputs: + flatten_size = reduce(lambda x, y: x * y, t.shape) + if flatten_size >= max_flatten_size: + max_flatten_size = flatten_size + shape = t.shape + return shape + + def _infer_format(self): + for tensor in self.inputs: + if tensor.data_format != DF.DEFAULT: + return tensor.data_format + return DF.DEFAULT + + +class _Reduce(OpInfer): + """Common infer for reduction operators""" + + def _check(self): + super()._check() + # check reduce axis in the range [-len, len) + shape_len = len(self.inputs[0].shape) + axis = self.attrs['reduce_axis'] + if isinstance(axis, int): + axis = [axis] + if not all([(-shape_len <= i < shape_len) for i in axis]): + raise GKException( + "reduce_axis should be in range [{},{}) but got {}".format(-shape_len, shape_len, axis)) + + def _infer_shape(self): + shape = copy.deepcopy(self.inputs[0].shape) + axis = self.attrs['reduce_axis'] + + if isinstance(axis, int): + axis = [axis] + if any([i < 0 for i in axis]): + # change the axis to non-negative number. + axis = list(map(lambda i: i + len(shape) if i < 0 else i, axis)) + self.attrs['reduce_axis'] = sorted(axis) + + if self.attrs['keep_dims']: + for i in axis: + shape[i] = 1 + return shape + + real_shape = [] + for i, s in enumerate(shape): + if i not in axis: + real_shape.append(s) + return real_shape + + def _infer_format(self): + return DF.DEFAULT + + +class _Reshape(OpInfer): + """Common infer for reshape operators, should not be instantiated""" + + def _infer_shape(self): + raise GKException("_infer_shape should be implemented by subclass") + + def _infer_format(self): + return DF.DEFAULT + + +class Reshape(_Reshape): + def _infer_shape(self): + return self.attrs["shape"] + + +class ExpandDims(_Reshape): + def _infer_shape(self): + return list(self.inputs[0].shape).insert(self.attrs["axis"], 1) + + +class Cast(_Elemwise): + def _infer_type(self): + return self.attrs["dst_type"] + + +class InplaceAssign(_Elemwise): + def _infer_shape(self): + return [1] if self.attrs["fake_output"] else self.inputs[2].shape + + def _infer_type(self): + return self.inputs[2].dtype + + def _infer_format(self): + return DF.DEFAULT if self.attrs["fake_output"] else self.inputs[2].data_format + + +class BroadcastTo(OpInfer): + def _infer_shape(self): + return self.attrs["shape"] + + def _infer_format(self): + return self.inputs[0].data_format + + +class Tile(OpInfer): + """Op Tile""" + + def __init__(self, op_name, inputs, attrs): + super().__init__(op_name, inputs, attrs) + self.input_reshape = None + self.output_reshape = None + self.broadcast_compatible = True + + def _infer_shape(self): + shape = self.inputs[0].shape + multiples = self.attrs["multiples"] + + shape = list(shape) + multiples = list(multiples) + diff_len = len(multiples) - len(shape) + if diff_len < 0: + raise ValueError("Dimensions of multiples{} < dimensions of input{} in Tile".format(multiples, shape)) + if diff_len > 0: + for _ in range(diff_len): + shape.insert(0, 1) + + self.broadcast_compatible = True + output_shape = [] + self.input_reshape = [] + self.output_reshape = [] + for sh, mul in list(zip(shape, multiples)): + dim = sh * mul + output_shape.append(dim) + if sh == 1 or mul == 1: + self.input_reshape.append(sh) + self.output_reshape.append(dim) + else: + self.broadcast_compatible = False + self.input_reshape.append(1) + self.input_reshape.append(sh) + self.output_reshape.append(mul) + self.output_reshape.append(sh) + + return output_shape + + def _infer_format(self): + return DF.DEFAULT + + +class _CompareOp(_Elemwise): + """Compare operators""" + + def _infer_type(self): + return "bool" + + +class Less(_CompareOp): + pass + + +class LessEqual(_CompareOp): + pass + + +class Equal(_CompareOp): + pass + + +class Greater(_CompareOp): + pass + + +class GreaterEqual(_CompareOp): + pass