|
|
|
@ -72,6 +72,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
int N, C, H, W, D;
|
|
|
|
|
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
|
|
|
|
|
|
|
|
|
|
auto *y = ctx.Output<Tensor>("Y");
|
|
|
|
|
y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
// ------------------- cudnn descriptors ---------------------
|
|
|
|
|
cudnnTensorDescriptor_t data_desc_;
|
|
|
|
|
cudnnTensorDescriptor_t bn_param_desc_;
|
|
|
|
@ -93,7 +96,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
mode_ = CUDNN_BATCHNORM_SPATIAL;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
VLOG(1) << "Setting descriptors.";
|
|
|
|
|
VLOG(3) << "Setting descriptors.";
|
|
|
|
|
std::vector<int> dims;
|
|
|
|
|
std::vector<int> strides;
|
|
|
|
|
if (data_layout == DataLayout::kNCHW) {
|
|
|
|
@ -113,11 +116,6 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
const auto *scale = ctx.Input<Tensor>("Scale");
|
|
|
|
|
const auto *bias = ctx.Input<Tensor>("Bias");
|
|
|
|
|
|
|
|
|
|
auto *y = ctx.Output<Tensor>("Y");
|
|
|
|
|
|
|
|
|
|
// alloc memory
|
|
|
|
|
y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
|
|
|
|
|
auto handle = dev_ctx.cudnn_handle();
|
|
|
|
@ -162,6 +160,11 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
|
|
|
|
|
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
|
|
|
|
|
|
|
|
|
|
if ((N * H * W * D) == 1) {
|
|
|
|
|
LOG(WARNING) << "Only 1 element in normalization dimension, "
|
|
|
|
|
<< "we skip the batch norm calculation, let y = x.";
|
|
|
|
|
framework::TensorCopySync(*x, ctx.GetPlace(), y);
|
|
|
|
|
} else {
|
|
|
|
|
double this_factor = 1. - momentum;
|
|
|
|
|
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
|
|
|
|
@ -179,6 +182,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
saved_variance->template mutable_data<BatchNormParamType<T>>(
|
|
|
|
|
ctx.GetPlace())));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// clean when exit.
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnDestroyTensorDescriptor(data_desc_));
|
|
|
|
@ -209,6 +213,25 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
int N, C, H, W, D;
|
|
|
|
|
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
|
|
|
|
|
|
|
|
|
|
// init output
|
|
|
|
|
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
|
|
|
|
|
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
|
|
|
|
|
|
|
|
|
|
d_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
d_scale->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
d_bias->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
if ((N * H * W * D) == 1) {
|
|
|
|
|
framework::TensorCopySync(*d_y, ctx.GetPlace(), d_x);
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
|
|
|
|
|
functor;
|
|
|
|
|
functor(dev_ctx, d_scale, static_cast<BatchNormParamType<T>>(0));
|
|
|
|
|
functor(dev_ctx, d_bias, static_cast<BatchNormParamType<T>>(0));
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
|
|
|
|
|
|
|
|
|
@ -247,21 +270,11 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
|
|
|
|
|
bn_param_desc_, data_desc_, mode_));
|
|
|
|
|
|
|
|
|
|
// init output
|
|
|
|
|
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
auto *d_scale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
|
|
|
|
|
auto *d_bias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
|
|
|
|
|
|
|
|
|
|
d_x->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
d_scale->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
d_bias->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
const auto *saved_mean = ctx.Input<Tensor>("SavedMean");
|
|
|
|
|
const auto *saved_var = ctx.Input<Tensor>("SavedVariance");
|
|
|
|
|
const void *saved_mean_data = saved_mean->template data<T>();
|
|
|
|
|
const void *saved_var_data = saved_var->template data<T>();
|
|
|
|
|
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationBackward(
|
|
|
|
|
dev_ctx.cudnn_handle(), mode_, CudnnDataType<T>::kOne(),
|
|
|
|
|
CudnnDataType<T>::kZero(), CudnnDataType<T>::kOne(),
|
|
|
|
|