@ -108,32 +108,21 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnCreateTensorDescriptor(&data_desc_),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnCreateTensorDescriptor(&data_desc_)."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnCreateTensorDescriptor(&bn_param_desc_)."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    VLOG(3) << "Setting descriptors.";
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    std::vector<int> dims = {N, C, H, W, D};
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    std::vector<int> strides = {H * W * D * C, 1, W * D * C, D * C, C};
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnSetTensorNdDescriptor(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            data_desc_, CudnnDataType<T>::type,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling cudnnSetTensorNdDescriptor."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        data_desc_, CudnnDataType<T>::type,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                                         data_desc_, mode_),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External("The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                   "cudnnDeriveBNTensorDescriptor."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                                         data_desc_, mode_));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    double this_factor = 1. - momentum;
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    cudnnBatchNormOps_t bnOps_ = CUDNN_BATCHNORM_OPS_BN_ACTIVATION;
 
				
			 
			
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
				 
				 
				
					@ -166,10 +155,7 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                /*yDesc=*/data_desc_,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                /*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                /*activationDesc=*/activation_desc_,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                /*sizeInBytes=*/&workspace_size),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                /*sizeInBytes=*/&workspace_size));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    // -------------- cudnn batchnorm reserve space --------------
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
				 
				 
				
					@ -179,10 +165,7 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*bnOps=*/bnOps_,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*activationDesc=*/activation_desc_,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*xDesc=*/data_desc_,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*sizeInBytes=*/&reserve_space_size),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnGetBatchNormalizationTrainingExReserveSpaceSize."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*sizeInBytes=*/&reserve_space_size));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    reserve_space_ptr = reserve_space->mutable_data(ctx.GetPlace(), x->type(),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                                    reserve_space_size);
 
				
			 
			
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
				 
				 
				
					@ -204,22 +187,13 @@ class FusedBatchNormActKernel<platform::CUDADeviceContext, T>
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            saved_variance->template mutable_data<BatchNormParamType<T>>(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                ctx.GetPlace()),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            activation_desc_, workspace_ptr, workspace_size, reserve_space_ptr,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            reserve_space_size),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnBatchNormalizationForwardTrainingEx."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            reserve_space_size));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    // clean when exit.
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnDestroyTensorDescriptor(data_desc_),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnDestroyTensorDescriptor(data_desc_)."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnDestroyTensorDescriptor(bn_param_desc_)."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					  }
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					};
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
				 
				 
				
					@ -298,15 +272,9 @@ class FusedBatchNormActGradKernel<platform::CUDADeviceContext, T>
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    cudnnBatchNormMode_t mode_ = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnCreateTensorDescriptor(&data_desc_),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnCreateTensorDescriptor(&data_desc_)."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnCreateTensorDescriptor(&data_desc_));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnCreateTensorDescriptor(&bn_param_desc_)."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnCreateTensorDescriptor(&bn_param_desc_));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    if (epsilon <= CUDNN_BN_MIN_EPSILON - FLT_EPSILON) {
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					      LOG(ERROR) << "Provided epsilon is smaller than "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                 << "CUDNN_BN_MIN_EPSILON. Setting it to "
 
				
			 
			
		
	
	
		
			
				
					
						
						
						
							
								 
							 
						
					 
				
				 
				 
				
					@ -314,17 +282,12 @@ class FusedBatchNormActGradKernel<platform::CUDADeviceContext, T>
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    }
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    epsilon = std::max(epsilon, CUDNN_BN_MIN_EPSILON);
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnSetTensorNdDescriptor(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            data_desc_, CudnnDataType<T>::type,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling cudnnSetTensorNdDescriptor."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        data_desc_, CudnnDataType<T>::type,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnDeriveBNTensorDescriptor(bn_param_desc_,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                                         data_desc_, mode_),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External("The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                   "cudnnDeriveBNTensorDescriptor."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                                         data_desc_, mode_));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
 
				
			 
			
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
				 
				 
				
					@ -354,10 +317,7 @@ class FusedBatchNormActGradKernel<platform::CUDADeviceContext, T>
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*dxDesc=*/data_desc_,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*bnScaleBiasMeanVarDesc=*/bn_param_desc_,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*activationDesc=*/activation_desc_,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*sizeInBytes=*/&workspace_size),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnGetBatchNormalizationBackwardExWorkspaceSize."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*sizeInBytes=*/&workspace_size));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    workspace_ptr = workspace_tensor.mutable_data(ctx.GetPlace(), x->type(),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                                  workspace_size);
 
				
			 
			
		
	
	
		
			
				
					
						
							
								 
							 
						
						
							
								 
							 
						
						
					 
				
				 
				 
				
					@ -395,21 +355,13 @@ class FusedBatchNormActGradKernel<platform::CUDADeviceContext, T>
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*workspace=*/workspace_ptr,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*workSpaceSizeInBytes=*/workspace_size,
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*reserveSpace=*/const_cast<T *>(reserve_space->template data<T>()),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*reserveSpaceSizeInBytes=*/reserve_space_size),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External("The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					                                   "cudnnBatchNormalizationBackwardEx."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            /*reserveSpaceSizeInBytes=*/reserve_space_size));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    // clean when exit.
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnDestroyTensorDescriptor(data_desc_),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnDestroyTensorDescriptor(data_desc_)."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					    PADDLE_ENFORCE_CUDA_SUCCESS(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_),
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::errors::External(
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "The error has happened when calling "
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					            "cudnnDestroyTensorDescriptor(bn_param_desc_)."));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					        platform::dynload::cudnnDestroyTensorDescriptor(bn_param_desc_));
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					  }
 
				
			 
			
		
	
		
			
				
					 
					 
				
				 
				 
				
					};