|
|
|
@ -626,7 +626,18 @@ All parameter, weight, gradient are variables in Paddle.
|
|
|
|
|
py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
|
|
|
|
|
#endif
|
|
|
|
|
py::class_<platform::CUDAPlace>(m, "CUDAPlace")
|
|
|
|
|
.def(py::init<int>())
|
|
|
|
|
.def("__init__",
|
|
|
|
|
[](platform::CUDAPlace &self, int dev_id) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
dev_id >= 0 && dev_id < platform::GetCUDADeviceCount(),
|
|
|
|
|
"Invalid CUDAPlace(%d), must inside [0, %d)", dev_id,
|
|
|
|
|
platform::GetCUDADeviceCount());
|
|
|
|
|
new (&self) platform::CUDAPlace(dev_id);
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Cannot use CUDAPlace in CPU only version");
|
|
|
|
|
#endif
|
|
|
|
|
})
|
|
|
|
|
.def("__str__", string::to_string<const platform::CUDAPlace &>);
|
|
|
|
|
|
|
|
|
|
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
|
|
|
|
@ -634,7 +645,12 @@ All parameter, weight, gradient are variables in Paddle.
|
|
|
|
|
.def("__str__", string::to_string<const platform::CPUPlace &>);
|
|
|
|
|
|
|
|
|
|
py::class_<paddle::platform::CUDAPinnedPlace>(m, "CUDAPinnedPlace")
|
|
|
|
|
.def(py::init<>())
|
|
|
|
|
.def("__init__",
|
|
|
|
|
[](platform::CUDAPinnedPlace &) {
|
|
|
|
|
#ifndef PADDLE_WITH_CUDA
|
|
|
|
|
PADDLE_THROW("Cannot use CUDAPinnedPlace in CPU only version");
|
|
|
|
|
#endif
|
|
|
|
|
})
|
|
|
|
|
.def("__str__", string::to_string<const platform::CUDAPinnedPlace &>);
|
|
|
|
|
|
|
|
|
|
py::class_<platform::Place>(m, "Place")
|
|
|
|
|