|  |  |  | @ -40,9 +40,11 @@ const std::map<std::string, NcclKernelType> kNcclTypeMap = { | 
			
		
	
		
			
				
					|  |  |  |  | static std::map<std::string, ncclDataType_t> kNcclDtypeMap = { | 
			
		
	
		
			
				
					|  |  |  |  |   {"kNumberTypeFloat32", ncclFloat}, {"kNumberTypeFloat16", ncclHalf}, {"kNumberTypeInt32", ncclInt}}; | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | typedef ncclResult_t (*AllReduce)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t); | 
			
		
	
		
			
				
					|  |  |  |  | typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, cudaStream_t); | 
			
		
	
		
			
				
					|  |  |  |  | typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t); | 
			
		
	
		
			
				
					|  |  |  |  | typedef ncclResult_t (*AllReduce)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t, | 
			
		
	
		
			
				
					|  |  |  |  |                                   const std::string &); | 
			
		
	
		
			
				
					|  |  |  |  | typedef ncclResult_t (*AllGather)(const void *, void *, size_t, ncclDataType_t, cudaStream_t, const std::string &); | 
			
		
	
		
			
				
					|  |  |  |  | typedef ncclResult_t (*ReduceScatter)(const void *, void *, size_t, ncclDataType_t, ncclRedOp_t, cudaStream_t, | 
			
		
	
		
			
				
					|  |  |  |  |                                       const std::string &); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  | template <typename T> | 
			
		
	
		
			
				
					|  |  |  |  | class NcclGpuKernel : public GpuKernel { | 
			
		
	
	
		
			
				
					|  |  |  | @ -50,6 +52,7 @@ class NcclGpuKernel : public GpuKernel { | 
			
		
	
		
			
				
					|  |  |  |  |   NcclGpuKernel() | 
			
		
	
		
			
				
					|  |  |  |  |       : nccl_kernel_type_(NCCL_INVALID_TYPE), | 
			
		
	
		
			
				
					|  |  |  |  |         nccl_reduce_type_(ncclSum), | 
			
		
	
		
			
				
					|  |  |  |  |         group_name_(""), | 
			
		
	
		
			
				
					|  |  |  |  |         input_size_(0), | 
			
		
	
		
			
				
					|  |  |  |  |         output_size_(0), | 
			
		
	
		
			
				
					|  |  |  |  |         collective_handle_(nullptr), | 
			
		
	
	
		
			
				
					|  |  |  | @ -71,7 +74,7 @@ class NcclGpuKernel : public GpuKernel { | 
			
		
	
		
			
				
					|  |  |  |  |           reinterpret_cast<AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce")); | 
			
		
	
		
			
				
					|  |  |  |  |         MS_EXCEPTION_IF_NULL(all_reduce_funcptr); | 
			
		
	
		
			
				
					|  |  |  |  |         CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), | 
			
		
	
		
			
				
					|  |  |  |  |                                                          nccl_data_type_, nccl_reduce_type_, stream), | 
			
		
	
		
			
				
					|  |  |  |  |                                                          nccl_data_type_, nccl_reduce_type_, stream, group_name_), | 
			
		
	
		
			
				
					|  |  |  |  |                                    "ncclAllReduce failed"); | 
			
		
	
		
			
				
					|  |  |  |  |         break; | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
	
		
			
				
					|  |  |  | @ -80,7 +83,7 @@ class NcclGpuKernel : public GpuKernel { | 
			
		
	
		
			
				
					|  |  |  |  |           reinterpret_cast<AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather")); | 
			
		
	
		
			
				
					|  |  |  |  |         MS_EXCEPTION_IF_NULL(all_gather_funcptr); | 
			
		
	
		
			
				
					|  |  |  |  |         CHECK_NCCL_RET_WITH_EXCEPT( | 
			
		
	
		
			
				
					|  |  |  |  |           (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream), | 
			
		
	
		
			
				
					|  |  |  |  |           (*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream, group_name_), | 
			
		
	
		
			
				
					|  |  |  |  |           "ncclAllGather failed"); | 
			
		
	
		
			
				
					|  |  |  |  |         break; | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
	
		
			
				
					|  |  |  | @ -89,7 +92,7 @@ class NcclGpuKernel : public GpuKernel { | 
			
		
	
		
			
				
					|  |  |  |  |           reinterpret_cast<ReduceScatter>(dlsym(const_cast<void *>(collective_handle_), "ReduceScatter")); | 
			
		
	
		
			
				
					|  |  |  |  |         MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr); | 
			
		
	
		
			
				
					|  |  |  |  |         CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), | 
			
		
	
		
			
				
					|  |  |  |  |                                                              nccl_data_type_, nccl_reduce_type_, stream), | 
			
		
	
		
			
				
					|  |  |  |  |                                                              nccl_data_type_, nccl_reduce_type_, stream, group_name_), | 
			
		
	
		
			
				
					|  |  |  |  |                                    "ncclReduceScatter failed"); | 
			
		
	
		
			
				
					|  |  |  |  |         break; | 
			
		
	
		
			
				
					|  |  |  |  |       } | 
			
		
	
	
		
			
				
					|  |  |  | @ -121,15 +124,18 @@ class NcclGpuKernel : public GpuKernel { | 
			
		
	
		
			
				
					|  |  |  |  |       output_size_list_.push_back(size); | 
			
		
	
		
			
				
					|  |  |  |  |       output_size_ += size; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  |     InferCommType(kernel_node); | 
			
		
	
		
			
				
					|  |  |  |  |     collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); | 
			
		
	
		
			
				
					|  |  |  |  |     MS_EXCEPTION_IF_NULL(collective_handle_); | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     InferCommType(kernel_node); | 
			
		
	
		
			
				
					|  |  |  |  |     group_name_ = GetAttr<std::string>(kernel_node, kAttrGroup); | 
			
		
	
		
			
				
					|  |  |  |  |     MS_LOG(INFO) << AnfAlgo::GetCNodeName(kernel_node) << " for group " << group_name_; | 
			
		
	
		
			
				
					|  |  |  |  |     auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id"); | 
			
		
	
		
			
				
					|  |  |  |  |     if (comm_stream_attr) { | 
			
		
	
		
			
				
					|  |  |  |  |       comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr)); | 
			
		
	
		
			
				
					|  |  |  |  |       MS_EXCEPTION_IF_NULL(comm_stream_); | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); | 
			
		
	
		
			
				
					|  |  |  |  |     MS_EXCEPTION_IF_NULL(collective_handle_); | 
			
		
	
		
			
				
					|  |  |  |  |     return true; | 
			
		
	
		
			
				
					|  |  |  |  |   } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  |  | @ -146,7 +152,7 @@ class NcclGpuKernel : public GpuKernel { | 
			
		
	
		
			
				
					|  |  |  |  |       nccl_kernel_type_ = iter->second; | 
			
		
	
		
			
				
					|  |  |  |  |     } | 
			
		
	
		
			
				
					|  |  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |  |     auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("op"); | 
			
		
	
		
			
				
					|  |  |  |  |     auto reduce_op = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr(kAttrOp); | 
			
		
	
		
			
				
					|  |  |  |  |     if (reduce_op) { | 
			
		
	
		
			
				
					|  |  |  |  |       std::string type = GetValue<std::string>(reduce_op); | 
			
		
	
		
			
				
					|  |  |  |  |       if (type == "sum") { | 
			
		
	
	
		
			
				
					|  |  |  | @ -167,6 +173,7 @@ class NcclGpuKernel : public GpuKernel { | 
			
		
	
		
			
				
					|  |  |  |  |   NcclKernelType nccl_kernel_type_; | 
			
		
	
		
			
				
					|  |  |  |  |   ncclRedOp_t nccl_reduce_type_; | 
			
		
	
		
			
				
					|  |  |  |  |   ncclDataType_t nccl_data_type_; | 
			
		
	
		
			
				
					|  |  |  |  |   std::string group_name_; | 
			
		
	
		
			
				
					|  |  |  |  |   size_t input_size_; | 
			
		
	
		
			
				
					|  |  |  |  |   size_t output_size_; | 
			
		
	
		
			
				
					|  |  |  |  |   std::vector<size_t> input_size_list_; | 
			
		
	
	
		
			
				
					|  |  |  | 
 |