|
|
|
@ -47,10 +47,13 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
float momentum = ctx.Attr<float>("momentum");
|
|
|
|
|
const bool is_test = ctx.Attr<bool>("is_test");
|
|
|
|
|
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
|
|
|
|
|
const bool trainable_stats = ctx.Attr<bool>("trainable_statistics");
|
|
|
|
|
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
|
|
|
|
|
const DataLayout data_layout =
|
|
|
|
|
framework::StringToDataLayout(data_layout_str);
|
|
|
|
|
|
|
|
|
|
bool test_mode = is_test && (!trainable_stats);
|
|
|
|
|
|
|
|
|
|
// Get the size for each dimension.
|
|
|
|
|
// NCHW [batch_size, in_channels, in_height, in_width]
|
|
|
|
|
const auto *x = ctx.Input<Tensor>("X");
|
|
|
|
@ -66,7 +69,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
|
|
|
|
|
auto dtype = platform::CudnnDataType<T>::type;
|
|
|
|
|
const bool fast_nhwc_batch_norm =
|
|
|
|
|
is_test ||
|
|
|
|
|
test_mode ||
|
|
|
|
|
(dtype == CUDNN_DATA_HALF && FLAGS_cudnn_batchnorm_spatial_persistent);
|
|
|
|
|
|
|
|
|
|
auto compute_format =
|
|
|
|
@ -133,7 +136,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::cudnnDeriveBNTensorDescriptor(
|
|
|
|
|
bn_param_desc_, data_desc_,
|
|
|
|
|
is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
|
|
|
|
|
test_mode ? CUDNN_BATCHNORM_SPATIAL : mode_));
|
|
|
|
|
|
|
|
|
|
const auto *scale = ctx.Input<Tensor>("Scale");
|
|
|
|
|
const auto *bias = ctx.Input<Tensor>("Bias");
|
|
|
|
@ -143,7 +146,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
auto handle = dev_ctx.cudnn_handle();
|
|
|
|
|
|
|
|
|
|
// Now, depending on whether we are running test or not, we have two paths.
|
|
|
|
|
if (is_test || use_global_stats) {
|
|
|
|
|
if (test_mode || use_global_stats) {
|
|
|
|
|
// only when test we use input to do computation.
|
|
|
|
|
const auto *est_mean = ctx.Input<Tensor>("Mean");
|
|
|
|
|
const auto *est_var = ctx.Input<Tensor>("Variance");
|
|
|
|
|