|
|
|
@ -56,15 +56,15 @@ void OpHandleBase::Run(bool use_event) {
|
|
|
|
|
RunImpl();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
|
|
|
|
|
void OpHandleBase::RecordWaitEventOnCtx(platform::DeviceContext *waited_ctx) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
|
|
|
|
|
if (platform::is_cpu_place(waited_ctx->GetPlace()) || events_.empty()) {
|
|
|
|
|
for (auto &dev_ctx : dev_ctxes_) {
|
|
|
|
|
dev_ctx.second->Wait();
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
auto stream =
|
|
|
|
|
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
|
|
|
|
|
static_cast<platform::CUDADeviceContext *>(waited_ctx)->stream();
|
|
|
|
|
for (auto &ev : events_) {
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, ev.second, 0));
|
|
|
|
|
}
|
|
|
|
@ -86,6 +86,28 @@ void OpHandleBase::AddOutput(VarHandleBase *out) {
|
|
|
|
|
out->generated_op_ = this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpHandleBase::WaitInputVarGenerated() {
|
|
|
|
|
for (auto in_var : inputs_) {
|
|
|
|
|
if (NeedWait(in_var)) {
|
|
|
|
|
for (auto &pair : dev_ctxes_) {
|
|
|
|
|
in_var->generated_op_->RecordWaitEventOnCtx(pair.second);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpHandleBase::WaitInputVarGenerated(const platform::Place &place) {
|
|
|
|
|
for (auto *in : inputs_) {
|
|
|
|
|
if (NeedWait(in)) {
|
|
|
|
|
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[place]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool OpHandleBase::NeedWait(VarHandleBase *in_var) {
|
|
|
|
|
return dynamic_cast<VarHandle *>(in_var) && in_var->generated_op_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (!events_.empty()) { // Use event
|
|
|
|
|