|
|
|
@ -114,23 +114,11 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
const auto *bias = ctx.Input<Tensor>("Bias");
|
|
|
|
|
|
|
|
|
|
auto *y = ctx.Output<Tensor>("Y");
|
|
|
|
|
auto *mean_out = ctx.Output<Tensor>("MeanOut");
|
|
|
|
|
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
|
|
|
|
|
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
|
|
|
|
|
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
|
|
|
|
|
|
|
|
|
|
// alloc memory
|
|
|
|
|
y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
|
|
|
|
|
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
|
|
|
|
|
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
|
|
|
|
|
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto &dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
|
|
|
|
|
functor;
|
|
|
|
|
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
|
|
|
|
|
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
|
|
|
|
|
|
|
|
|
|
auto handle = dev_ctx.cudnn_handle();
|
|
|
|
|
|
|
|
|
@ -159,6 +147,21 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
// Run training mode.
|
|
|
|
|
// obtain running mean and running inv var, and see if we need to
|
|
|
|
|
// initialize them.
|
|
|
|
|
|
|
|
|
|
auto *mean_out = ctx.Output<Tensor>("MeanOut");
|
|
|
|
|
auto *variance_out = ctx.Output<Tensor>("VarianceOut");
|
|
|
|
|
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
|
|
|
|
|
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto *saved_mean = ctx.Output<Tensor>("SavedMean");
|
|
|
|
|
auto *saved_variance = ctx.Output<Tensor>("SavedVariance");
|
|
|
|
|
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
|
|
|
|
|
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
|
|
|
|
|
math::SetConstant<platform::CUDADeviceContext, BatchNormParamType<T>>
|
|
|
|
|
functor;
|
|
|
|
|
functor(dev_ctx, saved_mean, static_cast<BatchNormParamType<T>>(0));
|
|
|
|
|
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
|
|
|
|
|
|
|
|
|
|
double this_factor = 1. - momentum;
|
|
|
|
|
|
|
|
|
|
CUDNN_ENFORCE(platform::dynload::cudnnBatchNormalizationForwardTraining(
|
|
|
|
|