diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 0bf447751b..6fcd3b65a0 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -307,7 +307,7 @@ CNodePtr KernelGraph::NewCNode(const std::vector &inputs) { if (inputs.size() == 1 || !feature_map_input_indexs.empty()) { kernel_info->SetFeatureMapFlag(true); } - if (AnfAlgo::IsRealCNodeKernel(cnode)) { + if (AnfAlgo::IsRealKernel(cnode)) { AnfAlgo::SetNodeAttr(kIsFeatureMapOutput, MakeValue(kernel_info->is_feature_map()), cnode); AnfAlgo::SetNodeAttr(kIsFeatureMapInputList, MakeValue(feature_map_input_indexs), cnode); } diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 5e3add1b5f..db41b2a0a8 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -363,19 +363,21 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat MS_LOG(INFO) << "RunOpInVM end"; return std::move(result); } - auto func = op_exec_info->py_primitive->GetComputeFunction(); - if (py::isinstance(func)) { - MS_LOG(ERROR) << "VM failed to get func"; + auto primitive = op_exec_info->py_primitive; + MS_EXCEPTION_IF_NULL(primitive); + auto result = primitive->RunPyComputeFunction(op_exec_info->op_inputs); + if (py::isinstance(result)) { + MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func"; *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; py::tuple err_ret(0); return std::move(err_ret); } // execute op - py::tuple result = py::make_tuple(func(*op_exec_info->op_inputs)); + py::tuple tuple_result = py::make_tuple(result); *status = PYNATIVE_SUCCESS; MS_LOG(INFO) << "RunOpInVM end"; - return std::move(result); + return std::move(tuple_result); } bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim, diff --git a/mindspore/ccsrc/utils/primitive_utils.cc b/mindspore/ccsrc/utils/primitive_utils.cc index 490e2517a9..abd5cb1660 100644 --- a/mindspore/ccsrc/utils/primitive_utils.cc +++ b/mindspore/ccsrc/utils/primitive_utils.cc @@ -15,6 +15,9 @@ */ #include "utils/primitive_utils.h" + +#include + #include "pipeline/jit/parse/python_adapter.h" #include "utils/log_adapter.h" #include "common/utils.h" @@ -43,4 +46,25 @@ py::function GetComputeFunction(std::string name) { py::object fn = mod.attr(common::SafeCStr(name)); return fn; } + +py::tuple ConvertDatatoPyTuple(const VectorRef &args) { + auto py_args = py::tuple(args.size()); + size_t i = 0; + for (auto &arg : args) { + py_args[i] = BaseRefToPyData(arg); + MS_LOG(DEBUG) << "arg:" << i << ":" << arg.ToString(); + i++; + } + return py_args; +} + +BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args) { + auto func = GetComputeFunction(prim->name()); + if (py::isinstance(func)) { + MS_LOG(EXCEPTION) << prim->name() << " 's compute function run failed, please check whether it is not implemented"; + } + auto py_args = ConvertDatatoPyTuple(args); + py::object obj = func(*py_args); + return std::make_shared(obj); +} } // namespace mindspore diff --git a/mindspore/ccsrc/utils/primitive_utils.h b/mindspore/ccsrc/utils/primitive_utils.h index b7e2515aea..0faeca9c47 100644 --- a/mindspore/ccsrc/utils/primitive_utils.h +++ b/mindspore/ccsrc/utils/primitive_utils.h @@ -19,6 +19,7 @@ #include #include "pybind11/pybind11.h" +#include "utils/base_ref.h" namespace py = pybind11; @@ -28,6 +29,10 @@ py::function GetBpropFunctionByObj(py::object obj); py::function GetBpropFunction(std::string name); py::function GetComputeFunction(std::string name); + +BaseRef RunComputeFunction(const PrimitivePtr &prim, const VectorRef &args); + +py::tuple ConvertDatatoPyTuple(const VectorRef &args); } // namespace mindspore #endif // MINDSPORE_CCSRC_UTILS_PRIMITIVE_UTILS_H_ diff --git a/mindspore/ccsrc/vm/vmimpl.cc b/mindspore/ccsrc/vm/vmimpl.cc index 2aebf8ad0d..8ce65c3a26 100644 --- a/mindspore/ccsrc/vm/vmimpl.cc +++ b/mindspore/ccsrc/vm/vmimpl.cc @@ -440,25 +440,13 @@ VectorRef VM::RunGraph(const FuncGraphPtr &g, const VectorRef &args) { } BaseRef RunOperation(const PrimitivePtr &prim, const VectorRef &args) { - PrimitivePyPtr operation = dyn_cast(prim); - MS_LOG(DEBUG) << "operation start " << prim->name(); - auto func = operation != nullptr ? operation->GetComputeFunction() : GetComputeFunction(prim->name()); - if (py::isinstance(func)) { - MS_LOG(EXCEPTION) << prim->name() << " 's compute function is not implemented"; - } - - py::tuple py_args = py::tuple(args.size()); - MS_LOG(DEBUG) << "input for operation:"; - size_t i = 0; - for (auto &arg : args) { - py_args[i] = BaseRefToPyData(arg); - MS_LOG(DEBUG) << "arg: " << i << ":"; - i++; - } - py::object obj = func(*py_args); - MS_LOG(DEBUG) << "result:" << py::str(obj); - return obj; + MS_EXCEPTION_IF_NULL(prim); + auto result = prim->RunComputeFunction(args); + if (result.is_null()) { + return RunComputeFunction(prim, args); + } + return result; } } // namespace compile diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h index 5471b58063..a1784a85a3 100644 --- a/mindspore/core/ir/primitive.h +++ b/mindspore/core/ir/primitive.h @@ -83,6 +83,7 @@ class Primitive : public Named { void set_attr(const std::string &attrName, const ValuePtr &attr) { attrs_[attrName] = attr; } void EraseAttr(const std::string &attrName) { (void)attrs_.erase(attrName); } + virtual BaseRef RunComputeFunction(const VectorRef &args) const { return nullptr; } ValuePtr GetAttr(const std::string &attrName) const { auto iter = attrs_.find(attrName); diff --git a/mindspore/core/ir/primitive_py.cc b/mindspore/core/ir/primitive_py.cc index 1a97487ddc..2a8f003623 100644 --- a/mindspore/core/ir/primitive_py.cc +++ b/mindspore/core/ir/primitive_py.cc @@ -79,13 +79,7 @@ py::function PrimitivePy::GetBpropFunction() { } BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { - auto py_args = py::tuple(args.size()); - size_t i = 0; - for (auto &arg : args) { - py_args[i] = BaseRefToPyData(arg); - MS_LOG(DEBUG) << "arg:" << i << ":"; - i++; - } + auto py_args = ConvertDatatoPyTuple(args); py::object obj; bool is_bprop = this->HasAttr(kBpropAttrName); if (is_bprop) { @@ -123,7 +117,7 @@ BaseRef PrimitivePy::RunHookFunction(const VectorRef &args) const { return std::make_shared(obj); } -py::function PrimitivePy::GetComputeFunction() { +py::function PrimitivePy::GetComputeFunction() const { static const char *const compute_func_name = "vm_impl"; if (py::hasattr(python_obj_, compute_func_name)) { @@ -176,6 +170,32 @@ void PrimitivePy::CopyHookFunction(const PrimitivePtr &primitive) { this->set_hook(primitive_py->hook()); } +BaseRef PrimitivePy::RunComputeFunction(const VectorRef &args) const { + auto py_args = ConvertDatatoPyTuple(args); + auto result = this->RunPyComputeFunction(py_args); + if (py::isinstance(result)) { + return std::make_shared(nullptr); + } + return std::make_shared(result); +} + +py::object PrimitivePy::RunPyComputeFunction(const py::tuple &py_args) const { + auto func = this->GetComputeFunction(); + if (py::isinstance(func)) { + return py::none(); + } + auto result = func(*py_args); + return result; +} + +bool PrimitivePy::HasComputeFunction() const { + auto func = GetComputeFunction(); + if (py::isinstance(func)) { + return false; + } + return true; +} + REGISTER_PYBIND_DEFINE(Primitive_, ([](const py::module *m) { (void)py::enum_(*m, "prim_type", py::arithmetic()) .value("unknown", PrimType::kPrimTypeUnknown) diff --git a/mindspore/core/ir/primitive_py.h b/mindspore/core/ir/primitive_py.h index 2dc45ac341..8c576016fa 100644 --- a/mindspore/core/ir/primitive_py.h +++ b/mindspore/core/ir/primitive_py.h @@ -41,7 +41,6 @@ class PrimitivePy : public Primitive { ~PrimitivePy() override = default; MS_DECLARE_PARENT(PrimitivePy, Primitive); py::function GetBpropFunction(); - py::function GetComputeFunction(); void set_signatures( std::vector> @@ -57,11 +56,15 @@ class PrimitivePy : public Primitive { void set_hook(const py::function &hook) { hook_ = hook; } py::function hook() const { return hook_; } BaseRef RunHookFunction(const VectorRef &args) const override; + BaseRef RunComputeFunction(const VectorRef &args) const override; + py::object RunPyComputeFunction(const py::tuple &py_args) const; + bool HasComputeFunction() const; const bool parse_info_ = true; const py::object &GetPyObj() const { return python_obj_; } bool is_tuple_input_ = false; private: + py::function GetComputeFunction() const; py::object python_obj_; py::function hook_; std::vector signatures_; diff --git a/tests/ut/cpp/operator/ops_test.cc b/tests/ut/cpp/operator/ops_test.cc index 789b1cab25..20f4734bf0 100644 --- a/tests/ut/cpp/operator/ops_test.cc +++ b/tests/ut/cpp/operator/ops_test.cc @@ -454,8 +454,7 @@ TEST_F(TestOps, GetConv2DPrimPyTest) { ASSERT_TRUE(conv2d_ptr); if (nullptr != conv2d_ptr) { MS_LOG(INFO) << "Get PrimitivePyPtr: " << conv2d_ptr->name(); - auto func = conv2d_ptr->GetComputeFunction(); - if (py::isinstance(func)) { + if(!conv2d_ptr->HasComputeFunction()){ MS_LOG(EXCEPTION) << "" << conv2d_ptr->name() << "'s compute function is not implemented"; } diff --git a/tests/ut/cpp/parallel/step_parallel_test.cc b/tests/ut/cpp/parallel/step_parallel_test.cc index 5657db8790..383a061805 100644 --- a/tests/ut/cpp/parallel/step_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_parallel_test.cc @@ -294,8 +294,7 @@ TEST_F(TestStepParallel, CreatOpInstance) { ASSERT_TRUE(allreduce_ptr); if (nullptr != allreduce_ptr) { MS_LOG(INFO) << "Get PrimitivePyPtr: " << allreduce_ptr->name(); - auto func = allreduce_ptr->GetComputeFunction(); - if (py::isinstance(func)) { + if (!allreduce_ptr->HasComputeFunction()) { MS_LOG(EXCEPTION) << "" << allreduce_ptr->name() << "'s compute function is not implemented"; } diff --git a/tests/ut/cpp/vm/segment_runner_test.cc b/tests/ut/cpp/vm/segment_runner_test.cc index c83b1b3434..60c027b077 100644 --- a/tests/ut/cpp/vm/segment_runner_test.cc +++ b/tests/ut/cpp/vm/segment_runner_test.cc @@ -57,11 +57,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert1) { std::vector todos(splits.size()); auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), - [](const BaseRef& seg) -> bool { return utils::isa(seg); }); + [](const BaseRef &seg) -> bool { return utils::isa(seg); }); todos.resize(std::distance(todos.begin(), it)); ASSERT_EQ(todos.size(), 1); - AnfNodePtrList anf_list; + AnfNodePtrList anf_list; for (auto &item : utils::cast(todos[0])) { anf_list.push_back(utils::cast(item)); } @@ -81,11 +81,11 @@ TEST_F(TestCompileSegmentRunner, test_MsVmConvert2) { std::vector todos(splits.size()); auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), - [](const BaseRef& seg) -> bool { return utils::isa(seg); }); + [](const BaseRef &seg) -> bool { return utils::isa(seg); }); todos.resize(std::distance(todos.begin(), it)); ASSERT_EQ(todos.size(), 1); - AnfNodePtrList anf_list; + AnfNodePtrList anf_list; for (auto &item : utils::cast(todos[0])) { anf_list.push_back(utils::cast(item)); } @@ -105,11 +105,11 @@ TEST_F(TestCompileSegmentRunner, test_if) { std::vector todos(splits.size()); auto it = std::copy_if(std::begin(splits), std::end(splits), std::begin(todos), - [](const BaseRef& seg) -> bool { return utils::isa(seg); }); + [](const BaseRef &seg) -> bool { return utils::isa(seg); }); todos.resize(std::distance(todos.begin(), it)); ASSERT_EQ(todos.size(), 1); - AnfNodePtrList anf_list; + AnfNodePtrList anf_list; for (auto &item : utils::cast(todos[0])) { anf_list.push_back(utils::cast(item)); } @@ -122,13 +122,13 @@ TEST_F(TestCompileSegmentRunner, test_if) { TEST_F(TestCompileSegmentRunner, test_RunOperation1) { VectorRef args({1}); - auto res = RunOperation(prim::kPrimIdentity, args); + auto res = RunOperation(std::make_shared(py::str(prim::kPrimIdentity->name()), py::none()), args); ASSERT_EQ(py::cast(BaseRefToPyData(res)), 1); } TEST_F(TestCompileSegmentRunner, test_RunOperation2) { VectorRef args({1, 2}); - auto res = RunOperation(prim::kPrimScalarGt, args); + auto res = RunOperation(std::make_shared(py::str(prim::kPrimScalarGt->name()), py::none()), args); ASSERT_EQ(py::cast(BaseRefToPyData(res)), false); } } // namespace compile