|
|
|
@ -83,11 +83,12 @@ class CudaDeviceContext : public DeviceContext {
|
|
|
|
|
cublasHandle_t cublas_handle() {
|
|
|
|
|
if (!blas_handle_) {
|
|
|
|
|
DeviceGuard guard(gpu_place_);
|
|
|
|
|
PADDLE_ENFORCE(cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS,
|
|
|
|
|
"cublasCreate failed");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
cublasSetStream(blas_handle_, stream_) == CUBLAS_STATUS_SUCCESS,
|
|
|
|
|
"cublasSetStream failed");
|
|
|
|
|
paddle::dyload::cublasCreate(&blas_handle_) == CUBLAS_STATUS_SUCCESS,
|
|
|
|
|
"cublasCreate failed");
|
|
|
|
|
PADDLE_ENFORCE(paddle::dyload::cublasSetStream(blas_handle_, stream_) ==
|
|
|
|
|
CUBLAS_STATUS_SUCCESS,
|
|
|
|
|
"cublasSetStream failed");
|
|
|
|
|
}
|
|
|
|
|
return blas_handle_;
|
|
|
|
|
}
|
|
|
|
@ -95,11 +96,12 @@ class CudaDeviceContext : public DeviceContext {
|
|
|
|
|
cudnnHandle_t cudnn_handle() {
|
|
|
|
|
if (!dnn_handle_) {
|
|
|
|
|
DeviceGuard guard(gpu_place_);
|
|
|
|
|
PADDLE_ENFORCE(cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS,
|
|
|
|
|
"cudnnCreate failed");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
cudnnSetStream(dnn_handle_, stream_) == CUDNN_STATUS_SUCCESS,
|
|
|
|
|
"cudnnSetStream failed");
|
|
|
|
|
paddle::dyload::cudnnCreate(&dnn_handle_) == CUDNN_STATUS_SUCCESS,
|
|
|
|
|
"cudnnCreate failed");
|
|
|
|
|
PADDLE_ENFORCE(paddle::dyload::cudnnSetStream(dnn_handle_, stream_) ==
|
|
|
|
|
CUDNN_STATUS_SUCCESS,
|
|
|
|
|
"cudnnSetStream failed");
|
|
|
|
|
}
|
|
|
|
|
return dnn_handle_;
|
|
|
|
|
}
|
|
|
|
@ -107,17 +109,17 @@ class CudaDeviceContext : public DeviceContext {
|
|
|
|
|
curandGenerator_t curand_generator() {
|
|
|
|
|
if (!rand_generator_) {
|
|
|
|
|
DeviceGuard guard(gpu_place_);
|
|
|
|
|
PADDLE_ENFORCE(paddle::dyload::curandCreateGenerator(
|
|
|
|
|
&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
|
|
|
|
|
CURAND_STATUS_SUCCESS,
|
|
|
|
|
"curandCreateGenerator failed");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
curandCreateGenerator(&rand_generator_, CURAND_RNG_PSEUDO_DEFAULT) ==
|
|
|
|
|
CURAND_STATUS_SUCCESS,
|
|
|
|
|
"curandCreateGenerator failed");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
curandSetPseudoRandomGeneratorSeed(rand_generator_, random_seed_) ==
|
|
|
|
|
CURAND_STATUS_SUCCESS,
|
|
|
|
|
paddle::dyload::curandSetPseudoRandomGeneratorSeed(
|
|
|
|
|
rand_generator_, random_seed_) == CURAND_STATUS_SUCCESS,
|
|
|
|
|
"curandSetPseudoRandomGeneratorSeed failed");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
curandSetStream(rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
|
|
|
|
|
"curandSetStream failed");
|
|
|
|
|
PADDLE_ENFORCE(paddle::dyload::curandSetStream(
|
|
|
|
|
rand_generator_, stream_) == CURAND_STATUS_SUCCESS,
|
|
|
|
|
"curandSetStream failed");
|
|
|
|
|
}
|
|
|
|
|
return rand_generator_;
|
|
|
|
|
}
|
|
|
|
@ -125,19 +127,21 @@ class CudaDeviceContext : public DeviceContext {
|
|
|
|
|
~CudaDeviceContext() {
|
|
|
|
|
Wait();
|
|
|
|
|
if (blas_handle_) {
|
|
|
|
|
PADDLE_ENFORCE(cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS,
|
|
|
|
|
"cublasDestroy failed");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
paddle::dyload::cublasDestroy(blas_handle_) == CUBLAS_STATUS_SUCCESS,
|
|
|
|
|
"cublasDestroy failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dnn_handle_) {
|
|
|
|
|
PADDLE_ENFORCE(cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS,
|
|
|
|
|
"cudnnDestroy failed");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
paddle::dyload::cudnnDestroy(dnn_handle_) == CUDNN_STATUS_SUCCESS,
|
|
|
|
|
"cudnnDestroy failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (rand_generator_) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
curandDestroyGenerator(rand_generator_) == CURAND_STATUS_SUCCESS,
|
|
|
|
|
"curandDestroyGenerator failed");
|
|
|
|
|
PADDLE_ENFORCE(paddle::dyload::curandDestroyGenerator(rand_generator_) ==
|
|
|
|
|
CURAND_STATUS_SUCCESS,
|
|
|
|
|
"curandDestroyGenerator failed");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
delete eigen_stream_;
|
|
|
|
|