|
|
|
@ -174,10 +174,11 @@ void NCCLTester::testNcclAllReduceOp() {
|
|
|
|
|
result_tensor->Resize(kDims);
|
|
|
|
|
auto *ct = result_tensor->mutable_data<float>(cpu_place);
|
|
|
|
|
|
|
|
|
|
paddle::memory::Copy(
|
|
|
|
|
cpu_place, ct, p::CUDAPlace(gpu_list_[i]), rt,
|
|
|
|
|
recv_tensor.numel() * sizeof(float),
|
|
|
|
|
static_cast<p::CUDADeviceContext *>(dev_ctxs_[i])->stream());
|
|
|
|
|
auto *dev_ctx = static_cast<p::CUDADeviceContext *>(dev_ctxs_[i]);
|
|
|
|
|
paddle::memory::Copy(cpu_place, ct, p::CUDAPlace(gpu_list_[i]), rt,
|
|
|
|
|
recv_tensor.numel() * sizeof(float),
|
|
|
|
|
dev_ctx->stream());
|
|
|
|
|
dev_ctx->Wait();
|
|
|
|
|
|
|
|
|
|
for (int64_t j = 0; j < f::product(kDims); ++j) {
|
|
|
|
|
ASSERT_NEAR(ct[j], expected_result, 1e-5);
|
|
|
|
@ -272,10 +273,10 @@ void NCCLTester::testNcclBcastOp() {
|
|
|
|
|
result_tensor->Resize(kDims);
|
|
|
|
|
auto *ct = result_tensor->mutable_data<float>(cpu_place);
|
|
|
|
|
|
|
|
|
|
paddle::memory::Copy(
|
|
|
|
|
cpu_place, ct, p::CUDAPlace(gpu_list_[idx]), rt,
|
|
|
|
|
recv_tensor.numel() * sizeof(float),
|
|
|
|
|
static_cast<p::CUDADeviceContext *>(dev_ctxs_[idx])->stream());
|
|
|
|
|
auto *dev_ctx = static_cast<p::CUDADeviceContext *>(dev_ctxs_[idx]);
|
|
|
|
|
paddle::memory::Copy(cpu_place, ct, p::CUDAPlace(gpu_list_[idx]), rt,
|
|
|
|
|
recv_tensor.numel() * sizeof(float), dev_ctx->stream());
|
|
|
|
|
dev_ctx->Wait();
|
|
|
|
|
|
|
|
|
|
for (int64_t j = 0; j < f::product(kDims); ++j) {
|
|
|
|
|
ASSERT_NEAR(ct[j], result, 1e-5);
|
|
|
|
@ -288,13 +289,9 @@ TEST_F(NCCLTester, ncclInitOp) {}
|
|
|
|
|
TEST_F(NCCLTester, ncclOp) {
|
|
|
|
|
// Serial execution is required for the same nccl comm.
|
|
|
|
|
|
|
|
|
|
// ncclAllReduceOp with desc
|
|
|
|
|
// TODO(helin): https://github.com/PaddlePaddle/Paddle/issues/9367
|
|
|
|
|
testNcclReduceOp();
|
|
|
|
|
|
|
|
|
|
testNcclAllReduceOp();
|
|
|
|
|
|
|
|
|
|
// ncclBcastOp with desc
|
|
|
|
|
// TODO(helin): https://github.com/PaddlePaddle/Paddle/issues/9540
|
|
|
|
|
testNcclBcastOp();
|
|
|
|
|
}
|
|
|
|
|