|
|
|
@ -26,6 +26,8 @@ using Tensor = framework::Tensor;
|
|
|
|
|
using DataLayout = framework::DataLayout;
|
|
|
|
|
template <typename T>
|
|
|
|
|
using CudnnDataType = platform::CudnnDataType<T>;
|
|
|
|
|
template <typename T>
|
|
|
|
|
using bn_param_type = CudnnDataType<T>::bn_param_type;
|
|
|
|
|
|
|
|
|
|
void ExtractNCWHD(const framework::DDim &dims, const DataLayout &data_layout,
|
|
|
|
|
int *N, int *C, int *H, int *W, int *D) {
|
|
|
|
@ -104,8 +106,9 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnSetTensorNdDescriptor(
|
|
|
|
|
data_desc_, CudnnDataType<T>::type,
|
|
|
|
|
x_dims.size() > 3 ? x_dims.size() : 4, dims.data(), strides.data()));
|
|
|
|
|
// Note: PERSISTENT not implemented for inference
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnDeriveBNTensorDescriptor(
|
|
|
|
|
bn_param_desc_, data_desc_, mode_));
|
|
|
|
|
bn_param_desc_, data_desc_, is_test ? CUDNN_BATCHNORM_SPATIAL : mode_));
|
|
|
|
|
|
|
|
|
|
const auto *scale = ctx.Input<Tensor>("Scale");
|
|
|
|
|
const auto *bias = ctx.Input<Tensor>("Bias");
|
|
|
|
@ -118,15 +121,15 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
|
|
|
|
|
// alloc memory
|
|
|
|
|
y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
mean_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
variance_out->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
saved_mean->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
saved_variance->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
mean_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
|
|
|
|
|
variance_out->mutable_data<bn_param_type<T>>(ctx.GetPlace());
|
|
|
|
|
saved_mean->mutable_data<bn_param_type<T>>(ctx.GetPlace());
|
|
|
|
|
saved_variance->mutable_data<bn_param_type<T>>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, T> functor;
|
|
|
|
|
functor(dev_ctx, saved_mean, static_cast<T>(0));
|
|
|
|
|
functor(dev_ctx, saved_variance, static_cast<T>(0));
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, bn_param_type<T>> functor;
|
|
|
|
|
functor(dev_ctx, saved_mean, static_cast<bn_param_type<T>>(0));
|
|
|
|
|
functor(dev_ctx, saved_variance, static_cast<bn_param_type<T>>(0));
|
|
|
|
|
|
|
|
|
|
auto handle = dev_ctx.cudnn_handle();
|
|
|
|
|
|
|
|
|
@ -147,8 +150,10 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
CUDNN_BATCHNORM_SPATIAL, CudnnDataType<T>::kOne(),
|
|
|
|
|
CudnnDataType<T>::kZero(), data_desc_, x->template data<T>(),
|
|
|
|
|
data_desc_, y->template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
bn_param_desc_, scale->template data<T>(), bias->template data<T>(),
|
|
|
|
|
est_mean->template data<T>(), est_var->template data<T>(), epsilon));
|
|
|
|
|
bn_param_desc_, scale->template data<bn_param_type<T>>(),
|
|
|
|
|
bias->template data<bn_param_type<T>>(),
|
|
|
|
|
est_mean->template data<bn_param_type<T>>(),
|
|
|
|
|
est_var->template data<bn_param_type<T>>(), epsilon));
|
|
|
|
|
} else {
|
|
|
|
|
// Run training mode.
|
|
|
|
|
// obtain running mean and running inv var, and see if we need to
|
|
|
|
@ -159,11 +164,14 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
handle, mode_, CudnnDataType<T>::kOne(), CudnnDataType<T>::kZero(),
|
|
|
|
|
data_desc_, x->template data<T>(), data_desc_,
|
|
|
|
|
y->template mutable_data<T>(ctx.GetPlace()), bn_param_desc_,
|
|
|
|
|
scale->template data<T>(), bias->template data<T>(), this_factor,
|
|
|
|
|
mean_out->template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
variance_out->template mutable_data<T>(ctx.GetPlace()), epsilon,
|
|
|
|
|
saved_mean->template mutable_data<T>(ctx.GetPlace()),
|
|
|
|
|
saved_variance->template mutable_data<T>(ctx.GetPlace())));
|
|
|
|
|
scale->template data<bn_param_type<T>>(),
|
|
|
|
|
bias->template data<bn_param_type<T>>(), this_factor,
|
|
|
|
|
mean_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
|
|
|
|
|
variance_out->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
|
|
|
|
|
epsilon,
|
|
|
|
|
saved_mean->template mutable_data<bn_param_type<T>>(ctx.GetPlace()),
|
|
|
|
|
saved_variance->template mutable_data<bn_param_type<T>>(
|
|
|
|
|
ctx.GetPlace())));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// clean when exit.
|
|
|
|
|