diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index fc9d51ad53..2543fc9878 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -296,57 +296,6 @@ std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, return graph_info; } -py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { - MS_LOG(INFO) << "RunOpInVM start"; - - MS_EXCEPTION_IF_NULL(status); - MS_EXCEPTION_IF_NULL(op_exec_info); - MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive); - - auto &op_inputs = op_exec_info->op_inputs; - if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf") { - py::tuple result(op_inputs.size()); - for (size_t i = 0; i < op_inputs.size(); i++) { - py::object input = op_inputs[i]; - auto tensor = py::cast(input); - auto new_tensor = std::make_shared(tensor->data_type(), tensor->shape(), tensor->data_ptr()); - new_tensor->set_device_address(tensor->device_address()); - new_tensor->set_sync_status(tensor->sync_status()); - result[i] = new_tensor; - } - *status = PYNATIVE_SUCCESS; - MS_LOG(INFO) << "RunOpInVM end"; - return std::move(result); - } - auto primitive = op_exec_info->py_primitive; - MS_EXCEPTION_IF_NULL(primitive); - auto result = primitive->RunPyComputeFunction(op_inputs); - if (py::isinstance(result)) { - MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func"; - *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; - py::tuple err_ret(0); - return std::move(err_ret); - } - if (op_exec_info->op_name == "stop_gradient" && py::isinstance(result)) { - py::tuple tuple_result(1); - auto tensor = py::cast(result); - MS_EXCEPTION_IF_NULL(tensor); - auto new_tensor = std::make_shared(tensor->data_type(), tensor->shape(), tensor->data_ptr()); - new_tensor->set_device_address(tensor->device_address()); - new_tensor->set_sync_status(tensor->sync_status()); - tuple_result[0] = new_tensor; - *status = PYNATIVE_SUCCESS; - MS_LOG(INFO) << "RunOpInVM end"; - return std::move(tuple_result); - } - - // execute op - py::tuple tuple_result = py::make_tuple(result); - *status = PYNATIVE_SUCCESS; - MS_LOG(INFO) << "RunOpInVM end"; - return std::move(tuple_result); -} - bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_index, const PrimitivePtr &op_prim, const std::unordered_set &input_attrs) { MS_EXCEPTION_IF_NULL(op_prim); @@ -1321,6 +1270,54 @@ py::object PynativeExecutor::RunOpWithBackendPolicy(MsBackendPolicy backend_poli return result; } +py::object PynativeExecutor::RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { + MS_LOG(INFO) << "RunOpInVM start"; + MS_EXCEPTION_IF_NULL(status); + MS_EXCEPTION_IF_NULL(op_exec_info); + MS_EXCEPTION_IF_NULL(op_exec_info->py_primitive); + + auto &op_inputs = op_exec_info->op_inputs; + if (op_exec_info->op_name == "HookBackward" || op_exec_info->op_name == "InsertGradientOf" || + op_exec_info->op_name == "stop_gradient") { + py::tuple result(op_inputs.size()); + for (size_t i = 0; i < op_inputs.size(); i++) { + py::object input = op_inputs[i]; + auto input_obj_id = GetId(input); + auto tensor = py::cast(input); + MS_EXCEPTION_IF_NULL(tensor); + if (obj_to_forward_id_.find(input_obj_id) == obj_to_forward_id_.end() && + op_exec_info->op_name == "HookBackward") { + // the input object is not a output of forward cnode, eg: parameter + result[i] = tensor; + } else { + // the input object is a output of forward cnode + auto new_tensor = std::make_shared(tensor->data_type(), tensor->shape(), tensor->data_ptr()); + new_tensor->set_device_address(tensor->device_address()); + new_tensor->set_sync_status(tensor->sync_status()); + result[i] = new_tensor; + } + } + *status = PYNATIVE_SUCCESS; + MS_LOG(INFO) << "RunOpInVM end"; + return std::move(result); + } + + auto primitive = op_exec_info->py_primitive; + MS_EXCEPTION_IF_NULL(primitive); + auto result = primitive->RunPyComputeFunction(op_inputs); + if (py::isinstance(result)) { + MS_LOG(ERROR) << "VM got the result none, please check whether it is failed to get func"; + *status = PYNATIVE_OP_NOT_IMPLEMENTED_ERR; + py::tuple err_ret(0); + return std::move(err_ret); + } + // execute op + py::tuple tuple_result = py::make_tuple(result); + *status = PYNATIVE_SUCCESS; + MS_LOG(INFO) << "RunOpInVM end"; + return std::move(tuple_result); +} + py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_EXCEPTION_IF_NULL(op_exec_info); MS_EXCEPTION_IF_NULL(status); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index f41034017c..ad89c92c47 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -53,8 +53,6 @@ struct PrimAbsInfo { using AbstractListMap = std::unordered_map; -py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); - py::tuple RunOp(const py::args &args); void ClearPyNativeSession(); @@ -124,6 +122,7 @@ class PynativeExecutor : public std::enable_shared_from_this { MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info); py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info); void RunParameterAutoMixPrecisionCast(const OpExecInfoPtr &op_exec_info); + py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status); py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecInfoPtr &op_exec_info, PynativeStatusCode *const status);