bug fix, test=develop (#28648)

musl/fix_failed_unittests_in_musl
lilong12 5 years ago committed by GitHub
parent 8f2656ef5c
commit b2f7ab6636
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,6 +26,7 @@ template <typename T>
class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
int rid = ctx.Attr<int>("ring_id");
PADDLE_ENFORCE_GE(
rid, 0,
@ -44,7 +45,6 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
framework::proto::VarType::Type type =
framework::proto::VarType::Type(data_type);
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
cudaStream_t stream = nullptr;
auto place = ctx.GetPlace();
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);

@ -26,6 +26,7 @@ template <typename T>
class SendOpV2CUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
auto x = ctx.Input<framework::LoDTensor>("X");
int numel = x->numel();
@ -42,7 +43,6 @@ class SendOpV2CUDAKernel : public framework::OpKernel<T> {
"The peer (%d) for send_v2 op must be non-negative.", peer));
cudaStream_t stream = nullptr;
auto place = ctx.GetPlace();
#if defined(PADDLE_WITH_NCCL) && NCCL_VERSION_CODE >= 2703
auto comm = platform::NCCLCommContext::Instance().Get(rid, place);
if (ctx.Attr<bool>("use_calc_stream")) {
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);

Loading…
Cancel
Save