|
|
|
@ -92,12 +92,22 @@ struct ComputationOpHandle : public OpHandle {
|
|
|
|
|
std::unique_ptr<OperatorBase> op_;
|
|
|
|
|
Scope *scope_;
|
|
|
|
|
platform::Place place_;
|
|
|
|
|
cudaEvent_t event_;
|
|
|
|
|
|
|
|
|
|
explicit ComputationOpHandle(const OpDesc &op_desc, Scope *scope,
|
|
|
|
|
platform::Place place)
|
|
|
|
|
: op_(framework::OpRegistry::CreateOp(op_desc)),
|
|
|
|
|
scope_(scope),
|
|
|
|
|
place_(place) {}
|
|
|
|
|
place_(place) {
|
|
|
|
|
if (platform::is_gpu_place(place)) {
|
|
|
|
|
cudaSetDevice(boost::get<platform::CUDAPlace>(place_).device);
|
|
|
|
|
cudaEventCreateWithFlags(&event_, cudaEventDisableTiming);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~ComputationOpHandle() {
|
|
|
|
|
// FIXME: Destroy Event
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Run() override {
|
|
|
|
|
// Wait other op if necessary
|
|
|
|
@ -113,10 +123,22 @@ struct ComputationOpHandle : public OpHandle {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op_->Run(*scope_, place_);
|
|
|
|
|
if (platform::is_gpu_place(place_)) {
|
|
|
|
|
auto stream = static_cast<platform::CUDADeviceContext *>(dev_ctx_[place_])
|
|
|
|
|
->stream();
|
|
|
|
|
PADDLE_ENFORCE(cudaEventRecord(event_, stream));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Wait(platform::DeviceContext *waited_dev) override {
|
|
|
|
|
this->dev_ctx_.at(place_)->Wait();
|
|
|
|
|
if (platform::is_cpu_place(waited_dev->GetPlace()) ||
|
|
|
|
|
platform::is_cpu_place(place_)) {
|
|
|
|
|
this->dev_ctx_.at(place_)->Wait();
|
|
|
|
|
} else {
|
|
|
|
|
auto stream =
|
|
|
|
|
static_cast<platform::CUDADeviceContext *>(waited_dev)->stream();
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamWaitEvent(stream, event_, 0));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|