|
|
|
@ -103,6 +103,19 @@ class Primitive(Primitive_):
|
|
|
|
|
self.add_attr(name, value)
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def del_prim_attr(self, name):
|
|
|
|
|
"""
|
|
|
|
|
Del primitive attribute.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
name (str): Attribute Name.
|
|
|
|
|
"""
|
|
|
|
|
if name in self.__dict__ and name in self.attrs:
|
|
|
|
|
del self.__dict__[name]
|
|
|
|
|
del self.attrs[name]
|
|
|
|
|
self.del_attr(name)
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
def set_stage(self, stage):
|
|
|
|
|
"""
|
|
|
|
|
Add stage id to primitive attribute.
|
|
|
|
@ -191,7 +204,7 @@ class Primitive(Primitive_):
|
|
|
|
|
|
|
|
|
|
def init_prim_io_names(self, inputs, outputs):
|
|
|
|
|
"""
|
|
|
|
|
Initializes the name of inputs and outpus of Tensor or attributes.
|
|
|
|
|
Initializes the name of inputs and outputs of Tensor or attributes.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
inputs (list[str]): list of inputs names.
|
|
|
|
@ -222,9 +235,9 @@ class Primitive(Primitive_):
|
|
|
|
|
class PrimitiveWithCheck(Primitive):
|
|
|
|
|
"""
|
|
|
|
|
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments
|
|
|
|
|
but used the infer method registed in c++ source codes.
|
|
|
|
|
but used the infer method registered in c++ source codes.
|
|
|
|
|
|
|
|
|
|
There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(),
|
|
|
|
|
There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(),
|
|
|
|
|
check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called.
|
|
|
|
|
If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of
|
|
|
|
|
the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
|
|
|
|
@ -301,7 +314,7 @@ class PrimitiveWithInfer(Primitive):
|
|
|
|
|
"""
|
|
|
|
|
PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python.
|
|
|
|
|
|
|
|
|
|
There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
|
|
|
|
|
There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(),
|
|
|
|
|
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
|
|
|
|
|
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer
|
|
|
|
|
logic of the shape and type. The infer_value() is used for constant propagation.
|
|
|
|
|