You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
428 lines
15 KiB
428 lines
15 KiB
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ============================================================================
|
|
|
|
"""primitive"""
|
|
|
|
import inspect
|
|
import copy
|
|
from mindspore.common.api import _wrap_func
|
|
from mindspore.common import Parameter
|
|
from mindspore.common._register_for_tensor import tensor_operator_registry
|
|
from mindspore import context
|
|
from .._c_expression import Primitive_, real_run_op, prim_type
|
|
from .._c_expression import signature_rw as sig_rw
|
|
from .._c_expression import signature_kind as sig_kind
|
|
from .._c_expression import signature_dtype as sig_dtype
|
|
|
|
|
|
class Primitive(Primitive_):
|
|
"""
|
|
Primitive is the base class of primitives in python.
|
|
|
|
Args:
|
|
name (str): Name for the current Primitive.
|
|
|
|
Examples:
|
|
>>> add = Primitive('add')
|
|
>>>
|
|
>>> # or work with prim_attr_register:
|
|
>>> # init a Primitive class with attr1 and attr2
|
|
>>> class Add(Primitive):
|
|
>>> @prim_attr_register
|
|
>>> def __init__(self, attr1, attr2):
|
|
>>> # check attr1 and attr2 or do some initializations
|
|
>>> # init a Primitive obj with attr1=1 and attr2=2
|
|
>>> add = Add(attr1=1, attr2=2)
|
|
"""
|
|
_repr_ignore_list = ['input_names', 'output_names']
|
|
|
|
def __init__(self, name):
|
|
self.name = name
|
|
self.attrs = {}
|
|
self.init_attrs = {"name": name}
|
|
self._update_parameter = False
|
|
Primitive_.__init__(self, name, self)
|
|
if hasattr(self.__class__, '__mindspore_signature__'):
|
|
sig = self._fill_signature(self.__class__.__mindspore_signature__)
|
|
self.set_signatures(sig)
|
|
|
|
def _fill_signature(self, signatures):
|
|
"""fills signature."""
|
|
signatures_new = []
|
|
for signature in signatures:
|
|
if isinstance(signature, sig_dtype):
|
|
signatures_new.append(("argument", sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD,
|
|
sig_kind.KIND_EMPTY_DEFAULT_VALUE, signature))
|
|
else:
|
|
if len(signature) < 3:
|
|
raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}")
|
|
if len(signature) == 3:
|
|
signature += (sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T_EMPTY_DEFAULT_VALUE)
|
|
if len(signature) == 4:
|
|
signature += (sig_dtype.T_EMPTY_DEFAULT_VALUE,)
|
|
signatures_new.append(signature)
|
|
return tuple(signatures_new)
|
|
|
|
def _clone(self):
|
|
"""
|
|
Deeply clones the primitive object.
|
|
|
|
Calls the __init__() method with the same arguments. This method is called in parser if the
|
|
flag self.__setattr_flag__ is True.
|
|
"""
|
|
cloned = copy.deepcopy(self)
|
|
init_params = inspect.getfullargspec(cloned.__init__.decorated_func).args[1:]
|
|
init_args = {}
|
|
for name in init_params:
|
|
value = self.attrs[name]
|
|
init_args[name] = value
|
|
# __init__ should be called to construct cpp object.
|
|
cloned.__init__(**init_args)
|
|
for name in self.attrs:
|
|
value = self.attrs[name]
|
|
cloned.add_prim_attr(name, value)
|
|
if hasattr(self, 'instance_name'):
|
|
cloned.set_prim_instance_name(self.instance_name)
|
|
return cloned
|
|
|
|
def add_prim_attr(self, name, value):
|
|
"""
|
|
Adds primitive attribute.
|
|
|
|
Args:
|
|
name (str): Attribute Name.
|
|
value (Any): Attribute value.
|
|
"""
|
|
self.__dict__[name] = value
|
|
self.attrs[name] = value
|
|
self.add_attr(name, value)
|
|
return self
|
|
|
|
def set_strategy(self, strategy):
|
|
"""
|
|
Add strategies to primitive attribute.
|
|
|
|
Note:
|
|
It is valid only in semi auto parallel or auto parallel mode.
|
|
In other parallel modes, strategies set here will be ignored.
|
|
|
|
Args:
|
|
strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
|
|
"""
|
|
self.add_prim_attr("strategy", strategy)
|
|
return self
|
|
|
|
def set_prim_instance_name(self, instance_name):
|
|
"""
|
|
Set instance name to primitive operator.
|
|
|
|
Note:
|
|
It will be called by default when user defines primitive operator.
|
|
|
|
Args:
|
|
instance_name (str): Instance name of primitive operator set by user.
|
|
"""
|
|
self.set_instance_name(instance_name)
|
|
self.instance_name = instance_name
|
|
return self
|
|
|
|
def __getattr__(self, item):
|
|
if item == 'infer_dynamic_shape':
|
|
return None
|
|
if item in super().get_attr_dict():
|
|
return super().get_attr_dict()[item]
|
|
if item in self.attrs:
|
|
return self.attrs[item]
|
|
raise AttributeError(item)
|
|
|
|
def check_elim(self, *args):
|
|
"""
|
|
Check if certain inputs should go to the backend. Subclass in need should override this method.
|
|
|
|
Args:
|
|
*args(Primitive args): Same as arguments of current Primitive.
|
|
|
|
Returns:
|
|
A tuple consisting of two elements. The first element indicates whether we should filter out current
|
|
arguments; the seconde element is the output if we need to filter out the arguments.
|
|
"""
|
|
return (False, None)
|
|
|
|
def __call__(self, *args):
|
|
should_elim, output = self.check_elim(*args)
|
|
if should_elim:
|
|
return output
|
|
return _run_op(self, self.name, args)
|
|
|
|
def __getstate__(self):
|
|
return self.__dict__
|
|
|
|
def __setstate__(self, d):
|
|
self.__dict__.update(d)
|
|
|
|
def __deepcopy__(self, memo):
|
|
return type(self)(**self.init_attrs)
|
|
|
|
def __repr__(self):
|
|
attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list])
|
|
info_str = f'Prim[{self.name}]'
|
|
if attr:
|
|
info_str += f'<{attr}>'
|
|
return info_str
|
|
|
|
def init_prim_io_names(self, inputs, outputs):
|
|
"""
|
|
Initializes the name of inputs and outpus of Tensor or attributes.
|
|
|
|
Args:
|
|
inputs (list[str]): list of inputs names.
|
|
outputs (list[str]): list of outputs names.
|
|
"""
|
|
# for checking para names with kernel implementation
|
|
self.add_prim_attr("input_names", inputs)
|
|
# for checking output number with kernel implementation
|
|
self.add_prim_attr("output_names", outputs)
|
|
|
|
@property
|
|
def update_parameter(self):
|
|
""" Whether the primitive will update the value of parameter."""
|
|
return self._update_parameter
|
|
|
|
|
|
class PrimitiveWithInfer(Primitive):
|
|
"""
|
|
PrimitiveWithInfer is the base class of primitives in python defines functions for tracking inference in python.
|
|
|
|
There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
|
|
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
|
|
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer
|
|
logic of the shape and type. The infer_value() is used for constant propagation.
|
|
|
|
Args:
|
|
name (str): Name of the current Primitive.
|
|
|
|
Examples:
|
|
>>> # init a Primitive class with infer
|
|
>>> class Add(PrimitiveWithInfer):
|
|
>>> @prim_attr_register
|
|
>>> def __init__(self):
|
|
>>> pass
|
|
>>>
|
|
>>> def infer_shape(self, x, y):
|
|
>>> return x # output shape same as first input 'x'
|
|
>>>
|
|
>>> def infer_dtype(self, x, y):
|
|
>>> return x # output type same as first input 'x'
|
|
>>>
|
|
>>> # init a Primitive obj
|
|
>>> add = Add()
|
|
"""
|
|
|
|
def __init__(self, name):
|
|
Primitive.__init__(self, name)
|
|
self.set_prim_type(prim_type.py_infer_shape)
|
|
|
|
def _clone(self):
|
|
"""
|
|
Deeply clones the primitive object.
|
|
|
|
Calls the __init__() method with the same arguments. This method is called in parser if the
|
|
flag self.__setattr_flag__ is True.
|
|
"""
|
|
cloned_prim = Primitive._clone(self)
|
|
return cloned_prim
|
|
|
|
def infer_shape(self, *args):
|
|
"""
|
|
Infer output shape based on input shape.
|
|
|
|
Note:
|
|
The shape of scalar is an empty tuple.
|
|
|
|
Args:
|
|
args (tuple(int)): shapes of input tensors.
|
|
|
|
Return:
|
|
`tuple(int)`, shapes of output tensors.
|
|
"""
|
|
return None
|
|
|
|
def infer_dtype(self, *args):
|
|
"""
|
|
Infer output dtype based on input dtype.
|
|
|
|
Args:
|
|
args (:class:`mindspore.dtype`): data type of inputs.
|
|
|
|
Return:
|
|
:class:`mindspore.dtype`, data type of outputs.
|
|
"""
|
|
return None
|
|
|
|
def infer_value(self, *args):
|
|
"""
|
|
Infer output value based on input value at compile time.
|
|
|
|
Args:
|
|
args (Any): value of inputs.
|
|
|
|
Return:
|
|
Value of outputs. Return `None`, the value can not be inferred at compile time in this case.
|
|
"""
|
|
return None
|
|
|
|
def __infer__(self, *args):
|
|
"""Infer shape, type, and value at the same time by using dictionary as arguments."""
|
|
is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
|
|
fn_infer_dynamic_shape = getattr(self, 'infer_dynamic_shape', None)
|
|
if is_graph_mode and fn_infer_dynamic_shape is not None:
|
|
out = fn_infer_dynamic_shape(*args)
|
|
tracks = ['dtype', 'value']
|
|
for track in tracks:
|
|
fn = getattr(self, 'infer_' + track)
|
|
# fn may return None
|
|
out[track] = fn(*(x[track] for x in args))
|
|
return out
|
|
|
|
tracks = ['dtype', 'shape', 'value']
|
|
out = {}
|
|
for track in tracks:
|
|
fn = getattr(self, 'infer_' + track)
|
|
# fn may return None
|
|
out[track] = fn(*(x[track] for x in args))
|
|
|
|
# in non-graph_mode, it is not necessary to infer min/max shape
|
|
if not is_graph_mode:
|
|
return out
|
|
|
|
def get_specified_shape(elems, attr):
|
|
has_specified_shape = False
|
|
ret_vals = []
|
|
for elem in elems:
|
|
if attr in elem:
|
|
has_specified_shape = True
|
|
ret_vals.append(elem[attr])
|
|
else:
|
|
ret_vals.append(elem['shape'])
|
|
return has_specified_shape, tuple(ret_vals)
|
|
|
|
has_min_shape, min_shapes = get_specified_shape(args, 'min_shape')
|
|
has_max_shape, max_shapes = get_specified_shape(args, 'max_shape')
|
|
if not (has_min_shape or has_max_shape):
|
|
return out
|
|
if has_min_shape and has_max_shape:
|
|
fn_infer_shape = getattr(self, 'infer_shape')
|
|
out['min_shape'] = fn_infer_shape(*min_shapes)
|
|
out['max_shape'] = fn_infer_shape(*max_shapes)
|
|
return out
|
|
raise ValueError('Input args has invalid dynamic shape, args info: {args}')
|
|
|
|
|
|
def prim_attr_register(fn):
|
|
"""
|
|
Primitive attributes register.
|
|
|
|
Register the decorator of the built-in operator primitive '__init__'.
|
|
The function will add all the parameters of '__init__' as operator attributes.
|
|
|
|
Args:
|
|
fn (function): __init__ function of primitive.
|
|
|
|
Returns:
|
|
function, original function.
|
|
"""
|
|
def deco(self, *args, **kwargs):
|
|
if isinstance(self, PrimitiveWithInfer):
|
|
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
|
|
else:
|
|
Primitive.__init__(self, self.__class__.__name__)
|
|
bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
|
|
bound_args.apply_defaults()
|
|
arguments = bound_args.arguments
|
|
del arguments['self']
|
|
del self.init_attrs['name']
|
|
for name in arguments:
|
|
value = arguments[name]
|
|
self.add_prim_attr(name, value)
|
|
self.init_attrs[name] = value
|
|
fn(self, *args, **kwargs)
|
|
deco.decorated_func = fn
|
|
return deco
|
|
|
|
|
|
def constexpr(fn=None, get_instance=True, name=None):
|
|
"""
|
|
Makes a PrimitiveWithInfer operator that can infer the value at compile time. We can define a function
|
|
to compute between constant variable and used in constructß.
|
|
|
|
Args:
|
|
fn (function): A `fn` use as the infer_value of the output operator.
|
|
get_instance (bool): If true, return the instance of operator, otherwise return the operator class.
|
|
name (str): Defines the operator name. If `name` is None, use the function name as op name.
|
|
|
|
Examples:
|
|
>>> a = (1, 2)
|
|
>>> # make an operator to calculate tuple len
|
|
>>> @constexpr
|
|
>>> def tuple_len(x):
|
|
>>> return len(x)
|
|
>>> assert tuple_len(a) == 2
|
|
>>>
|
|
>>> # make a operator class to calculate tuple len
|
|
>>> @constexpr(get_instance=False, name="TupleLen")
|
|
>>> def tuple_len_class(x):
|
|
>>> return len(x)
|
|
>>> assert tuple_len_class()(a) == 2
|
|
"""
|
|
def deco(fn):
|
|
class CompileOp(PrimitiveWithInfer):
|
|
def __init__(self):
|
|
op_name = name if name else fn.__name__
|
|
PrimitiveWithInfer.__init__(self, op_name)
|
|
self.set_is_const_value(True)
|
|
|
|
def infer_value(self, *args):
|
|
return fn(*args)
|
|
if get_instance:
|
|
return CompileOp()
|
|
return CompileOp
|
|
if fn is not None:
|
|
return deco(fn)
|
|
return deco
|
|
|
|
|
|
@_wrap_func
|
|
def _run_op(obj, op_name, args):
|
|
"""Single op execution function supported by ge in PyNative mode."""
|
|
cast = tensor_operator_registry.get("cast")
|
|
if op_name == "Cast" or obj.update_parameter:
|
|
cast_args = args
|
|
else:
|
|
cast_args = list()
|
|
for arg in args:
|
|
if isinstance(arg, Parameter):
|
|
if arg.cast_type:
|
|
cast_args.append(cast(arg, arg.cast_type))
|
|
else:
|
|
cast_args.append(arg)
|
|
else:
|
|
cast_args.append(arg)
|
|
output = real_run_op(obj, op_name, tuple(cast_args))
|
|
if not output:
|
|
raise RuntimeError("Pynative run op %s failed!" % op_name)
|
|
if len(output) == 1:
|
|
output = output[0]
|
|
return output
|