diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc index 99e613395c..9f3011d118 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc @@ -523,14 +523,8 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs return iter->second; } auto py_args = PreparePyInputs(prim_py_, args); - - auto pyobj = prim_py_->GetPyObj(); - if (pyobj == nullptr) { - MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty"; - } - auto infer_fuc = pyobj.attr("__infer__"); prim_py_->BeginRecordAddAttr(); - py::dict output = infer_fuc(*py_args); + py::dict output = prim_py_->RunInfer(py_args); prim_py_->EndRecordAddAttr(); auto added_attrs = prim_py_->evaluate_added_attrs(); MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output); diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index ad39190dc3..fe5871fe5e 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -654,17 +654,7 @@ static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, c } } if (!is_attr_same) { - if (prim->isa()) { - PrimitivePyPtr prim_py = prim->cast(); - auto clone_fn = prim_py->GetPyObj().attr("_clone"); - py::object new_obj = clone_fn(); - auto cloned_prim = new_obj.cast(); - for (auto &item : *attrs) { - cloned_prim->AddAttr(item.first, item.second); - } - return cloned_prim; - } - auto cloned_prim = std::make_shared(*prim); + auto cloned_prim = prim->Clone(); for (auto &item : *attrs) { cloned_prim->AddAttr(item.first, item.second); } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index db41b2a0a8..fd5c8f1965 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -280,8 +280,8 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn AbstractBasePtrList args_spec_list; for (size_t i = 0; i < size; i++) { ValuePtr input_value = PyAttrValue(py_args[i]); - args_spec_list.emplace_back(abstract::FromValueInside( - input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa())); + args_spec_list.emplace_back( + abstract::FromValueInside(input_value, !prim->ObjHasAttr("const_value") && input_value->isa())); } AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract(); op_exec_info->abstract = infer_res; @@ -296,8 +296,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args) MS_EXCEPTION_IF_NULL(op_exec_info); op_exec_info->op_name = py::cast(args[PY_NAME]); auto prim = py::cast(args[PY_PRIM]); - auto pyobj = prim->GetPyObj(); - if (pyobj == nullptr) { + if (!prim->HasPyObj()) { MS_LOG(EXCEPTION) << "pyobj is empty"; } @@ -708,7 +707,7 @@ py::tuple RunOpInner(const py::args &args) { value_ret[0] = output["value"]; return value_ret; } - if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) { + if (op_exec_info->py_primitive->ObjHasAttr("const_value")) { py::tuple value_ret(1); value_ret[0] = ""; return value_ret; diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h index a1784a85a3..c00af41950 100644 --- a/mindspore/core/ir/primitive.h +++ b/mindspore/core/ir/primitive.h @@ -100,6 +100,7 @@ class Primitive : public Named { return !(iter == attrs_.cend()); } void set_prim_type(const PrimType t) { prim_type_ = t; } + virtual PrimitivePtr Clone() { return std::make_shared(*this); } void set_instance_name(const std::string s) { instance_name_ = s; } bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; } bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; } diff --git a/mindspore/core/ir/primitive_py.cc b/mindspore/core/ir/primitive_py.cc index 2a8f003623..15a19b703a 100644 --- a/mindspore/core/ir/primitive_py.cc +++ b/mindspore/core/ir/primitive_py.cc @@ -196,6 +196,21 @@ bool PrimitivePy::HasComputeFunction() const { return true; } +PrimitivePtr PrimitivePy::Clone() { + auto clone_fn = python_obj_.attr("_clone"); + py::object new_obj = clone_fn(); + auto cloned_prim = new_obj.cast(); + return cloned_prim; +} + +py::dict PrimitivePy::RunInfer(const py::tuple &args) { + if (!HasPyObj()) { + MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty"; + } + auto infer_fuc = python_obj_.attr("__infer__"); + return infer_fuc(*args); +} + REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { (void)py::enum_(*m, "prim_type", py::arithmetic()) .value("unknown", PrimType::kPrimTypeUnknown) diff --git a/mindspore/core/ir/primitive_py.h b/mindspore/core/ir/primitive_py.h index 8c576016fa..01af9c530f 100644 --- a/mindspore/core/ir/primitive_py.h +++ b/mindspore/core/ir/primitive_py.h @@ -61,6 +61,10 @@ class PrimitivePy : public Primitive { bool HasComputeFunction() const; const bool parse_info_ = true; const py::object &GetPyObj() const { return python_obj_; } + py::dict RunInfer(const py::tuple &args); + bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); } + bool HasPyObj() { return python_obj_ != nullptr; } + PrimitivePtr Clone() override; bool is_tuple_input_ = false; private: