diff --git a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc index 1ae1b111bc..8095a503e3 100644 --- a/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/device/gpu/gpu_kernel_runtime.cc @@ -228,7 +228,7 @@ void GPUKernelRuntime::ClearKernelOutputAddress(const session::KernelGraph *grap continue; } - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); if (device_address->ptr_) { mem_manager_->FreeMemFromMemPool(device_address); } @@ -289,7 +289,7 @@ bool GPUKernelRuntime::AddMemSwapTask(const AnfNodePtr &kernel) { for (auto &mem_swap_info : mem_swap_info_list) { auto &kernel_exec_info = mem_swap_manager_->SearchKernelExecutionInfo(mem_swap_info.kernel_); const HostAddress &host_address = kernel_exec_info.host_addrs_[mem_swap_info.output_idx_]; - auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_); + auto device_address = AnfAlgo::GetMutableOutputAddr(mem_swap_info.kernel_, mem_swap_info.output_idx_, false); if (mem_swap_info.swap_kind_ == SwapKind::kDeviceToHost) { mem_swap_manager_->AddMemSwapTask(SwapKind::kDeviceToHost, device_address, host_address); @@ -379,7 +379,8 @@ bool GPUKernelRuntime::AllocKernelInputDynamicRes(const mindspore::AnfNodePtr &k MS_EXCEPTION_IF_NULL(kernel_inputs); MS_EXCEPTION_IF_NULL(mem_swap_manager_); for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); MS_EXCEPTION_IF_NULL(device_address); if (mem_swap_manager_->trigger_swap()) { while (auto device_address_swap_in = mem_swap_manager_->UpdateSwapQueue(SwapKind::kHostToDevice)) { @@ -437,7 +438,7 @@ bool GPUKernelRuntime::AllocKernelOutputDynamicRes(const mindspore::kernel::Kern } auto output_sizes = kernel_mod.GetOutputSizeList(); for (size_t i = 0; i < output_sizes.size(); ++i) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); MS_EXCEPTION_IF_NULL(device_address); if (device_address->ptr_ == nullptr && !AttemptMallocMem(device_address, output_sizes[i])) { return false; @@ -495,7 +496,7 @@ void GPUKernelRuntime::AllocCommunicationOpInputDynamicRes(const mindspore::AnfN std::vector size_list; DeviceAddressPtrList addr_list; for (size_t i = 0; i < AnfAlgo::GetInputTensorNum(kernel); ++i) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); MS_EXCEPTION_IF_NULL(device_address); if (device_address->ptr_ == nullptr) { is_need_alloc_memory = true; @@ -520,7 +521,7 @@ void GPUKernelRuntime::AllocCommunicationOpOutputDynamicRes(const mindspore::Anf MS_EXCEPTION_IF_NULL(kernel_mod); auto output_sizes = kernel_mod->GetOutputSizeList(); for (size_t i = 0; i < output_sizes.size(); ++i) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); MS_EXCEPTION_IF_NULL(device_address); if (device_address->ptr_ == nullptr) { is_need_alloc_memory = true; @@ -578,7 +579,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, MS_LOG(EXCEPTION) << "Check dynamic reference count failed."; } if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { - auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i); + auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel, i, false); mem_manager_->FreeMemFromMemPool(device_address); device_address->set_status(DeviceAddressStatus::kInDevice); } @@ -590,7 +591,7 @@ void GPUKernelRuntime::FreeKernelDynamicRes(const mindspore::AnfNodePtr &kernel, continue; } if (kernel_ref_count_ptr->ref_count_dynamic_use_ == 0) { - auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i); + auto device_address = AnfAlgo::GetMutableOutputAddr(kernel, i, false); mem_manager_->FreeMemFromMemPool(device_address); device_address->set_status(DeviceAddressStatus::kInDevice); } diff --git a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc index 1ee2254c01..b36b397d1e 100644 --- a/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc +++ b/mindspore/ccsrc/pre_activate/mem_reuse/mem_reuse.cc @@ -228,7 +228,8 @@ KernelRefCountPtr MemReuseUtil::GetKernelInputRef(const CNodePtr &kernel, size_t << AnfAlgo::GetInputTensorNum(kernel); } auto input_node = kernel->input(input_idx + 1); - auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + auto kernel_input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); if (IsPrimitive(kernel_input.first, prim::kPrimMakeTuple)) { MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << input_idx << " is MakeTuple"; } @@ -269,7 +270,8 @@ void MemReuseUtil::SetKernelDefInputs() { if (ref_ptr != nullptr) { // set the inputs of this kernel_def auto input_node = AnfAlgo::GetInputNode(kernel, i); - auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, true); + // Graph may be all nop nodes and not remove nop node, so this can not skip nop node. + auto input = AnfAlgo::VisitKernelWithReturnType(input_node, 0, false); if (IsPrimitive(input.first, prim::kPrimMakeTuple)) { MS_LOG(EXCEPTION) << "Input node [" << input_node->DebugString() << "]'s input " << i << " is MakeTuple"; } diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/session/anf_runtime_algorithm.cc index 09ea32becb..836110b8a4 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.cc @@ -544,9 +544,10 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &an } // get output device addr of anf_node -const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx) { +const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx, + bool visit_nop_node) { MS_EXCEPTION_IF_NULL(node); - if (opt::IsNopNode(node)) { + if (opt::IsNopNode(node) && visit_nop_node) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() == 2) { @@ -565,9 +566,10 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, return addr; } -DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx) { +DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, + bool visit_nop_node) { MS_EXCEPTION_IF_NULL(node); - if (opt::IsNopNode(node)) { + if (opt::IsNopNode(node) && visit_nop_node) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); if (cnode->inputs().size() == 2) { @@ -598,14 +600,16 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_ return kernel_info->OutputAddrExist(output_idx); } -const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { +const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, + bool visit_nop_node) { KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); - return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second); + return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); } -DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) { +DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, + bool visit_nop_node) { KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx); - return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second); + return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, visit_nop_node); } // set output device addr of anf_node diff --git a/mindspore/ccsrc/session/anf_runtime_algorithm.h b/mindspore/ccsrc/session/anf_runtime_algorithm.h index bab867a3ef..223917ceec 100644 --- a/mindspore/ccsrc/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/session/anf_runtime_algorithm.h @@ -121,14 +121,16 @@ class AnfRuntimeAlgorithm { // get output select data type from prev node,input_index is the input index of current node related to prev node static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx); // get output device addr of anf_node - static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx); + static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); // get mutable output device addr of anf_node - static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx); + static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx, bool visit_nop_node = true); // check whether output addr is exist or not static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx); // get address from prev node,input_index is the input index of current node related to prev node - static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx); - static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx); + static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx, + bool visit_nop_node = true); + static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx, + bool visit_nop_node = true); // set output device addr of anf_node static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node); // set workspace device addr of anf_node diff --git a/tests/st/ops/gpu/test_flatten_op.py b/tests/st/ops/gpu/test_flatten_op.py index 3d8ba96b7f..504d7c06b3 100644 --- a/tests/st/ops/gpu/test_flatten_op.py +++ b/tests/st/ops/gpu/test_flatten_op.py @@ -31,6 +31,49 @@ class NetFlatten(nn.Cell): return self.flatten(x) +class NetAllFlatten(nn.Cell): + def __init__(self): + super(NetAllFlatten, self).__init__() + self.flatten = P.Flatten() + + def construct(self, x): + loop_count = 4 + while loop_count > 0: + x = self.flatten(x) + loop_count = loop_count - 1 + return x + + +class NetFirstFlatten(nn.Cell): + def __init__(self): + super(NetFirstFlatten, self).__init__() + self.flatten = P.Flatten() + self.relu = P.ReLU() + + def construct(self, x): + loop_count = 4 + while loop_count > 0: + x = self.flatten(x) + loop_count = loop_count - 1 + x = self.relu(x) + return x + + +class NetLastFlatten(nn.Cell): + def __init__(self): + super(NetLastFlatten, self).__init__() + self.flatten = P.Flatten() + self.relu = P.ReLU() + + def construct(self, x): + loop_count = 4 + x = self.relu(x) + while loop_count > 0: + x = self.flatten(x) + loop_count = loop_count - 1 + return x + + @pytest.mark.level0 @pytest.mark.platform_x86_gpu_training @pytest.mark.env_onecard @@ -46,3 +89,55 @@ def test_flatten(): flatten = NetFlatten() output = flatten(x) assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_all_flatten(): + x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) + expect = np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + flatten = NetAllFlatten() + output = flatten(x) + assert (output.asnumpy() == expect).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + flatten = NetAllFlatten() + output = flatten(x) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_first_flatten(): + x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) + expect = np.array([[0, 0.3, 3.6], [0.4, 0.5, 0]]).astype(np.float32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + flatten = NetFirstFlatten() + output = flatten(x) + assert (output.asnumpy() == expect).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + flatten = NetFirstFlatten() + output = flatten(x) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_last_flatten(): + x = Tensor(np.array([[-0.1, 0.3, 3.6], [0.4, 0.5, -3.2]]).astype(np.float32)) + expect = np.array([[0, 0.3, 3.6], [0.4, 0.5, 0]]).astype(np.float32) + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + flatten = NetLastFlatten() + output = flatten(x) + assert (output.asnumpy() == expect).all() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + flatten = NetLastFlatten() + output = flatten(x) + assert (output.asnumpy() == expect).all() + \ No newline at end of file