|
|
@ -109,7 +109,6 @@ struct OpHandle {
|
|
|
|
|
|
|
|
|
|
|
|
virtual void Wait(platform::DeviceContext *waited_dev) {
|
|
|
|
virtual void Wait(platform::DeviceContext *waited_dev) {
|
|
|
|
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
|
|
|
|
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
|
|
|
|
VLOG(4) << "I am here";
|
|
|
|
|
|
|
|
for (auto &dev_ctx : dev_ctx_) {
|
|
|
|
for (auto &dev_ctx : dev_ctx_) {
|
|
|
|
dev_ctx.second->Wait();
|
|
|
|
dev_ctx.second->Wait();
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -255,7 +254,7 @@ struct FetchOpHandle : public OpHandle {
|
|
|
|
|
|
|
|
|
|
|
|
class ParallelExecutorPrivate {
|
|
|
|
class ParallelExecutorPrivate {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
explicit ParallelExecutorPrivate(size_t num_threads = 12)
|
|
|
|
explicit ParallelExecutorPrivate(size_t num_threads = 0)
|
|
|
|
: pool_(num_threads == 0 ? nullptr : new ThreadPool(num_threads)) {}
|
|
|
|
: pool_(num_threads == 0 ? nullptr : new ThreadPool(num_threads)) {}
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<platform::Place> places_;
|
|
|
|
std::vector<platform::Place> places_;
|
|
|
@ -397,8 +396,6 @@ struct NCCLAllReduceOpHandle : public OpHandle {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
platform::dynload::ncclGroupEnd();
|
|
|
|
platform::dynload::ncclGroupEnd();
|
|
|
|
PADDLE_ENFORCE(cudaDeviceSynchronize());
|
|
|
|
PADDLE_ENFORCE(cudaDeviceSynchronize());
|
|
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "After NCCL";
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|