|
|
|
@ -136,18 +136,14 @@ All parameter, weight, gradient are variables in Paddle.
|
|
|
|
|
[]() -> paddle::platform::DeviceContext* {
|
|
|
|
|
return new paddle::platform::CPUDeviceContext();
|
|
|
|
|
})
|
|
|
|
|
#ifndef PADDLE_ONLY_CPU
|
|
|
|
|
.def_static("gpu_context",
|
|
|
|
|
[](paddle::platform::Place& place)
|
|
|
|
|
-> paddle::platform::DeviceContext* {
|
|
|
|
|
#ifdef PADDLE_ONLY_CPU
|
|
|
|
|
|
|
|
|
|
// PADDLE_THROW("'GPUPlace' is not supported in CPU only
|
|
|
|
|
// device.");
|
|
|
|
|
return nullptr;
|
|
|
|
|
#else
|
|
|
|
|
return new paddle::platform::CUDADeviceContext(place);
|
|
|
|
|
})
|
|
|
|
|
#endif
|
|
|
|
|
});
|
|
|
|
|
;
|
|
|
|
|
|
|
|
|
|
py::class_<paddle::platform::Place>(m, "GPUPlace").def(py::init<int>());
|
|
|
|
|
|
|
|
|
|