diff --git a/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc b/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc index 3594081cc7..42cdcf29ec 100644 --- a/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc +++ b/mindspore/ccsrc/device/gpu/gpu_stream_assign.cc @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "device/gpu/gpu_stream_assign.h" #include #include #include @@ -21,7 +22,6 @@ #include "device/gpu/gpu_common.h" #include "device/gpu/kernel_info_setter.h" #include "device/gpu/gpu_device_manager.h" -#include "device/gpu/gpu_stream_assign.h" namespace mindspore { namespace device { @@ -36,18 +36,19 @@ void AssignGpuStream(const std::shared_ptr &kernel_graph) allreduce_kernels.emplace_back(kernel_node); } else { DeviceStream compute_stream = GPUDeviceManager::GetInstance().default_stream(); - AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast(compute_stream)), kernel_node); + MS_EXCEPTION_IF_NULL(compute_stream); + AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(reinterpret_cast(compute_stream)), kernel_node); } } if (allreduce_kernels.size() > 1) { - // Assign multiple streams only when there's Recv node for AllReduce. + // Assign multiple streams only when there're multiple AllReduce nodes. std::vector send_recv_pairs; if (FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs)) { DeviceStream comm_stream = nullptr; GPUDeviceManager::GetInstance().CreateStream(&comm_stream); std::transform( allreduce_kernels.begin(), allreduce_kernels.end(), allreduce_kernels.begin(), [&](CNodePtr allreduce_kernel) { - AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast(comm_stream)), allreduce_kernel); + AnfAlgo::SetNodeAttr(kAttrStreamId, MakeValue(reinterpret_cast(comm_stream)), allreduce_kernel); return allreduce_kernel; }); InsertStreamSwitchNode(kernel_graph, send_recv_pairs); @@ -161,25 +162,28 @@ bool GenSendRecvCNodesForAllReduce(const std::shared_ptr & cudaEvent_t event = nullptr; CHECK_CUDA_RET_WITH_EXCEPT(cudaEventCreate(&event, cudaEventDisableTiming), "Creating cuda event failed."); - AnfAlgo::SetNodeAttr("record_event", MakeValue(reinterpret_cast(event)), *send_node); - AnfAlgo::SetNodeAttr("wait_event", MakeValue(reinterpret_cast(event)), *recv_node); + AnfAlgo::SetNodeAttr(kAttrRecordEvent, MakeValue(reinterpret_cast(event)), *send_node); + AnfAlgo::SetNodeAttr(kAttrWaitEvent, MakeValue(reinterpret_cast(event)), *recv_node); - uintptr_t send_stream = AnfAlgo::GetNodeAttr(mock_send_node, "stream_id"); - AnfAlgo::SetNodeAttr("record_event_stream", MakeValue(send_stream), *send_node); - uintptr_t recv_stream = AnfAlgo::GetNodeAttr(mock_recv_node, "stream_id"); - AnfAlgo::SetNodeAttr("wait_event_stream", MakeValue(recv_stream), *recv_node); + uintptr_t send_stream = AnfAlgo::GetNodeAttr(mock_send_node, kAttrStreamId); + AnfAlgo::SetNodeAttr(kAttrRecordEventStream, MakeValue(send_stream), *send_node); + uintptr_t recv_stream = AnfAlgo::GetNodeAttr(mock_recv_node, kAttrStreamId); + AnfAlgo::SetNodeAttr(kAttrWaitEventStream, MakeValue(recv_stream), *recv_node); return true; } CNodePtr CreateStreamSwitchNode(const std::shared_ptr &kernel_graph, const std::string &name) { auto op = std::make_shared(name); + MS_EXCEPTION_IF_NULL(op); auto apply = std::make_shared(op); + MS_EXCEPTION_IF_NULL(apply); std::vector input_list = {apply}; CNodePtr node = kernel_graph->NewCNode(input_list); MS_EXCEPTION_IF_NULL(node); kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder; AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), node.get()); auto abstract_none = std::make_shared(); + MS_EXCEPTION_IF_NULL(abstract_none); node->set_abstract(abstract_none); SetKernelInfo(node); return node; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index a6d9b7e32a..d2c5225ab2 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -202,6 +202,11 @@ constexpr auto kAttrLabelSwitchList = "label_switch_list"; constexpr auto kAttrNewAxisMask = "new_axis_mask"; constexpr auto kAttrShrinkAxisMask = "shrink_axis_mask"; constexpr auto kAttrDatadumpOriginalNames = "_datadump_original_names"; +constexpr auto kAttrStreamId = "stream_id"; +constexpr auto kAttrRecordEvent = "record_event"; +constexpr auto kAttrWaitEvent = "wait_event"; +constexpr auto kAttrRecordEventStream = "record_event_stream"; +constexpr auto kAttrWaitEventStream = "wait_event_stream"; // attr value constexpr auto kValueTargetSwitch = "target_switch";