randperm run error in multi-gpus (#27942)

swt-req
zhupengyang 4 years ago committed by GitHub
parent 74fadeb44a
commit 6dd64b0a30
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -57,7 +57,7 @@ class RandpermKernel : public framework::OpKernel<T> {
tmp_tensor.Resize(framework::make_ddim({n}));
T* tmp_data = tmp_tensor.mutable_data<T>(platform::CPUPlace());
random_permate<T>(tmp_data, n, seed);
framework::TensorCopy(tmp_tensor, platform::CUDAPlace(), out_tensor);
framework::TensorCopy(tmp_tensor, ctx.GetPlace(), out_tensor);
}
}
};

Loading…
Cancel
Save