diff --git a/mindspore/ccsrc/pre_activate/common/helper.cc b/mindspore/ccsrc/pre_activate/common/helper.cc index be3bb533ce..290fd24b59 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.cc +++ b/mindspore/ccsrc/pre_activate/common/helper.cc @@ -384,7 +384,7 @@ tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple) { bool IsNopNode(const AnfNodePtr &node) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); - if (context_ptr->device_target() != kAscendDevice) { + if (context_ptr->device_target() != kAscendDevice && context_ptr->device_target() != kGPUDevice) { return false; } static std::unordered_set nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName, diff --git a/mindspore/ccsrc/pre_activate/common/helper.h b/mindspore/ccsrc/pre_activate/common/helper.h index fb55c120be..ead338d0af 100644 --- a/mindspore/ccsrc/pre_activate/common/helper.h +++ b/mindspore/ccsrc/pre_activate/common/helper.h @@ -154,6 +154,8 @@ tensor::TensorPtr CreateTensorWithValueTuple(const ValueTuplePtr &value_tuple_pt tensor::TensorPtr CreateTupleTensor(const ValueTuplePtr &value_tuple); +bool IsAllNopNode(const session::KernelGraph *const graph); + bool IsNopNode(const AnfNodePtr &node); void HideNopNode(session::KernelGraph *const graph); diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index 61d79b56d5..a08c228220 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -18,6 +18,8 @@ #include #include #include "pre_activate/mem_reuse/mem_reuse_checker.h" +#include "pre_activate/common/helper.h" + namespace mindspore { namespace memreuse { bool MemReuseUtil::InitDynamicOutputKernelRef() { @@ -324,9 +326,17 @@ void MemReuseUtil::SetSummaryNodesRefCount() { } void MemReuseUtil::SetGraphOutputRefCount() { + auto is_all_nop_node = opt::IsAllNopNode(graph_); auto nodes = AnfAlgo::GetAllOutput(graph_->output(), {prim::kPrimTupleGetItem}); for (const auto &node : nodes) { - auto kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0); + session::KernelWithIndex kernel_input; + if (is_all_nop_node) { + // The graph does not remove the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, false); + } else { + // The graph removes the nop node. + kernel_input = AnfAlgo::VisitKernelWithReturnType(node, 0, true); + } MS_EXCEPTION_IF_NULL(kernel_input.first); if (!kernel_input.first->isa() || !AnfAlgo::IsRealKernel(kernel_input.first)) { continue; diff --git a/mindspore/ccsrc/session/gpu_session.cc b/mindspore/ccsrc/session/gpu_session.cc index 5bc815de8a..999c5ba163 100644 --- a/mindspore/ccsrc/session/gpu_session.cc +++ b/mindspore/ccsrc/session/gpu_session.cc @@ -75,7 +75,6 @@ void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const { MS_EXCEPTION_IF_NULL(kernel_graph); auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); - // opt::RemoveNopNode(kernel_graph); runtime_instance->AssignMemory(kernel_graph); } @@ -84,7 +83,6 @@ void GPUSession::RunOpAllocateMemory(const std::vector &input MS_EXCEPTION_IF_NULL(kernel_graph); auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); - // opt::RemoveNopNode(kernel_graph); runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); } @@ -156,14 +154,16 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList Optimize(graph); // Assign CUDA streams AssignStream(graph); - // Remove NoOp from execution graph - // opt::HideNopNode(graph.get()); + // Hide NoOp from execution graph + opt::HideNopNode(graph.get()); // Build kernel if node is cnode BuildKernel(graph); // Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph auto execution_order = graph->execution_order(); Reorder(&execution_order); graph->set_execution_order(execution_order); + // Remove NoOp from execution graph + opt::RemoveNopNode(graph.get()); // Alloc memory, including static memory and dynamic memory AllocateMemory(graph.get()); MS_EXCEPTION_IF_NULL(context_); @@ -205,6 +205,8 @@ void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_in MS_EXCEPTION_IF_NULL(kernel_graph); SelectKernel(kernel_graph); StartKernelRT(); + // Hide NoOp from execution graph + opt::HideNopNode(kernel_graph.get()); BuildKernel(kernel_graph); run_op_graphs_[graph_info] = kernel_graph; } @@ -213,6 +215,8 @@ py::tuple GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph const std::vector &input_tensors) { auto kernel_graph = run_op_graphs_[graph_info]; MS_EXCEPTION_IF_NULL(kernel_graph); + // Remove NoOp from execution graph + opt::RemoveNopNode(kernel_graph.get()); RunOpAllocateMemory(input_tensors, kernel_graph.get()); // Execute the computation LoadInputData(kernel_graph, input_tensors);