From 701ab0d05fa742da39654c0ef3b7ec09e9deebc7 Mon Sep 17 00:00:00 2001 From: chujinjin Date: Tue, 1 Dec 2020 21:34:56 +0800 Subject: [PATCH] add cpu pynative mode --- .../ccsrc/backend/session/cpu_session.cc | 69 +++++++++++++++++++ mindspore/ccsrc/backend/session/cpu_session.h | 8 +++ .../pipeline/pynative/pynative_execute.cc | 22 ++++-- .../runtime/device/cpu/cpu_kernel_runtime.cc | 5 ++ 4 files changed, 99 insertions(+), 5 deletions(-) diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index 99558356bd..5fda81d88e 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -93,10 +93,20 @@ void CPUSession::CreateOutputTensors(const GraphId &graph_id, const std::vector< runtime_.CreateOutputTensors(kernel_graph.get(), input_tensors, outputs, tensor_to_node); } +void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr &kernel_graph) { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { + return; + } + runtime_.SyncValueNodeDeviceAddr(kernel_graph.get()); +} + void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { auto kernel_graph = GetGraph(graph_id); MS_EXCEPTION_IF_NULL(kernel_graph); + SyncValueNodeDeviceAddr(kernel_graph); MS_LOG(INFO) << "Bind input output address"; runtime_.BindInputOutput(kernel_graph.get(), inputs, outputs); @@ -130,6 +140,65 @@ void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector &input_tensors, + const std::vector &tensors_mask) { + // Check if the graph cache exists. + if (run_op_graphs_.find(graph_info) != run_op_graphs_.end()) { + return; + } + // Prepare the graph + auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); + MS_EXCEPTION_IF_NULL(kernel_graph); + SetKernelInfo(kernel_graph.get()); + BuildKernel(kernel_graph.get()); + run_op_graphs_[graph_info] = kernel_graph; +} + +void CPUSession::SetOutputFlags(const VectorRef &base_ref, std::vector *outputs_tensors) { + for (size_t i = 0; i < base_ref.size(); ++i) { + if (utils::isa(base_ref[i])) { + auto ref_iter = utils::cast(base_ref[i]); + SetOutputFlags(ref_iter, outputs_tensors); + } else if (utils::isa(base_ref[i])) { + auto tensor_ptr = utils::cast>(base_ref[i]); + tensor_ptr->SetNeedWait(false); + outputs_tensors->push_back(tensor_ptr); + } + } +} + +void CPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + std::vector *input_tensors, VectorRef *outputs, + const std::vector &tensors_mask) { + MS_EXCEPTION_IF_NULL(input_tensors); + BuildOpImpl(op_run_info, graph_info, *input_tensors, tensors_mask); + EraseValueNodeTensor(tensors_mask, input_tensors); + + auto kernel_graph = run_op_graphs_[graph_info]; + MS_EXCEPTION_IF_NULL(kernel_graph); + + runtime_.AssignKernelAddress(kernel_graph.get()); + std::map tensor_to_node; + runtime_.CreateOutputTensors(kernel_graph.get(), *input_tensors, outputs, &tensor_to_node); + runtime_.BindInputOutput(kernel_graph.get(), *input_tensors, outputs); + + MS_LOG(INFO) << "Run Op start"; + auto execution_order = kernel_graph->execution_order(); + Reorder(&execution_order); + + kernel_graph->set_execution_order(execution_order); + + bool ret = runtime_.Run(kernel_graph.get(), false); + if (!ret) { + MS_LOG(EXCEPTION) << "Run Op failed"; + } + + std::vector output_tensors; + SetOutputFlags(*outputs, &output_tensors); + MS_LOG(INFO) << "Run Op end"; +} + void CPUSession::SetKernelInfo(const KernelGraph *kernel_graph) { MS_EXCEPTION_IF_NULL(kernel_graph); auto &kernel_nodes = kernel_graph->execution_order(); diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h index 20d28c00f2..0412a0c672 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.h +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -38,10 +38,18 @@ class CPUSession : public SessionBasic { void RunGraphImpl(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; ParameterPtr CreateNewParameterFromParameter(const AnfNodePtr &anf, KernelGraph *graph) override; void Optimize(const std::shared_ptr &kernel_graph); + void BuildOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + const std::vector &input_tensors, + const std::vector &tensors_mask) override; + void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, + std::vector *input_tensors, VectorRef *outputs, + const std::vector &tensors_mask) override; private: void SetKernelInfo(const KernelGraph *kernel_graph); void BuildKernel(const KernelGraph *kernel_graph); + void SetOutputFlags(const VectorRef &base_ref, std::vector *outputs_tensors); + void SyncValueNodeDeviceAddr(const std::shared_ptr &kernel_graph); device::cpu::CPUKernelRuntime runtime_; }; MS_REG_SESSION(kCPUDevice, CPUSession); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index ef16967871..ed28b36a88 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -994,6 +994,8 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex }); return; } + auto ms_context = MsContext::GetInstance(); + auto target = ms_context->get_param(MS_CTX_DEVICE_TARGET); const auto &tensor_id_list = op_index_with_tensor_id_[op_index]; for (size_t i = 0; i < tensor_id_list.size(); ++i) { auto tensor_id = tensor_id_list[i]; @@ -1003,7 +1005,20 @@ void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_ex std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) { tensor->set_shape(new_tensor->shape()); tensor->set_data_type(new_tensor->data_type()); - tensor->set_device_address(new_tensor->device_address()); + if (target != kCPUDevice) { + tensor->set_device_address(new_tensor->device_address()); + } else { + auto old_device_address = std::dynamic_pointer_cast(tensor->device_address()); + auto new_device_address = std::dynamic_pointer_cast(new_tensor->device_address()); + auto old_ptr = old_device_address->GetMutablePtr(); + auto new_ptr = new_device_address->GetPtr(); + MS_EXCEPTION_IF_NULL(old_ptr); + MS_EXCEPTION_IF_NULL(new_ptr); + auto ret = memcpy_s(old_ptr, old_device_address->GetSize(), new_ptr, new_device_address->GetSize()); + if (ret != EOK) { + MS_LOG(EXCEPTION) << "Memory copy failed. ret: " << ret; + } + } }); } } @@ -1264,12 +1279,9 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati MS_LOG(INFO) << "Start run op [" << op_exec_info->op_name << "] with backend policy ms"; auto ms_context = MsContext::GetInstance(); ms_context->set_param(MS_CTX_ENABLE_PYNATIVE_INFER, true); - std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); - if (device_target != kAscendDevice && device_target != kGPUDevice) { - MS_EXCEPTION(ArgumentError) << "Device target [" << device_target << "] is not supported in Pynative mode"; - } if (session == nullptr) { + std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); session = session::SessionFactory::Get().Create(device_target); MS_EXCEPTION_IF_NULL(session); session->Init(ms_context->get_param(MS_CTX_DEVICE_ID)); diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index 54ac40791e..cfb115c5fd 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -56,6 +56,11 @@ void CPUKernelRuntime::AssignValueNodeAddress(session::KernelGraph *kernel_graph } auto tensor = node_value->cast(); MS_EXCEPTION_IF_NULL(tensor); + if (tensor->device_address() != nullptr) { + AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast(tensor->device_address()), 0, + item_node.get()); + continue; + } TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item_node, 0); if (output_type_id == kTypeUnknown) { output_type_id = AnfAlgo::GetOutputInferDataType(item_node, 0);