diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc index bc12da790b..449d7bc6f3 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.cc +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.cc @@ -283,6 +283,11 @@ void PrimitivePy::AddPyAttr(const py::str &name, const py::object &obj) { (void)this->AddAttr(attr_name, converted_ret); } +void PrimitivePy::DelPyAttr(const py::str &name) { + std::string attr_name = name; + (void)this->DelAttr(attr_name); +} + py::dict PrimitivePy::GetAttrDict() { py::dict attr_dict; for (auto &attr : attrs_) { @@ -378,6 +383,7 @@ REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { .def_readonly(PYTHON_PRIMITIVE_FLAG, &PrimitivePy::parse_info_) .def(py::init()) .def("add_attr", &PrimitivePy::AddPyAttr, "add primitive attr") + .def("del_attr", &PrimitivePy::DelPyAttr, "del primitive attr") .def("get_attr_dict", &PrimitivePy::GetAttrDict, "get primitive attr") .def("set_prim_type", &PrimitivePy::set_prim_type, "Set primitive type.") .def("set_const_prim", &PrimitivePy::set_const_prim, "Set primitive is const.") diff --git a/mindspore/ccsrc/pybind_api/ir/primitive_py.h b/mindspore/ccsrc/pybind_api/ir/primitive_py.h index 91130cf66f..6cdf46fe2a 100644 --- a/mindspore/ccsrc/pybind_api/ir/primitive_py.h +++ b/mindspore/ccsrc/pybind_api/ir/primitive_py.h @@ -49,6 +49,8 @@ class PrimitivePy : public Primitive { void AddPyAttr(const py::str &name, const py::object &obj); + void DelPyAttr(const py::str &name); + py::dict GetAttrDict(); void set_hook(const py::function &hook) { hook_ = hook; } py::function hook() const { return hook_; } diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h index b34d5823c6..994526122e 100644 --- a/mindspore/core/ir/primitive.h +++ b/mindspore/core/ir/primitive.h @@ -60,6 +60,11 @@ class Primitive : public Named { return *this; } + Primitive &DelAttr(const std::string &name) { + attrs_.erase(name); + return *this; + } + Primitive &SetAttrs(const std::unordered_map &attrs) { for (auto &attr : attrs) { attrs_[attr.first] = attr.second; diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index d2008ba770..f0684bae4e 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -103,6 +103,19 @@ class Primitive(Primitive_): self.add_attr(name, value) return self + def del_prim_attr(self, name): + """ + Del primitive attribute. + + Args: + name (str): Attribute Name. + """ + if name in self.__dict__ and name in self.attrs: + del self.__dict__[name] + del self.attrs[name] + self.del_attr(name) + return self + def set_stage(self, stage): """ Add stage id to primitive attribute. @@ -191,7 +204,7 @@ class Primitive(Primitive_): def init_prim_io_names(self, inputs, outputs): """ - Initializes the name of inputs and outpus of Tensor or attributes. + Initializes the name of inputs and outputs of Tensor or attributes. Args: inputs (list[str]): list of inputs names. @@ -222,9 +235,9 @@ class Primitive(Primitive_): class PrimitiveWithCheck(Primitive): """ PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator input arguments - but used the infer method registed in c++ source codes. + but used the infer method registered in c++ source codes. - There are three methods can be overide to define the check logic of the primitive: __check__(), check_shape(), + There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(), check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called. If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation. @@ -301,7 +314,7 @@ class PrimitiveWithInfer(Primitive): """ PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference in python. - There are four method can be overide to define the infer logic of the primitive: __infer__(), infer_shape(), + There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(), infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer logic of the shape and type. The infer_value() is used for constant propagation.