Add Paddle Enforce

helinwang-patch-1
Yu Yang 7 years ago
parent 833e522d16
commit f385228f05

@ -34,7 +34,7 @@ std::string OpHandleBase::DebugString() const {
OpHandleBase::~OpHandleBase() { OpHandleBase::~OpHandleBase() {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
for (auto &ev : events_) { for (auto &ev : events_) {
cudaEventDestroy(ev.second); PADDLE_ENFORCE(cudaEventDestroy(ev.second));
} }
#endif #endif
} }
@ -44,8 +44,9 @@ void OpHandleBase::Run(bool use_event) {
if (events_.empty() && use_event) { if (events_.empty() && use_event) {
for (auto &p : dev_ctx_) { for (auto &p : dev_ctx_) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
cudaSetDevice(dev_id); PADDLE_ENFORCE(cudaSetDevice(dev_id));
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming); PADDLE_ENFORCE(
cudaEventCreateWithFlags(&events_[dev_id], cudaEventDisableTiming));
} }
} }
#else #else
@ -60,7 +61,7 @@ void OpHandleBase::Run(bool use_event) {
int dev_id = boost::get<platform::CUDAPlace>(p.first).device; int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
auto stream = auto stream =
static_cast<platform::CUDADeviceContext *>(p.second)->stream(); static_cast<platform::CUDADeviceContext *>(p.second)->stream();
cudaEventRecord(events_.at(dev_id), stream); PADDLE_ENFORCE(cudaEventRecord(events_.at(dev_id), stream));
} }
} }
#endif #endif

Loading…
Cancel
Save