@ -108,32 +108,21 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnCreateTensorDescriptor(&data_desc_)."));
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnCreateTensorDescriptor(&bn_param_desc_)."));
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
VLOG(3) << "Setting descriptors.";
std::vector<int> dims = {N, C, H, W, D};
std::vector<int> strides = {H * W * D * C, 1, W * D * C, D * C, C};
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()),
platform::errors::External(
"The error has happened when calling cudnnSetTensorNdDescriptor."));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_,
data_desc_, mode_),
platform::errors::External("The error has happened when calling "
"cudnnDeriveBNTensorDescriptor."));
data_desc_, mode_));
double this_factor = 1. - momentum;
cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
@ -166,10 +155,7 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
/*yDesc=*/data_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*activationDesc=*/activation_desc_,
/*sizeInBytes=*/&workspace_size),
platform::errors::External(
"The error has happened when calling "
"cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize."));
/*sizeInBytes=*/&workspace_size));
// -------------- cudnn batchnorm reserve space --------------
PADDLE_ENFORCE_CUDA_SUCCESS(
@ -179,10 +165,7 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
/*bnOps=*/bnOps_,
/*activationDesc=*/activation_desc_,
/*xDesc=*/data_desc_,
/*sizeInBytes=*/&reserve_space_size),
platform::errors::External(
"The error has happened when calling "
"cudnnGetBatchNormalizationTrainingExReserveSpaceSize."));
/*sizeInBytes=*/&reserve_space_size));
reserve_space_ptr = reserve_space->mutable_data(ctx.GetPlace(), x->type(),
reserve_space_size);
@ -204,22 +187,13 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
activation_desc_, workspace_ptr, workspace_size, reserve_space_ptr,
reserve_space_size),
platform::errors::External(
"The error has happened when calling "
"cudnnBatchNormalizationForwardTrainingEx."));
reserve_space_size));
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnDestroyTensorDescriptor(data_desc_)."));
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnDestroyTensorDescriptor(bn_param_desc_)."));
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
}
};
@ -298,15 +272,9 @@ class FusedBatchNormActGradKernel<platform::CUDADeviceContext, T>
cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnCreateTensorDescriptor(&data_desc_)."));
platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnCreateTensorDescriptor(&bn_param_desc_)."));
platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
LOG(ERROR) << "Provided epsilon is smaller than "
<< "CUDNN_BN_MIN_EPSILON. Setting it to "
@ -314,17 +282,12 @@ class FusedBatchNormActGradKernel<platform::CUDADeviceContext, T>
}
epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()),
platform::errors::External(
"The error has happened when calling cudnnSetTensorNdDescriptor."));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
data_desc_, CudnnDataType<T>::type,
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_,
data_desc_, mode_),
platform::errors::External("The error has happened when calling "
"cudnnDeriveBNTensorDescriptor."));
data_desc_, mode_));
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
@ -354,10 +317,7 @@ class FusedBatchNormActGradKernel<platform::CUDADeviceContext, T>
/*dxDesc=*/data_desc_,
/*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
/*activationDesc=*/activation_desc_,
/*sizeInBytes=*/&workspace_size),
platform::errors::External(
"The error has happened when calling "
"cudnnGetBatchNormalizationBackwardExWorkspaceSize."));
/*sizeInBytes=*/&workspace_size));
workspace_ptr = workspace_tensor.mutable_data(ctx.GetPlace(), x->type(),
workspace_size);
@ -395,21 +355,13 @@ class FusedBatchNormActGradKernel<platform::CUDADeviceContext, T>
/*workspace=*/workspace_ptr,
/*workSpaceSizeInBytes=*/workspace_size,
/*reserveSpace=*/const_cast<T *>(reserve_space->template data<T>()),
/*reserveSpaceSizeInBytes=*/reserve_space_size),
platform::errors::External("The error has happened when calling "
"cudnnBatchNormalizationBackwardEx."));
/*reserveSpaceSizeInBytes=*/reserve_space_size));
// clean when exit.
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnDestroyTensorDescriptor(data_desc_)."));
platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_),
platform::errors::External(
"The error has happened when calling "
"cudnnDestroyTensorDescriptor(bn_param_desc_)."));
platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
}
};