refactor primitive GetObj function

pull/3088/head
WilliamLian 5 years ago
parent 6566b38371
commit 4c4c08c726

@ -523,14 +523,8 @@ EvalResultPtr PythonPrimEvaluator::EvalPrim(const AnalysisEnginePtr &, const Abs
return iter->second;
}
auto py_args = PreparePyInputs(prim_py_, args);
auto pyobj = prim_py_->GetPyObj();
if (pyobj == nullptr) {
MS_LOG(EXCEPTION) << "[" << prim_py_->ToString() << "]: pyobj is empty";
}
auto infer_fuc = pyobj.attr("__infer__");
prim_py_->BeginRecordAddAttr();
py::dict output = infer_fuc(*py_args);
py::dict output = prim_py_->RunInfer(py_args);
prim_py_->EndRecordAddAttr();
auto added_attrs = prim_py_->evaluate_added_attrs();
MS_LOG(DEBUG) << "Output type is " << (std::string)py::str(output);

@ -654,17 +654,7 @@ static PrimitivePtr BuildPrimtiveValueWithAttributes(const PrimitivePtr &prim, c
}
}
if (!is_attr_same) {
if (prim->isa<PrimitivePy>()) {
PrimitivePyPtr prim_py = prim->cast<PrimitivePyPtr>();
auto clone_fn = prim_py->GetPyObj().attr("_clone");
py::object new_obj = clone_fn();
auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
for (auto &item : *attrs) {
cloned_prim->AddAttr(item.first, item.second);
}
return cloned_prim;
}
auto cloned_prim = std::make_shared<Primitive>(*prim);
auto cloned_prim = prim->Clone();
for (auto &item : *attrs) {
cloned_prim->AddAttr(item.first, item.second);
}

@ -280,8 +280,8 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
AbstractBasePtrList args_spec_list;
for (size_t i = 0; i < size; i++) {
ValuePtr input_value = PyAttrValue(py_args[i]);
args_spec_list.emplace_back(abstract::FromValueInside(
input_value, !py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()));
args_spec_list.emplace_back(
abstract::FromValueInside(input_value, !prim->ObjHasAttr("const_value") && input_value->isa<tensor::Tensor>()));
}
AbstractBasePtr infer_res = EvalOnePrim(prim, args_spec_list)->abstract();
op_exec_info->abstract = infer_res;
@ -296,8 +296,7 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args, py::list *const out_args)
MS_EXCEPTION_IF_NULL(op_exec_info);
op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]);
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
auto pyobj = prim->GetPyObj();
if (pyobj == nullptr) {
if (!prim->HasPyObj()) {
MS_LOG(EXCEPTION) << "pyobj is empty";
}
@ -708,7 +707,7 @@ py::tuple RunOpInner(const py::args &args) {
value_ret[0] = output["value"];
return value_ret;
}
if (py::hasattr(op_exec_info->py_primitive->GetPyObj(), "const_value")) {
if (op_exec_info->py_primitive->ObjHasAttr("const_value")) {
py::tuple value_ret(1);
value_ret[0] = "";
return value_ret;

@ -100,6 +100,7 @@ class Primitive : public Named {
return !(iter == attrs_.cend());
}
void set_prim_type(const PrimType t) { prim_type_ = t; }
virtual PrimitivePtr Clone() { return std::make_shared<Primitive>(*this); }
void set_instance_name(const std::string s) { instance_name_ = s; }
bool HasPyEvaluator() const { return prim_type_ == kPrimTypePyInferShape || prim_type_ == kPrimTypeUserCustom; }
bool HasPyInferTensor() const { return prim_type_ == kPrimTypePyInferTensor; }

@ -196,6 +196,21 @@ bool PrimitivePy::HasComputeFunction() const {
return true;
}
PrimitivePtr PrimitivePy::Clone() {
auto clone_fn = python_obj_.attr("_clone");
py::object new_obj = clone_fn();
auto cloned_prim = new_obj.cast<PrimitivePyPtr>();
return cloned_prim;
}
py::dict PrimitivePy::RunInfer(const py::tuple &args) {
if (!HasPyObj()) {
MS_LOG(EXCEPTION) << "[" << this->ToString() << "]: pyobj is empty";
}
auto infer_fuc = python_obj_.attr("__infer__");
return infer_fuc(*args);
}
REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) {
(void)py::enum_<PrimType>(*m, "prim_type", py::arithmetic())
.value("unknown", PrimType::kPrimTypeUnknown)

@ -61,6 +61,10 @@ class PrimitivePy : public Primitive {
bool HasComputeFunction() const;
const bool parse_info_ = true;
const py::object &GetPyObj() const { return python_obj_; }
py::dict RunInfer(const py::tuple &args);
bool ObjHasAttr(const char *attr_name) { return py::hasattr(python_obj_, attr_name); }
bool HasPyObj() { return python_obj_ != nullptr; }
PrimitivePtr Clone() override;
bool is_tuple_input_ = false;
private:

Loading…
Cancel
Save