diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index 528db5b76d..6b49b4b878 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -961,18 +961,40 @@ void KernelGraph::PrintGraphExecuteOrder() const { } } -void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node) { +void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, int output_idx, + bool unique_target) { if (front_node == nullptr || node == nullptr) { MS_LOG(INFO) << "Front node or node is nullptr"; return; } MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString(); front_to_internal_outputs_map_[front_node] = node; - int output_idx = 0; if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) { output_idx = AnfAlgo::GetTupleGetItemOutIndex(front_node->cast()); } - internal_outputs_to_front_map_[node][output_idx] = front_node; + internal_outputs_to_front_map_[node][output_idx] = std::pair(front_node, unique_target); +} + +void KernelGraph::AddInternalOutputTensor(const AnfNodePtr &node, int output_idx, const tensor::TensorPtr &tensor) { + if (node == nullptr) { + return; + } + internal_outputs_tensor_map_[node][output_idx] = tensor; +} + +tensor::TensorPtr KernelGraph::GetInternalOutputTensor(const AnfNodePtr &node, int output_idx) { + if (node == nullptr) { + return nullptr; + } + auto iter = internal_outputs_tensor_map_.find(node); + if (iter == internal_outputs_tensor_map_.end()) { + return nullptr; + } + auto idx_iter = iter->second.find(output_idx); + if (idx_iter == iter->second.end()) { + return nullptr; + } + return idx_iter->second; } void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx, @@ -996,7 +1018,7 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr if (src_output_idx == -1) { internal_outputs_to_front_map_[new_node] = front_nodes; for (const auto &front_node_iter : front_nodes) { - front_to_internal_outputs_map_[front_node_iter.second] = new_node; + front_to_internal_outputs_map_[front_node_iter.second.first] = new_node; } internal_outputs_to_front_map_.erase(iter); return; @@ -1008,9 +1030,9 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr MS_LOG(INFO) << "The output " << src_output_idx << " of node " << node->DebugString() << " is not an internal node"; return; } - auto front_node = front_node_iter->second; - internal_outputs_to_front_map_[new_node][dst_output_idx] = front_node; - front_to_internal_outputs_map_[front_node] = new_node; + auto front_node_pair = front_node_iter->second; + internal_outputs_to_front_map_[new_node][dst_output_idx] = front_node_pair; + front_to_internal_outputs_map_[front_node_pair.first] = new_node; front_nodes.erase(index); if (front_nodes.empty()) { internal_outputs_to_front_map_.erase(iter); @@ -1027,16 +1049,30 @@ AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_nod bool KernelGraph::IsInternalOutput(const AnfNodePtr &node, int output_idx) const { auto front_nodes_iter = internal_outputs_to_front_map_.find(node); - if (front_nodes_iter != internal_outputs_to_front_map_.end()) { - if (output_idx == -1) { - return true; - } - auto &front_nodes = front_nodes_iter->second; - if (front_nodes.find(output_idx) != front_nodes.end()) { - return true; - } + if (front_nodes_iter == internal_outputs_to_front_map_.end()) { + return false; } - return false; + if (output_idx == -1) { + return true; + } + auto &front_nodes = front_nodes_iter->second; + if (front_nodes.find(output_idx) == front_nodes.end()) { + return false; + } + return true; +} + +bool KernelGraph::IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const { + auto front_nodes_iter = internal_outputs_to_front_map_.find(node); + if (front_nodes_iter == internal_outputs_to_front_map_.end()) { + return false; + } + auto &front_nodes = front_nodes_iter->second; + auto idx_iter = front_nodes.find(output_idx); + if (idx_iter == front_nodes.end()) { + return false; + } + return idx_iter->second.second; } void KernelGraph::UpdateChildGraphOrder() { diff --git a/mindspore/ccsrc/backend/session/kernel_graph.h b/mindspore/ccsrc/backend/session/kernel_graph.h index 25cea907fb..047c21ea20 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.h +++ b/mindspore/ccsrc/backend/session/kernel_graph.h @@ -143,11 +143,16 @@ class KernelGraph : public FuncGraph { void PrintGraphExecuteOrder() const; const std::map> &summary_nodes() const { return summary_nodes_; } void set_summary_nodes(const std::map> &nodes) { summary_nodes_ = nodes; } - void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node); + void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node, int output_idx = 0, + bool unique_target = false); void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1, int dst_output_idx = -1); AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const; bool IsInternalOutput(const AnfNodePtr &node, int output_idx = -1) const; + bool IsUniqueTargetInternalOutput(const AnfNodePtr &node, int output_idx) const; + void AddInternalOutputTensor(const AnfNodePtr &node, int output_idx, const tensor::TensorPtr &tensor); + tensor::TensorPtr GetInternalOutputTensor(const AnfNodePtr &node, int output_idx); + uint32_t current_epoch() const { return current_epoch_; } void set_current_epoch(uint32_t epoch) { current_epoch_ = epoch; } void UpdateChildGraphOrder(); @@ -217,7 +222,8 @@ class KernelGraph : public FuncGraph { CNodePtr end_goto_; bool null_output_; std::unordered_map front_to_internal_outputs_map_; - std::unordered_map> internal_outputs_to_front_map_; + std::unordered_map>> internal_outputs_to_front_map_; + std::unordered_map> internal_outputs_tensor_map_; uint32_t current_epoch_; }; } // namespace session diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 5a4f443388..21ff3180e3 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -58,51 +58,38 @@ ValuePtr GetParamDefaultValue(const AnfNodePtr &node) { return parameter->default_param(); } -BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraph &graph, - const std::vector &input_tensors) { +tensor::TensorPtr CreateOutputTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph, + const DeviceAddressPtr &address) { MS_EXCEPTION_IF_NULL(node); - MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; - // if node is a value node, no need sync addr from device to host - if (!AnfAlgo::OutputAddrExist(node, output_index)) { - if (node->isa()) { - auto value_node = node->cast(); - MS_EXCEPTION_IF_NULL(value_node); - return value_node->value(); - } - if (node->isa()) { - for (size_t input_idx = 0; input_idx < graph.inputs().size(); input_idx++) { - if (input_idx >= input_tensors.size()) { - MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size(); - } - if (graph.inputs()[input_idx] == node) { - return input_tensors[input_idx]; - } - } - MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr"; - } - } - // if proccess reach here,it remarks item_with_index is a real node(Parameter,or executable CNode) - auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); - MS_EXCEPTION_IF_NULL(address); - auto shape = AnfAlgo::GetOutputInferShape(node, output_index); - TypeId type_id = kNumberTypeFloat32; - type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); + MS_EXCEPTION_IF_NULL(graph); + TypeId type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); if (type_id == kTypeUnknown) { type_id = AnfAlgo::GetOutputInferDataType(node, output_index); } + tensor::TensorPtr tensor; std::vector temp_shape; - if (graph.IsInternalOutput(node, output_index)) { + if (graph->IsUniqueTargetInternalOutput(node, output_index)) { temp_shape.emplace_back(1); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); + tensor = std::make_shared(type_id, temp_shape); tensor->set_device_address(address); tensor->set_dirty(false); return tensor; } - (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); - tensor::TensorPtr tensor = std::make_shared(type_id, temp_shape); + + tensor = graph->GetInternalOutputTensor(node, output_index); + if (tensor == nullptr) { + auto shape = AnfAlgo::GetOutputInferShape(node, output_index); + (void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); + tensor = std::make_shared(type_id, temp_shape); + bool is_internal_output = graph->IsInternalOutput(node, output_index); + if (is_internal_output) { + graph->AddInternalOutputTensor(node, output_index, tensor); + } + } // if in paynative mode,data only copyed to host when user want to print data auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); + MS_EXCEPTION_IF_NULL(address); if (ms_context->execution_mode() == kPynativeMode || ms_context->device_target() == kGPUDevice) { tensor->set_device_address(address); tensor->set_dirty(false); @@ -114,7 +101,35 @@ BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const Kerne return tensor; } -BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraph &graph, +BaseRef CreateOneTensor(const AnfNodePtr &node, size_t output_index, const KernelGraphPtr &graph, + const std::vector &input_tensors) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; + // if node is a value node, no need sync addr from device to host + if (!AnfAlgo::OutputAddrExist(node, output_index)) { + if (node->isa()) { + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + return value_node->value(); + } + if (node->isa()) { + for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) { + if (input_idx >= input_tensors.size()) { + MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size(); + } + if (graph->inputs()[input_idx] == node) { + return input_tensors[input_idx]; + } + } + MS_LOG(EXCEPTION) << "Parameter : " << node->DebugString() << " has no output addr"; + } + } + auto address = AnfAlgo::GetMutableOutputAddr(node, output_index); + return CreateOutputTensor(node, output_index, graph, address); +} + +BaseRef CreateTensorForOutput(const AnfNodePtr &anf, const KernelGraphPtr &graph, const std::vector &input_tensors) { MS_EXCEPTION_IF_NULL(anf); MS_LOG(INFO) << "Create tensor for output[" << anf->DebugString() << "]"; @@ -308,7 +323,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx); auto ref_real_node = real_kernel.first; auto ref_real_node_index = real_kernel.second; - if (ref_real_node->isa() && node_graph->IsInternalOutput(ref_real_node, ref_real_node_index)) { + if (ref_real_node->isa() && node_graph->IsUniqueTargetInternalOutput(ref_real_node, ref_real_node_index)) { auto kernel_info = ref_real_node->kernel_info(); if (kernel_info == nullptr || !kernel_info->has_build_info()) { MS_LOG(INFO) << "No kernel info"; @@ -888,7 +903,7 @@ void SessionBasic::UpdateOutputs(const std::shared_ptr &kernel_grap for (auto &item : anf_outputs) { MS_EXCEPTION_IF_NULL(item); MS_LOG(INFO) << "Update output[" << item->DebugString() << "]"; - outputs->emplace_back(CreateTensorForOutput(item, *kernel_graph, input_tensors)); + outputs->emplace_back(CreateTensorForOutput(item, kernel_graph, input_tensors)); } } @@ -967,6 +982,71 @@ void SessionBasic::Summary(KernelGraph *graph) { summary_callback_(0, params_list); } +namespace { +bool CNodePrimIsValueNode(const AnfNodePtr &node) { + if (node == nullptr) { + return false; + } + auto cnode = node->cast(); + if (cnode == nullptr) { + return false; + } + auto prim = cnode->input(kAnfPrimitiveIndex); + if (prim == nullptr || !prim->isa()) { + return false; + } + return true; +} + +void HandleInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &backend_node, + const FuncGraphManagerPtr &front_func_graph_manager, + const std::shared_ptr &backend_graph) { + auto node_users = front_func_graph_manager->node_users(); + auto users = node_users[front_node]; + auto front_real_kernel_pair = AnfAlgo::VisitKernel(front_node, 0); + auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_node, 0); + + auto front_real_kernel = front_real_kernel_pair.first; + std::string kernel_target = GetCNodeTarget(front_real_kernel); + bool internal_output = CNodePrimIsValueNode(front_real_kernel); + bool unique_target = true; + if (internal_output && opt::IsNopNode(front_real_kernel)) { + auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0); + auto pre_node_target = GetCNodeTarget(pre_node_pair.first); + if (pre_node_target != kernel_target) { + unique_target = false; + } + } + if (internal_output) { + for (auto user : users) { + auto cnode = user.first->cast(); + if (cnode == nullptr) { + internal_output = false; + break; + } + auto prim = cnode->input(kAnfPrimitiveIndex); + if (prim == nullptr || !prim->isa()) { + internal_output = false; + break; + } + if (!AnfAlgo::IsRealKernel(user.first)) { + internal_output = false; + break; + } + if (kernel_target != GetCNodeTarget(user.first)) { + unique_target = false; + } + } + } + if (internal_output) { + MS_LOG(INFO) << "Internal output: " << front_node->DebugString() << "To " + << backend_real_kernel_pair.first->DebugString(); + backend_graph->AddInternalOutput(front_node, backend_real_kernel_pair.first, backend_real_kernel_pair.second, + unique_target); + } +} +} // namespace + CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr &graph) { MS_EXCEPTION_IF_NULL(graph); std::vector output_args; @@ -982,9 +1062,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: if (context_ptr->execution_mode() == kPynativeMode) { return backend_anf; } - auto front_real_kernel_pair = AnfAlgo::VisitKernel(out, 0); - auto front_real_kernel = front_real_kernel_pair.first; - auto backend_real_kernel_pair = AnfAlgo::VisitKernel(backend_anf, 0); + MS_EXCEPTION_IF_NULL(out); auto out_func_graph = out->func_graph(); MS_EXCEPTION_IF_NULL(out_func_graph); @@ -992,51 +1070,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: if (out_func_graph_manager == nullptr) { return backend_anf; } - auto node_users = out_func_graph_manager->node_users(); - auto users = node_users[out]; - bool internal_output = true; - std::string kernel_target = GetCNodeTarget(front_real_kernel); - if (front_real_kernel != nullptr && front_real_kernel->isa()) { - auto front_cnode = front_real_kernel->cast(); - if (front_cnode != nullptr) { - auto prim = front_cnode->input(kAnfPrimitiveIndex); - if (prim == nullptr || !prim->isa()) { - internal_output = false; - } - } else { - internal_output = false; - } - } - if (internal_output && opt::IsNopNode(front_real_kernel)) { - auto pre_node_pair = AnfAlgo::GetPrevNodeOutput(front_real_kernel, 0); - auto pre_node_target = GetCNodeTarget(pre_node_pair.first); - if (pre_node_target != kernel_target) { - internal_output = false; - } - } - if (internal_output) { - for (auto user : users) { - auto cnode = user.first->cast(); - if (cnode == nullptr) { - internal_output = false; - break; - } - auto prim = cnode->input(kAnfPrimitiveIndex); - if (prim == nullptr || !prim->isa()) { - internal_output = false; - break; - } - if (!AnfAlgo::IsRealKernel(user.first) || kernel_target != GetCNodeTarget(user.first)) { - internal_output = false; - break; - } - } - } - if (internal_output) { - MS_LOG(INFO) << "Internal output: " << out->DebugString() << "To " - << backend_real_kernel_pair.first->DebugString(); - graph->AddInternalOutput(out, backend_real_kernel_pair.first); - } + HandleInternalOutput(out, backend_anf, out_func_graph_manager, graph); return backend_anf; } MS_LOG(EXCEPTION) << "Can't find the node in the equiv map!"; diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc index 5af81eadbf..6904c6a0be 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.cc @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include #include "backend/kernel_compiler/kernel.h" #include "runtime/device/cpu/cpu_device_address.h" @@ -124,11 +124,10 @@ DeviceAddressPtr CPUKernelRuntime::CreateDeviceAddress(void *device_ptr, size_t return std::make_shared(device_ptr, device_size, format, type_id); } -tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, size_t index, - std::set *bound_addresses, +tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, + size_t index, std::vector *need_sync_outputs) { MS_EXCEPTION_IF_NULL(node); - MS_EXCEPTION_IF_NULL(bound_addresses); MS_EXCEPTION_IF_NULL(need_sync_outputs); size_t output_size = AnfAlgo::GetOutputTensorNum(node); if (index >= output_size) { @@ -136,14 +135,21 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, s } auto address = AnfAlgo::GetMutableOutputAddr(node, index); MS_EXCEPTION_IF_NULL(address); - auto shape = AnfAlgo::GetOutputInferShape(node, index); - std::vector temp_shape; - (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); + TypeId infer_type_id = AnfAlgo::GetOutputInferDataType(node, index); TypeId device_type_id = AnfAlgo::GetOutputDeviceDataType(node, index); - tensor::TensorPtr tensor = std::make_shared(infer_type_id, temp_shape); - MS_EXCEPTION_IF_NULL(tensor); - if (bound_addresses->find(address) != bound_addresses->end()) { + tensor::TensorPtr tensor = kernel_graph->GetInternalOutputTensor(node, index); + if (tensor == nullptr) { + auto shape = AnfAlgo::GetOutputInferShape(node, index); + std::vector temp_shape; + (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); + tensor = std::make_shared(infer_type_id, temp_shape); + bool is_internal_output = kernel_graph->IsInternalOutput(node, index); + if (is_internal_output) { + kernel_graph->AddInternalOutputTensor(node, index, tensor); + } + } + if (bound_addresses_.find(address) != bound_addresses_.end()) { tensor->set_device_address(address); need_sync_outputs->emplace_back(tensor); } else { @@ -159,15 +165,14 @@ tensor::TensorPtr CPUKernelRuntime::CreatTensorForOutput(const CNodePtr &node, s address->ptr_ = tensor->data_c(); } address->ref_count_ = INIT_NODE_REF; - (void)bound_addresses->insert(address); + (void)bound_addresses_.insert(address); } tensor->set_dirty(false); return tensor; } -BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, - const std::unordered_map &input_map, - std::set *bound_addresses, +BaseRef CPUKernelRuntime::CreatTensorForOutput(session::KernelGraph *kernel_graph, + const session::KernelWithIndex &kernel_with_index, std::vector *need_sync_outputs) { auto &input_node = kernel_with_index.first; auto index = kernel_with_index.second; @@ -179,15 +184,15 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k VectorRef ret; for (size_t i = 1; i < node->inputs().size(); i++) { auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node->input(i), 0); - auto out = CreatTensorForOutput(item_with_index, input_map, bound_addresses, need_sync_outputs); + auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs); ret.push_back(out); } return ret; } - return CreatTensorForOutput(node, index, bound_addresses, need_sync_outputs); + return CreatTensorForOutput(kernel_graph, node, index, need_sync_outputs); } else if (input_node->isa()) { - auto iter = input_map.find(input_node.get()); - if (iter != input_map.end()) { + auto iter = input_param_tensor_map_.find(input_node); + if (iter != input_param_tensor_map_.end()) { return iter->second; } } else if (input_node->isa()) { @@ -197,10 +202,8 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const session::KernelWithIndex &k } return BaseRef(); } - -void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, - const std::vector &inputs, VectorRef *outputs, - std::vector *need_sync_outputs) { +void CPUKernelRuntime::BindInputOutput(session::KernelGraph *kernel_graph, const std::vector &inputs, + VectorRef *outputs, std::vector *need_sync_outputs) { MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(outputs); // bind input ptr @@ -208,11 +211,11 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, if (input_nodes.size() != inputs.size()) { MS_LOG(EXCEPTION) << "Input size not equal to input node size!"; } - std::unordered_map input_map; + input_param_tensor_map_.clear(); size_t input_idx = 0; for (auto &item : input_nodes) { MS_EXCEPTION_IF_NULL(item); - input_map[item.get()] = inputs[input_idx]; + input_param_tensor_map_[item] = inputs[input_idx]; if (item->isa()) { auto address = AnfAlgo::GetMutableOutputAddr(item, 0); auto tensor = inputs[input_idx]; @@ -222,7 +225,6 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, if (tensor_address != nullptr && tensor_address != address) { (void)tensor->data_sync(); } - if (tensor->data_type() == address->type_id_ || tensor->data_type() == kNumberTypeFloat32 || tensor->data_type() == kNumberTypeInt32) { address->ptr_ = tensor->data_c(); @@ -243,11 +245,11 @@ void CPUKernelRuntime::BindInputOutput(const session::KernelGraph *kernel_graph, input_idx++; } // new output and bind ptr - std::set bound_addresses; + bound_addresses_.clear(); auto output_nodes = kernel_graph->outputs(); for (const auto &item : output_nodes) { auto item_with_index = AnfAlgo::VisitKernelWithReturnType(item, 0, true); - auto out = CreatTensorForOutput(item_with_index, input_map, &bound_addresses, need_sync_outputs); + auto out = CreatTensorForOutput(kernel_graph, item_with_index, need_sync_outputs); outputs->push_back(std::move(out)); } } diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h index a486ab1a8b..e391332f85 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h @@ -19,7 +19,7 @@ #include #include #include -#include +#include #include #include "runtime/device/kernel_runtime.h" #include "backend/session/kernel_graph.h" @@ -38,7 +38,7 @@ class CPUKernelRuntime : public KernelRuntime { bool Init() override { return true; } bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override; void AssignKernelAddress(session::KernelGraph *kernel_graph); - void BindInputOutput(const session::KernelGraph *kernel_graph, const std::vector &inputs, + void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector &inputs, VectorRef *outputs, std::vector *need_sync_outputs); void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); @@ -49,19 +49,18 @@ class CPUKernelRuntime : public KernelRuntime { TypeId type_id) override; private: - tensor::TensorPtr CreatTensorForOutput(const CNodePtr &node, size_t index, - std::set *bound_addresses, + tensor::TensorPtr CreatTensorForOutput(session::KernelGraph *kernel_graph, const CNodePtr &node, size_t index, std::vector *need_sync_outputs); - BaseRef CreatTensorForOutput(const session::KernelWithIndex &kernel_with_index, - const std::unordered_map &input_map, - std::set *bound_addresses, + BaseRef CreatTensorForOutput(session::KernelGraph *kernel_graph, const session::KernelWithIndex &kernel_with_index, std::vector *need_sync_outputs); void AssignValueNodeAddress(session::KernelGraph *kernel_graph); void AssignInputNodeAddress(const session::KernelGraph *kernel_graph); void AssignKernelOutputAddress(const session::KernelGraph *kernel_graph); void AddRuntimeAddress(DeviceAddress *address, std::vector *input_list); CPUResourceManager resource_manager_; + std::set bound_addresses_; + std::map input_param_tensor_map_; }; } // namespace cpu } // namespace device