|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
#include <Python.h>
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <cstdlib>
|
|
|
|
|
#include <map>
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <mutex> // NOLINT // for call_once
|
|
|
|
@ -786,13 +787,41 @@ All parameter, weight, gradient are variables in Paddle.
|
|
|
|
|
.def("__init__",
|
|
|
|
|
[](platform::CUDAPlace &self, int dev_id) {
|
|
|
|
|
#ifdef PADDLE_WITH_CUDA
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
dev_id >= 0 && dev_id < platform::GetCUDADeviceCount(),
|
|
|
|
|
"Invalid CUDAPlace(%d), must inside [0, %d)", dev_id,
|
|
|
|
|
platform::GetCUDADeviceCount());
|
|
|
|
|
if (UNLIKELY(dev_id < 0)) {
|
|
|
|
|
LOG(ERROR) << string::Sprintf(
|
|
|
|
|
"Invalid CUDAPlace(%d), device id must be 0 or "
|
|
|
|
|
"positive integer",
|
|
|
|
|
dev_id);
|
|
|
|
|
std::exit(-1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (UNLIKELY(dev_id >= platform::GetCUDADeviceCount())) {
|
|
|
|
|
if (platform::GetCUDADeviceCount() == 0) {
|
|
|
|
|
LOG(ERROR) << "Cannot use GPU because there is no GPU "
|
|
|
|
|
"detected on your "
|
|
|
|
|
"machine.";
|
|
|
|
|
std::exit(-1);
|
|
|
|
|
} else {
|
|
|
|
|
LOG(ERROR) << string::Sprintf(
|
|
|
|
|
"Invalid CUDAPlace(%d), must inside [0, %d), because GPU "
|
|
|
|
|
"number on your machine is %d",
|
|
|
|
|
dev_id, platform::GetCUDADeviceCount(),
|
|
|
|
|
platform::GetCUDADeviceCount());
|
|
|
|
|
std::exit(-1);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
new (&self) platform::CUDAPlace(dev_id);
|
|
|
|
|
#else
|
|
|
|
|
PADDLE_THROW("Cannot use CUDAPlace in CPU only version");
|
|
|
|
|
LOG(ERROR) << string::Sprintf(
|
|
|
|
|
"Cannot use GPU because you have installed CPU version "
|
|
|
|
|
"PaddlePaddle.\n"
|
|
|
|
|
"If you want to use GPU, please try to install GPU version "
|
|
|
|
|
"PaddlePaddle by: pip install paddlepaddle-gpu\n"
|
|
|
|
|
"If you only have CPU, please change CUDAPlace(%d) to be "
|
|
|
|
|
"CPUPlace().\n",
|
|
|
|
|
dev_id);
|
|
|
|
|
std::exit(-1);
|
|
|
|
|
#endif
|
|
|
|
|
})
|
|
|
|
|
.def("_type", &PlaceIndex<platform::CUDAPlace>)
|
|
|
|
|