From 97a97e02db0faf58e8cb2ad760b5532630d8834c Mon Sep 17 00:00:00 2001 From: kswang Date: Wed, 31 Mar 2021 16:03:08 +0800 Subject: [PATCH] extract load input --- .../ccsrc/backend/session/ascend_session.cc | 2 -- .../ccsrc/backend/session/cpu_session.cc | 26 +++++++++++++++++++ mindspore/ccsrc/backend/session/cpu_session.h | 2 ++ mindspore/ccsrc/backend/session/executor.cc | 1 + .../ccsrc/backend/session/gpu_session.cc | 2 -- mindspore/ccsrc/backend/session/gpu_session.h | 5 ++-- .../ccsrc/backend/session/kernel_graph.cc | 12 ++++----- .../ccsrc/backend/session/kernel_graph.h | 4 +-- .../ccsrc/backend/session/session_basic.h | 7 +++++ .../runtime/device/cpu/cpu_kernel_runtime.cc | 11 ++------ 10 files changed, 48 insertions(+), 24 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index d3b9d840c3..3ee7e52c9c 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -571,8 +571,6 @@ void AscendSession::RunGraphImpl(const GraphId &graph_id, const std::vector memo; SyncDataToExtraParams(NOT_NULL(kernel_graph), NOT_NULL(&memo)); memo.clear(); - // load input data from user input - LoadInputData(kernel_graph, inputs); if (debugger_) { debugger_->PreExecute(kernel_graph, graph_sum_); } diff --git a/mindspore/ccsrc/backend/session/cpu_session.cc b/mindspore/ccsrc/backend/session/cpu_session.cc index 76c93921b6..abd101ac9a 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.cc +++ b/mindspore/ccsrc/backend/session/cpu_session.cc @@ -130,6 +130,32 @@ void CPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr &ker runtime_.SyncValueNodeDeviceAddr(kernel_graph.get()); } +void CPUSession::LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const { + MS_EXCEPTION_IF_NULL(kernel_graph); + auto &input_nodes = kernel_graph->inputs(); + if (input_nodes.size() != inputs_const.size()) { + MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; + } + for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) { + auto &item = input_nodes[input_idx]; + MS_EXCEPTION_IF_NULL(item); + if (item->isa() && !HasAbstractMonad(item)) { + auto address = AnfAlgo::GetMutableOutputAddr(item, 0); + auto tensor = inputs_const[input_idx]; + auto tensor_address = tensor->device_address(); + MS_EXCEPTION_IF_NULL(address); + MS_EXCEPTION_IF_NULL(tensor); + if (tensor_address != nullptr && tensor_address != address && + (std::dynamic_pointer_cast(tensor_address)->DeviceType() != + device::DeviceAddressType::kCPU || + AnfAlgo::IsParameterWeight(item->cast()))) { + tensor->data_sync(false); + } + } + } +} + void CPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { auto kernel_graph = GetGraph(graph_id); diff --git a/mindspore/ccsrc/backend/session/cpu_session.h b/mindspore/ccsrc/backend/session/cpu_session.h index 559d7aaa76..ae1358c354 100644 --- a/mindspore/ccsrc/backend/session/cpu_session.h +++ b/mindspore/ccsrc/backend/session/cpu_session.h @@ -44,6 +44,8 @@ class CPUSession : public SessionBasic { const std::vector &tensors_mask) override; void RunOpImpl(const GraphInfo &graph_info, OpRunInfo *op_run_info, std::vector *input_tensors, VectorRef *outputs, const std::vector &tensors_mask) override; + void LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const override; private: void Reorder(std::vector *node_list); diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index c29f73d792..6ef0506c3f 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -161,6 +161,7 @@ void RunGraphTask::Run() { } graph->ResetGraphRunningStatus(); try { + session_->LoadInputs(graph_id_, input_tensors_); session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_); UpdateOutputTensors(&outputs_, tensor_to_node_); } catch (const std::exception &e) { diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 8dd9d1bdb2..2ccac6ba9d 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -425,8 +425,6 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vectorPreExecute(kernel_graph, graph_sum_); } diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index 93cadc5cd8..0cbf16c039 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -47,6 +47,8 @@ class GPUSession : public SessionBasic { VectorRef *outputs, const std::vector &tensors_mask) override; std::shared_ptr CreateBucket(uint32_t bucket_id, uint32_t bucket_size) override; std::string GetCommWorldGroup() override { return kNcclWorldGroup; } + void LoadInputData(const std::shared_ptr &kernel_graph, + const std::vector &inputs_const) const override; private: void SelectKernel(const std::shared_ptr &kernel_graph) const; @@ -71,9 +73,6 @@ class GPUSession : public SessionBasic { void RunOpClearMemory(KernelGraph *kernel_graph) const; - void LoadInputData(const std::shared_ptr &kernel_graph, - const std::vector &inputs_const) const override; - void Execute(const std::shared_ptr &kernel_graph) const; void Dump(const std::shared_ptr &kernel_graph) const; diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index b68e7c267d..c1160a28da 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -180,8 +180,8 @@ std::vector KernelGraph::outputs() const { return std::vector(1, graph_output); } -void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, - std::unordered_set *visited_nodes, bool comm_first) { +void KernelGraph::EnqueueActiveNodes(const AnfNodePtr &node, std::queue *visit_queue, + std::unordered_set *visited_nodes, bool comm_first) { MS_EXCEPTION_IF_NULL(visit_queue); MS_EXCEPTION_IF_NULL(visited_nodes); auto it = node_output_edges_.find(node); @@ -241,7 +241,7 @@ void KernelGraph::SetExecOrderByDefault() { while (!seed_nodes.empty() || !delay_comm_stack.empty()) { // seed nodes first, then delay comm nodes if (seed_nodes.empty()) { - VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); + EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); delay_comm_stack.pop(); } else { zero_input_nodes.push(seed_nodes.front()); @@ -272,16 +272,16 @@ void KernelGraph::SetExecOrderByDefault() { } if (optimize_comm) { while (!delay_comm_stack.empty()) { - VisitNodeDescendants(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); + EnqueueActiveNodes(delay_comm_stack.top(), &communication_descendants, &visited_nodes, false); delay_comm_stack.pop(); } delay_comm_stack.push(node); } else if (is_fused_comm) { delay_comm_stack.push(node); } else if (is_communication_descendant) { - VisitNodeDescendants(node, &communication_descendants, &visited_nodes); + EnqueueActiveNodes(node, &communication_descendants, &visited_nodes); } else { - VisitNodeDescendants(node, &zero_input_nodes, &visited_nodes); + EnqueueActiveNodes(node, &zero_input_nodes, &visited_nodes); } } } diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 2ee32dcd41..7b34641959 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -283,8 +283,8 @@ class KernelGraph : public FuncGraph { void SetKernelInfoForNode(const AnfNodePtr &node) const; void ResetInFormat(const AnfNodePtr &node, const std::string &format) const; AnfNodePtr MakeValueNode(const AnfNodePtr &node); - void VisitNodeDescendants(const AnfNodePtr &node, std::queue *visit_queue, - std::unordered_set *visited_nodes, bool comm_first = true); + void EnqueueActiveNodes(const AnfNodePtr &node, std::queue *visit_queue, + std::unordered_set *visited_nodes, bool comm_first = true); // update node edge list void UpdateNodeEdgeList(std::queue *seed_nodes); // add node depend edge by data edge or control depend diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index fde8186804..b346cf7839 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -181,6 +181,13 @@ class SessionBasic : public std::enable_shared_from_this { const std::map &cnode_refcount) {} virtual void SetSummaryNodes(KernelGraph *graph); + void LoadInputs(const GraphId &graph_id, const std::vector &inputs_const) { + auto kernel_graph = GetGraph(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_LOG(INFO) << "Load inputs"; + LoadInputData(kernel_graph, inputs_const); + } + virtual void LoadInputData(const std::shared_ptr &kernel_graph, const std::vector &inputs_const) const; void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector *input_tensors); diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index 20f291a32f..0dfefce755 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -283,20 +283,14 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker if (input_nodes.size() != inputs.size()) { MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; } - size_t input_idx = 0; - for (auto &item : input_nodes) { + for (size_t input_idx = 0; input_idx < input_nodes.size(); ++input_idx) { + auto &item = input_nodes[input_idx]; MS_EXCEPTION_IF_NULL(item); if (item->isa() && !HasAbstractMonad(item)) { auto address = AnfAlgo::GetMutableOutputAddr(item, 0); auto tensor = inputs[input_idx]; - auto tensor_address = tensor->device_address(); MS_EXCEPTION_IF_NULL(address); MS_EXCEPTION_IF_NULL(tensor); - if (tensor_address != nullptr && tensor_address != address && - (std::dynamic_pointer_cast(tensor_address)->DeviceType() != DeviceAddressType::kCPU || - AnfAlgo::IsParameterWeight(item->cast()))) { - tensor->data_sync(false); - } if (GetTypeByte(TypeIdToType(tensor->data_type())) == GetTypeByte(TypeIdToType(address->type_id_))) { address->ptr_ = tensor->data_c(); } else { @@ -318,7 +312,6 @@ void CPUKernelRuntime::BindInputTensorAddressPtr(const session::KernelGraph &ker address->ref_count_ = INIT_NODE_REF; tensor->set_device_address(address); } - input_idx++; } }