!11430 add del_attr function

From: @changzherui
Reviewed-by: @kingxian,@zh_qh
Signed-off-by: @kingxian
pull/11430/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit a58413479e

@ -283,6 +283,11 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) {
(void)this->AddAttr(attr_name, converted_ret);
}
void PrimitivePy::DelPyAttr(const py::str &name) {
std::string attr_name = name;
(void)this->DelAttr(attr_name);
}
py::dict PrimitivePy::GetAttrDict() {
py::dict attr_dict;
for (auto &attr : attrs_) {
@ -378,6 +383,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
.def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_)
.def(py::init<py::str &, py::object>())
.def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr")
.def("del_attr", &PrimitivePy::DelPyAttr, "del primitive attr")
.def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr")
.def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.")
.def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.")

@ -49,6 +49,8 @@ class PrimitivePy : public Primitive {
void AddPyAttr(const py::str &name, const py::object &obj);
void DelPyAttr(const py::str &name);
py::dict GetAttrDict();
void set_hook(const py::function &hook) { hook_ = hook; }
py::function hook() const { return hook_; }

@ -60,6 +60,11 @@ class Primitive : public Named {
return *this;
}
Primitive &DelAttr(const std::string &name) {
attrs_.erase(name);
return *this;
}
Primitive &SetAttrs(const std::unordered_map<std::string, ValuePtr> &attrs) {
for (auto &attr : attrs) {
attrs_[attr.first] = attr.second;

@ -103,6 +103,19 @@ class Primitive(Primitive_):
self.add_attr(name, value)
return self
def del_prim_attr(self, name):
"""
Del primitive attribute.
Args:
name (str): Attribute Name.
"""
if name in self.__dict__ and name in self.attrs:
del self.__dict__[name]
del self.attrs[name]
self.del_attr(name)
return self
def set_stage(self, stage):
"""
Add stage id to primitive attribute.
@ -191,7 +204,7 @@ class Primitive(Primitive_):
def init_prim_io_names(self, inputs, outputs):
"""
Initializes the name of inputs and outpus of Tensor or attributes.
Initializes the name of inputs and outputs of Tensor or attributes.
Args:
inputs (list[str]): list of inputs names.
@ -222,9 +235,9 @@ class Primitive(Primitive_):
class PrimitiveWithCheck(Primitive):
"""
PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments
but used the infer method registed in c++ source codes.
but used the infer method registered in c++ source codes.
There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(),
There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(),
check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called.
If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of
the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
@ -301,7 +314,7 @@ class PrimitiveWithInfer(Primitive):
"""
PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python.
There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(),
There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(),
infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer
logic of the shape and type. The infer_value() is used for constant propagation.

Loading…
Cancel
Save