|
|
|
@ -145,9 +145,9 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
|
|
|
|
|
class CudnnHolder {
|
|
|
|
|
public:
|
|
|
|
|
CudnnHolder(const cudaStream_t* stream, const CUDAPlace& place)
|
|
|
|
|
: stream_(stream), place_(place), workspace_(nullptr), workspace_len_(0) {
|
|
|
|
|
: workspace_(nullptr), workspace_len_(0), stream_(stream), place_(place) {
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
|
|
|
|
|
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, *stream_));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
cudnnHandle_t get_cudnn_handle() const { return cudnn_handle_; }
|
|
|
|
@ -157,14 +157,14 @@ class CudnnHolder {
|
|
|
|
|
void* new_workspace = paddle::memory::Alloc(place_, required_len);
|
|
|
|
|
if (workspace_ != nullptr) {
|
|
|
|
|
// Maybe someone is using the current workspace
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
|
|
|
|
|
PADDLE_ENFORCE(cudaStreamSynchronize(*stream_));
|
|
|
|
|
PADDLE_ENFORCE(cudaGetLastError());
|
|
|
|
|
paddle::memory::Free(place_, workspace_);
|
|
|
|
|
}
|
|
|
|
|
workspace_ = new_workspace;
|
|
|
|
|
workspace_len_ = required_len;
|
|
|
|
|
}
|
|
|
|
|
return workspace_
|
|
|
|
|
return workspace_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
~CudnnHolder() { PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_)); }
|
|
|
|
@ -231,7 +231,7 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
|
|
|
|
|
return cudnn_holder_->get_cudnn_handle();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void* cudnn_workspace(size_t required_len) const {
|
|
|
|
|
void* CUDADeviceContext::cudnn_workspace(size_t required_len) const {
|
|
|
|
|
return cudnn_holder_->get_workspace(required_len);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|