diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu index e28a3c1b6d..6840e4847c 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cu @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include - #include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/platform/float16.h" diff --git a/paddle/fluid/operators/benchmark/op_tester.cc b/paddle/fluid/operators/benchmark/op_tester.cc index e01b66b7a1..c8a04c3242 100644 --- a/paddle/fluid/operators/benchmark/op_tester.cc +++ b/paddle/fluid/operators/benchmark/op_tester.cc @@ -77,7 +77,7 @@ void OpTester::Run() { if (platform::is_cpu_place(place_)) { platform::EnableProfiler(platform::ProfilerState::kCPU); } else { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) platform::EnableProfiler(platform::ProfilerState::kAll); platform::SetDeviceId(config_.device_id); #else diff --git a/paddle/fluid/operators/collective/CMakeLists.txt b/paddle/fluid/operators/collective/CMakeLists.txt index 3962f7edf9..8920541b9b 100644 --- a/paddle/fluid/operators/collective/CMakeLists.txt +++ b/paddle/fluid/operators/collective/CMakeLists.txt @@ -13,7 +13,7 @@ endforeach() register_operators(EXCLUDES c_gen_bkcl_id_op gen_bkcl_id_op c_gen_nccl_id_op gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) -if(WITH_NCCL) +if(WITH_NCCL OR WITH_RCCL) set(COLLECTIVE_DEPS ${COLLECTIVE_DEPS} nccl_common collective_helper) op_library(c_gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) op_library(gen_nccl_id_op DEPS ${COLLECTIVE_DEPS}) diff --git a/paddle/fluid/operators/collective/allreduce_op.h b/paddle/fluid/operators/collective/allreduce_op.h index e486faa575..157924f085 100644 --- a/paddle/fluid/operators/collective/allreduce_op.h +++ b/paddle/fluid/operators/collective/allreduce_op.h @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -36,7 +36,7 @@ class AllReduceOpKernel : public framework::OpKernel { PADDLE_ENFORCE_EQ(is_gpu_place(place), true, platform::errors::PreconditionNotMet( "AllReduce op can run on gpu place only for now.")); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto& dev_ctx = ctx.template device_context(); auto in = ctx.Input("X"); auto out = ctx.Output("Out"); @@ -73,7 +73,11 @@ class AllReduceOpKernel : public framework::OpKernel { sendbuff, recvbuff, numel, static_cast(dtype), red_type, comm, stream)); if (ctx.Attr("sync_mode")) { +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif } #else PADDLE_THROW(platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/operators/collective/barrier_op.cu.cc b/paddle/fluid/operators/collective/barrier_op.cu.cc index 81597c0dac..f6281aa8ca 100644 --- a/paddle/fluid/operators/collective/barrier_op.cu.cc +++ b/paddle/fluid/operators/collective/barrier_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/barrier_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -26,7 +26,7 @@ template class BarrierOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto in = ctx.Input("X"); auto out = ctx.Output("Out"); @@ -45,7 +45,11 @@ class BarrierOpCUDAKernel : public framework::OpKernel { sendbuff, recvbuff, numel, dtype, nccl_red_type, comm->comm(), stream)); auto comm_stream = platform::NCCLCommContext::Instance().Get(rid, place)->stream(); +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(comm_stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(comm_stream)); +#endif #else PADDLE_THROW(platform::errors::Unavailable( "PaddlePaddle should compile with NCCL.")); diff --git a/paddle/fluid/operators/collective/broadcast_op.cu.cc b/paddle/fluid/operators/collective/broadcast_op.cu.cc index 471474818e..fa4d7ee4cc 100644 --- a/paddle/fluid/operators/collective/broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/broadcast_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -33,7 +33,7 @@ class NCCLBroadcastOpKernel : public framework::OpKernel { platform::errors::PreconditionNotMet( "The place of ExecutionContext should be CUDAPlace.")); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) int dev_id = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).device; int root_dev_id = ctx.Attr("root"); @@ -62,7 +62,11 @@ class NCCLBroadcastOpKernel : public framework::OpKernel { << " From " << root_dev_id << " to " << dev_id; if (ctx.Attr("sync_mode")) { +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif } #else PADDLE_THROW(platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/operators/collective/c_allgather_op.cu.cc b/paddle/fluid/operators/collective/c_allgather_op.cu.cc index 763b695e0c..597e4321d6 100644 --- a/paddle/fluid/operators/collective/c_allgather_op.cu.cc +++ b/paddle/fluid/operators/collective/c_allgather_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/c_allgather_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -26,7 +26,7 @@ template class CAllGatherOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto in = ctx.Input("X"); auto out = ctx.Output("Out"); ncclDataType_t dtype = platform::ToNCCLDataType(in->type()); @@ -48,7 +48,7 @@ class CAllGatherOpCUDAKernel : public framework::OpKernel { const T* send_buff = in->data(); T* recv_buff = out->data(); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 24f7f427cf..2f56f43d79 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -109,7 +109,7 @@ template class CAllReduceOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto in = ctx.Input("X"); auto out = ctx.Output("Out"); @@ -123,7 +123,7 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { int rid = ctx.Attr("ring_id"); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index b7fc785126..b37bd250c1 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/c_broadcast_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -26,7 +26,7 @@ template class CBroadcastOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto x = ctx.Input("X"); auto out = ctx.Output("Out"); int numel = x->numel(); @@ -36,7 +36,7 @@ class CBroadcastOpCUDAKernel : public framework::OpKernel { auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); diff --git a/paddle/fluid/operators/collective/c_comm_init_all_op.cc b/paddle/fluid/operators/collective/c_comm_init_all_op.cc index 7d1bb771ae..60a9b1ee44 100644 --- a/paddle/fluid/operators/collective/c_comm_init_all_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_all_op.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/threadpool.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -52,7 +52,7 @@ class CCommInitAllOp : public framework::OperatorBase { platform::errors::PreconditionNotMet( "CCommInitAllOp can run on gpu place only")); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) std::vector devices = Attr>("devices"); if (devices.empty()) { devices = platform::GetSelectedDevices(); diff --git a/paddle/fluid/operators/collective/c_comm_init_op.cc b/paddle/fluid/operators/collective/c_comm_init_op.cc index c5f172763d..3464bff486 100644 --- a/paddle/fluid/operators/collective/c_comm_init_op.cc +++ b/paddle/fluid/operators/collective/c_comm_init_op.cc @@ -14,6 +14,9 @@ limitations under the License. */ #if defined(PADDLE_WITH_NCCL) #include #endif +#if defined(PADDLE_WITH_RCCL) +#include +#endif #if defined(PADDLE_WITH_XPU_BKCL) #include "xpu/bkcl.h" #endif @@ -26,7 +29,8 @@ namespace framework { class Scope; } // namespace framework } // namespace paddle -#if (defined PADDLE_WITH_NCCL) || (defined PADDLE_WITH_XPU_BKCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) || \ + defined(PADDLE_WITH_XPU_BKCL) #include "paddle/fluid/platform/collective_helper.h" #endif @@ -50,7 +54,7 @@ class CCommInitOp : public framework::OperatorBase { PADDLE_ENFORCE_NOT_NULL( var, platform::errors::InvalidArgument("Input con not be empty.")); if (is_gpu_place(place)) { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) ncclUniqueId* nccl_id = var->GetMutable(); int nranks = Attr("nranks"); diff --git a/paddle/fluid/operators/collective/c_reduce_op.h b/paddle/fluid/operators/collective/c_reduce_op.h index 81dc5c35bf..1bce01e13a 100644 --- a/paddle/fluid/operators/collective/c_reduce_op.h +++ b/paddle/fluid/operators/collective/c_reduce_op.h @@ -24,7 +24,7 @@ limitations under the License. */ #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -114,7 +114,7 @@ template class CReduceOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto in = ctx.Input("X"); auto out = ctx.Output("Out"); @@ -129,7 +129,7 @@ class CReduceOpCUDAKernel : public framework::OpKernel { int root = ctx.Attr("root_id"); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); diff --git a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc index af563d022b..4d19ee4264 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/c_reducescatter_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -26,7 +26,7 @@ template class CReduceScatterOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto in = ctx.Input("X"); auto out = ctx.Output("Out"); @@ -49,7 +49,7 @@ class CReduceScatterOpCUDAKernel : public framework::OpKernel { T* recv_buff = out->data(); int dtype = platform::ToNCCLDataType(in->type()); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/c_scatter_op.cu.cc index 8d9e6b4b7d..0c9dc2af14 100644 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ b/paddle/fluid/operators/collective/c_scatter_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/c_scatter_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -26,7 +26,7 @@ template class CScatterOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) auto x = ctx.Input("X"); auto out = ctx.Output("Out"); int numel = x->numel(); @@ -53,7 +53,7 @@ class CScatterOpCUDAKernel : public framework::OpKernel { "The ring_id (%d) for c_scatter_op must be non-negative.", ring_id)); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { auto dev_ctx = platform::DeviceContextPool::Instance().Get(place); stream = static_cast(dev_ctx)->stream(); diff --git a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc index bdffe96acd..c4abe284d7 100644 --- a/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_calc_stream_op.cc @@ -37,10 +37,14 @@ class CSyncCalcStreamOp : public framework::OperatorBase { PADDLE_ENFORCE_EQ(is_gpu_place(place), true, platform::errors::PreconditionNotMet( "Sync stream op can run on gpu place only for now.")); -#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) +#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && !defined(_WIN32) auto dev_ctx = static_cast( platform::DeviceContextPool::Instance().Get(place)); +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(dev_ctx->stream())); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(dev_ctx->stream())); +#endif #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with GPU.")); diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc index aef3d83c90..adf27069f5 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op.cc @@ -19,7 +19,7 @@ namespace framework { class Scope; } // namespace framework } // namespace paddle -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #endif @@ -40,11 +40,15 @@ class CSyncCommStreamOp : public framework::OperatorBase { platform::errors::PreconditionNotMet( "Sync stream op can run on gpu place only for now.")); -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) int ring_id = Attr("ring_id"); auto stream = platform::NCCLCommContext::Instance().Get(ring_id, place)->stream(); +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipStreamSynchronize(stream)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaStreamSynchronize(stream)); +#endif #else PADDLE_THROW(platform::errors::PreconditionNotMet( "PaddlePaddle should compile with GPU.")); diff --git a/paddle/fluid/operators/collective/recv_v2_op.cu.cc b/paddle/fluid/operators/collective/recv_v2_op.cu.cc index 892056f213..5b846598b8 100644 --- a/paddle/fluid/operators/collective/recv_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/recv_v2_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/recv_v2_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -26,7 +26,8 @@ template class RecvOpV2CUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { -#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703 +#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \ + NCCL_VERSION_CODE >= 2703 int rid = ctx.Attr("ring_id"); PADDLE_ENFORCE_GE( rid, 0, @@ -45,7 +46,7 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { framework::proto::VarType::Type type = framework::proto::VarType::Type(data_type); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { @@ -65,12 +66,21 @@ class RecvOpV2CUDAKernel : public framework::OpKernel { // Recv the number of elements to receive first int numel = 0; int *numel_ptr = nullptr; +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipMalloc(&numel_ptr, sizeof(int))); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int))); +#endif PADDLE_ENFORCE_CUDA_SUCCESS( platform::dynload::ncclRecv(static_cast(numel_ptr), 1, ncclInt, peer, comm->comm(), stream)); +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS( + hipMemcpy(&numel, numel_ptr, sizeof(int), hipMemcpyDeviceToHost)); +#else PADDLE_ENFORCE_CUDA_SUCCESS( cudaMemcpy(&numel, numel_ptr, sizeof(int), cudaMemcpyDeviceToHost)); +#endif int rest_numel = 1; for (int i = 1; i < out_dims.size(); ++i) { diff --git a/paddle/fluid/operators/collective/send_v2_op.cu.cc b/paddle/fluid/operators/collective/send_v2_op.cu.cc index 4de3f47ccc..b70124a7bf 100644 --- a/paddle/fluid/operators/collective/send_v2_op.cu.cc +++ b/paddle/fluid/operators/collective/send_v2_op.cu.cc @@ -14,7 +14,7 @@ limitations under the License. */ #include "paddle/fluid/operators/collective/send_v2_op.h" -#if defined(PADDLE_WITH_NCCL) +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #include "paddle/fluid/platform/collective_helper.h" #include "paddle/fluid/platform/nccl_helper.h" #endif @@ -26,7 +26,8 @@ template class SendOpV2CUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703 +#if (defined(PADDLE_WITH_RCCL) || defined(PADDLE_WITH_NCCL)) && \ + NCCL_VERSION_CODE >= 2703 auto x = ctx.Input("X"); int numel = x->numel(); @@ -41,7 +42,7 @@ class SendOpV2CUDAKernel : public framework::OpKernel { peer, 0, platform::errors::InvalidArgument( "The peer (%d) for send_v2 op must be non-negative.", peer)); - cudaStream_t stream = nullptr; + gpuStream_t stream = nullptr; auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); if (ctx.Attr("use_calc_stream")) { @@ -59,9 +60,15 @@ class SendOpV2CUDAKernel : public framework::OpKernel { // Send number of elements to the receiver, as the receiver may have // no information of the Tensor size. int* numel_ptr = nullptr; +#ifdef PADDLE_WITH_RCCL + PADDLE_ENFORCE_CUDA_SUCCESS(hipMalloc(&numel_ptr, sizeof(int))); + PADDLE_ENFORCE_CUDA_SUCCESS( + hipMemcpy(numel_ptr, &numel, sizeof(int), hipMemcpyHostToDevice)); +#else PADDLE_ENFORCE_CUDA_SUCCESS(cudaMalloc(&numel_ptr, sizeof(int))); PADDLE_ENFORCE_CUDA_SUCCESS( cudaMemcpy(numel_ptr, &numel, sizeof(int), cudaMemcpyHostToDevice)); +#endif PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclSend( numel_ptr, 1, ncclInt, peer, comm->comm(), stream)); diff --git a/paddle/fluid/operators/detail/strided_memcpy.h b/paddle/fluid/operators/detail/strided_memcpy.h index e29b057ed5..7df0f85523 100644 --- a/paddle/fluid/operators/detail/strided_memcpy.h +++ b/paddle/fluid/operators/detail/strided_memcpy.h @@ -34,7 +34,7 @@ struct StridedMemcpyFunctor { auto& cpu_place = BOOST_GET_CONST(platform::CPUPlace, place); memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T)); } else { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto& gpu_place = BOOST_GET_CONST(platform::CUDAPlace, place); auto& cuda_ctx = reinterpret_cast(dev_ctx); @@ -58,7 +58,7 @@ struct StridedMemcpyFunctor { auto& cpu_place = BOOST_GET_CONST(platform::CPUPlace, place); memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim[0]); } else { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto& gpu_place = BOOST_GET_CONST(platform::CUDAPlace, place); auto& cuda_ctx = reinterpret_cast(dev_ctx);