|
|
|
@ -31,7 +31,7 @@ namespace platform {
|
|
|
|
|
TEST(NCCL, init) {
|
|
|
|
|
std::vector<ncclComm_t> comms;
|
|
|
|
|
comms.resize(dev_count);
|
|
|
|
|
PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr));
|
|
|
|
|
dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
|
|
|
|
|
for (int i = 0; i < dev_count; ++i) {
|
|
|
|
|
dynload::ncclCommDestroy(comms[i]);
|
|
|
|
|
}
|
|
|
|
@ -62,7 +62,7 @@ TEST(NCCL, all_reduce) {
|
|
|
|
|
std::vector<ncclComm_t> comms;
|
|
|
|
|
comms.resize(dev_count);
|
|
|
|
|
VLOG(1) << "Initializing ncclComm";
|
|
|
|
|
PADDLE_ENFORCE(dynload::ncclCommInitAll(comms.data(), dev_count, nullptr));
|
|
|
|
|
dynload::ncclCommInitAll(comms.data(), dev_count, nullptr);
|
|
|
|
|
VLOG(1) << "ncclComm initialized";
|
|
|
|
|
VLOG(1) << "Creating thread data";
|
|
|
|
|
std::vector<std::unique_ptr<PerThreadData<double>>> data;
|
|
|
|
|