|
|
|
@ -69,10 +69,13 @@ void GraphView::Initialize(const ProgramDesc* pdesc) {
|
|
|
|
|
|
|
|
|
|
struct Device {
|
|
|
|
|
platform::CPUDeviceContext* cpu_device_context;
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
platform::CUDADeviceContext* cuda_device_context;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
Device(platform::CPUDeviceContext* cpu, platform::CUDADeviceContext* gpu)
|
|
|
|
|
: cpu_device_context(cpu), cuda_device_context(gpu) {}
|
|
|
|
|
platform::CDUADeviceContext* cuda_device_context;
|
|
|
|
|
#else
|
|
|
|
|
explicit Device(platform::CPUDeviceContext* cpu) : cpu_device_context(cpu) {}
|
|
|
|
|
#endif
|
|
|
|
@ -126,10 +129,16 @@ platform::CUDADeviceContext* GetCUDADeviceContext(
|
|
|
|
|
Device* GetDevice(const platform::Place& place) {
|
|
|
|
|
platform::CPUPlace cpu_place;
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
platform::GPUPlace gpu_place = boost::get<platform::GPUPlace>(place);
|
|
|
|
|
static std::unique_ptr<Device> g_device = make_unique<Device>(
|
|
|
|
|
GetCPUDeviceContext(cpu_place), GetCUDADeviceContext(gpu_place));
|
|
|
|
|
return g_device.get();
|
|
|
|
|
if (platform::is_gpu_place(place)) {
|
|
|
|
|
platform::GPUPlace gpu_place = boost::get<platform::GPUPlace>(place);
|
|
|
|
|
static std::unique_ptr<Device> g_device = make_unique<Device>(
|
|
|
|
|
GetCPUDeviceContext(cpu_place), GetCUDADeviceContext(gpu_place));
|
|
|
|
|
return g_device.get();
|
|
|
|
|
} else {
|
|
|
|
|
static std::unique_ptr<Device> g_device =
|
|
|
|
|
make_unique<Device>(GetCPUDeviceContext(cpu_place), nullptr);
|
|
|
|
|
return g_device.get();
|
|
|
|
|
}
|
|
|
|
|
#else
|
|
|
|
|
static std::unique_ptr<Device> g_device =
|
|
|
|
|
make_unique<Device>(GetCPUDeviceContext(cpu_place));
|
|
|
|
@ -153,7 +162,9 @@ void ExecutorImpl::Run() {
|
|
|
|
|
scope_->NewVar();
|
|
|
|
|
device_->cpu_device_context->Wait();
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
device_->cuda_device_context->Wait();
|
|
|
|
|
if (device_->cuda_device_context) {
|
|
|
|
|
device_->cuda_device_context->Wait();
|
|
|
|
|
}
|
|
|
|
|
#endif
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|