|
|
|
@ -142,7 +142,43 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
|
|
|
|
|
mutable unsigned int* semaphore_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
|
|
|
|
|
class CudnnHolder {
|
|
|
|
|
public:
|
|
|
|
|
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
|
|
|
|
|
: stream_(stream), place_(place), workspace_(nullptr), workspace_len_(0) {
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cudnnHandle_t get_cudnn_handle() const { return cudnn_handle_; }
|
|
|
|
|
|
|
|
|
|
void* get_workspace(size_t required_len) {
|
|
|
|
|
if (required_len > workspace_len_) {
|
|
|
|
|
void* new_workspace = paddle::memory::Alloc(place_, required_len);
|
|
|
|
|
if (workspace_ != nullptr) {
|
|
|
|
|
// Maybe someone is using the current workspace
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
|
|
|
|
|
PADDLE_ENFORCE(cudaGetLastError());
|
|
|
|
|
paddle::memory::Free(place_, workspace_);
|
|
|
|
|
}
|
|
|
|
|
workspace_ = new_workspace;
|
|
|
|
|
}
|
|
|
|
|
return workspace_
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
cudnnHandle_t cudnn_handle_;
|
|
|
|
|
void* workspace_;
|
|
|
|
|
size_t workspace_len_;
|
|
|
|
|
|
|
|
|
|
const cudaStream_t* stream_; // not owned;
|
|
|
|
|
const CUDAPlace place_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
|
|
|
|
|
: place_(place), cudnn_holder_(nullptr) {
|
|
|
|
|
SetDeviceId(place_.device);
|
|
|
|
|
compute_capability = GetCUDAComputeCapability(place_.device);
|
|
|
|
|
multi_process = GetCUDAMultiProcessors(place_.device);
|
|
|
|
@ -154,10 +190,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
|
|
|
|
|
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
|
|
|
|
|
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
|
|
|
|
|
if (dynload::HasCUDNN()) {
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
|
|
|
|
|
} else {
|
|
|
|
|
cudnn_handle_ = nullptr;
|
|
|
|
|
cudnn_holder_.reset(new CudnnHolder(&stream_, place));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -165,9 +198,6 @@ CUDADeviceContext::~CUDADeviceContext() {
|
|
|
|
|
SetDeviceId(place_.device);
|
|
|
|
|
Wait();
|
|
|
|
|
PADDLE_ENFORCE(dynload::cublasDestroy(cublas_handle_));
|
|
|
|
|
if (cudnn_handle_ != nullptr) {
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
|
|
|
|
|
}
|
|
|
|
|
eigen_stream_.reset();
|
|
|
|
|
eigen_device_.reset();
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
|
|
|
|
@ -196,7 +226,13 @@ cublasHandle_t CUDADeviceContext::cublas_handle() const {
|
|
|
|
|
return cublas_handle_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
|
|
|
|
|
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
|
|
|
|
|
return cudnn_holder_->get_cudnn_handle();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void* cudnn_workspace(size_t required_len) const {
|
|
|
|
|
return cudnn_holder_->get_workspace(required_len);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
|
|
|
|
|
|
|
|
|
|