|
|
|
@ -42,7 +42,7 @@ OpHandleBase::~OpHandleBase() {
|
|
|
|
|
void OpHandleBase::Run(bool use_event) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (events_.empty() && use_event) {
|
|
|
|
|
for (auto &p : dev_ctx_) {
|
|
|
|
|
for (auto &p : dev_ctxes_) {
|
|
|
|
|
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
|
|
|
|
|
PADDLE_ENFORCE(cudaSetDevice(dev_id));
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
@ -57,7 +57,7 @@ void OpHandleBase::Run(bool use_event) {
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (use_event) {
|
|
|
|
|
for (auto &p : dev_ctx_) {
|
|
|
|
|
for (auto &p : dev_ctxes_) {
|
|
|
|
|
int dev_id = boost::get<platform::CUDAPlace>(p.first).device;
|
|
|
|
|
auto stream =
|
|
|
|
|
static_cast<platform::CUDADeviceContext *>(p.second)->stream();
|
|
|
|
@ -70,7 +70,7 @@ void OpHandleBase::Run(bool use_event) {
|
|
|
|
|
void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
if (platform::is_cpu_place(waited_dev->GetPlace()) || events_.empty()) {
|
|
|
|
|
for (auto &dev_ctx : dev_ctx_) {
|
|
|
|
|
for (auto &dev_ctx : dev_ctxes_) {
|
|
|
|
|
dev_ctx.second->Wait();
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
@ -81,7 +81,7 @@ void OpHandleBase::Wait(platform::DeviceContext *waited_dev) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
for (auto &dev_ctx : dev_ctx_) {
|
|
|
|
|
for (auto &dev_ctx : dev_ctxes_) {
|
|
|
|
|
dev_ctx.second->Wait();
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|