|
|
|
@ -26,6 +26,7 @@ template <typename T>
|
|
|
|
|
class SendOpV2CUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
|
|
|
|
|
auto x = ctx.Input<framework::LoDTensor>("X");
|
|
|
|
|
int numel = x->numel();
|
|
|
|
|
|
|
|
|
@ -42,7 +43,6 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
|
|
|
|
|
"The peer (%d) for send_v2 op must be non-negative.", peer));
|
|
|
|
|
cudaStream_t stream = nullptr;
|
|
|
|
|
auto place = ctx.GetPlace();
|
|
|
|
|
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
|
|
|
|
|
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
|
|
|
|
|
if (ctx.Attr<bool>("use_calc_stream")) {
|
|
|
|
|
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
|
|
|
|
|