|
|
|
@ -127,15 +127,21 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
|
|
|
|
|
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
|
|
|
|
|
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
|
|
|
|
|
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
|
|
|
|
|
if (dynload::HasCUDNN()) {
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
|
|
|
|
|
} else {
|
|
|
|
|
cudnn_handle_ = nullptr;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CUDADeviceContext::~CUDADeviceContext() {
|
|
|
|
|
SetDeviceId(place_.device);
|
|
|
|
|
Wait();
|
|
|
|
|
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
|
|
|
|
|
if (cudnn_handle_ != nullptr) {
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
|
|
|
|
|
}
|
|
|
|
|
eigen_stream_.reset();
|
|
|
|
|
eigen_device_.reset();
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
|
|
|
|
@ -160,20 +166,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
|
|
|
|
|
|
|
|
|
|
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
|
|
|
|
|
|
|
|
|
|
CUDNNDeviceContext::CUDNNDeviceContext(CUDAPlace place)
|
|
|
|
|
: CUDADeviceContext(place) {
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CUDNNDeviceContext::~CUDNNDeviceContext() {
|
|
|
|
|
SetDeviceId(boost::get<CUDAPlace>(GetPlace()).device);
|
|
|
|
|
Wait();
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cudnnHandle_t CUDNNDeviceContext::cudnn_handle() const { return cudnn_handle_; }
|
|
|
|
|
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
} // namespace platform
|
|
|
|
|