|
|
|
@ -22,6 +22,7 @@ from mindspore import context
|
|
|
|
|
from .._c_expression import Primitive_, real_run_op, prim_type
|
|
|
|
|
from . import signature as sig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Primitive(Primitive_):
|
|
|
|
|
"""
|
|
|
|
|
Primitive is the base class of primitives in python.
|
|
|
|
@ -168,7 +169,7 @@ class Primitive(Primitive_):
|
|
|
|
|
return type(self)(**self.init_attrs)
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
attr = ', '.join([f'{k}={self.attrs[k]}'for k in self.attrs if not k in Primitive._repr_ignore_list])
|
|
|
|
|
attr = ', '.join([f'{k}={self.attrs[k]}' for k in self.attrs if not k in Primitive._repr_ignore_list])
|
|
|
|
|
info_str = f'Prim[{self.name}]'
|
|
|
|
|
if attr:
|
|
|
|
|
info_str += f'<{attr}>'
|
|
|
|
@ -425,6 +426,7 @@ def prim_attr_register(fn):
|
|
|
|
|
Returns:
|
|
|
|
|
function, original function.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def deco(self, *args, **kwargs):
|
|
|
|
|
if isinstance(self, PrimitiveWithInfer):
|
|
|
|
|
PrimitiveWithInfer.__init__(self, self.__class__.__name__)
|
|
|
|
@ -442,6 +444,7 @@ def prim_attr_register(fn):
|
|
|
|
|
self.add_prim_attr(name, value)
|
|
|
|
|
self.init_attrs[name] = value
|
|
|
|
|
fn(self, *args, **kwargs)
|
|
|
|
|
|
|
|
|
|
deco.decorated_func = fn
|
|
|
|
|
return deco
|
|
|
|
|
|
|
|
|
@ -470,6 +473,7 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|
|
|
|
>>> return len(x)
|
|
|
|
|
>>> assert tuple_len_class()(a) == 2
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def deco(fn):
|
|
|
|
|
class CompileOp(PrimitiveWithInfer):
|
|
|
|
|
def __init__(self):
|
|
|
|
@ -479,9 +483,11 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|
|
|
|
|
|
|
|
|
def infer_value(self, *args):
|
|
|
|
|
return fn(*args)
|
|
|
|
|
|
|
|
|
|
if get_instance:
|
|
|
|
|
return CompileOp()
|
|
|
|
|
return CompileOp
|
|
|
|
|
|
|
|
|
|
if fn is not None:
|
|
|
|
|
return deco(fn)
|
|
|
|
|
return deco
|
|
|
|
|