|
|
|
@ -58,8 +58,12 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
// NCHW [batch_size, in_channels, in_height, in_width]
|
|
|
|
|
const auto *x = ctx.Input<Tensor>("X");
|
|
|
|
|
const auto &x_dims = x->dims();
|
|
|
|
|
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
|
|
|
|
|
"The Input dim size should be between 2 and 5");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims.size() >= 2 && x_dims.size() <= 5, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The size of input's dimensions should be between 2 and 5"
|
|
|
|
|
"But received: the size of input's dimensions is [%d]",
|
|
|
|
|
x_dims.size()));
|
|
|
|
|
|
|
|
|
|
auto *y = ctx.Output<Tensor>("Y");
|
|
|
|
|
y->mutable_data<T>(ctx.GetPlace());
|
|
|
|
@ -151,10 +155,34 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
const auto *est_mean = ctx.Input<Tensor>("Mean");
|
|
|
|
|
const auto *est_var = ctx.Input<Tensor>("Variance");
|
|
|
|
|
// Run inference mode.
|
|
|
|
|
PADDLE_ENFORCE_EQ(est_mean->dims().size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(est_var->dims().size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(est_mean->dims()[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(est_var->dims()[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
est_mean->dims().size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The size of mean's dimensions must equal to 1."
|
|
|
|
|
"But received: the size of mean's dimensions mean is [%d],"
|
|
|
|
|
"the dimensions of mean is [%s].",
|
|
|
|
|
est_mean->dims().size(), est_mean->dims()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
est_var->dims().size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The size of variance's dimensions must equal to 1."
|
|
|
|
|
"But received: the size of variance's dimensions is [%d],"
|
|
|
|
|
"the dimensions of variance is [%s].",
|
|
|
|
|
est_var->dims().size(), est_var->dims()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
est_mean->dims()[0], C,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of mean must equal to the number of "
|
|
|
|
|
"Channels, which is [%d]. But received: the first dimension"
|
|
|
|
|
"of mean is [%d], the dimensions of mean is [%s].",
|
|
|
|
|
C, est_mean->dims()[0], est_mean->dims()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
est_var->dims()[0], C,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of variance must equal to the number"
|
|
|
|
|
"of Channels, which is [%d]. But received: the first dimension of"
|
|
|
|
|
"variance is [%d], the dimensions of variance is [%s].",
|
|
|
|
|
C, est_var->dims()[0], est_var->dims()));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_CUDA_SUCCESS(
|
|
|
|
|
platform::dynload::cudnnBatchNormalizationForwardInference(
|
|
|
|
@ -503,8 +531,13 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
|
|
|
|
|
const auto &x_dims = x->dims();
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
|
|
|
|
|
"The Input dim size should be between 2 and 5");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
x_dims.size() >= 2 && x_dims.size() <= 5, true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The size of input's dimensions should be between 2 and 5."
|
|
|
|
|
"But received: the size of input's dimensions is [%d],"
|
|
|
|
|
"the dimensions of input is [%s]",
|
|
|
|
|
x_dims.size(), x_dims));
|
|
|
|
|
int N, C, H, W, D;
|
|
|
|
|
ExtractNCWHD(x_dims, data_layout, &N, &C, &H, &W, &D);
|
|
|
|
|
|
|
|
|
@ -515,8 +548,19 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
|
|
|
|
|
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
|
|
|
|
|
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale->dims().size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale->dims()[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
scale->dims().size(), 1UL,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The size of scale's dimensions must equal to 1. But received: "
|
|
|
|
|
"the size of scale's dimensions is [%d], the dimensions of scale "
|
|
|
|
|
"is [%s].",
|
|
|
|
|
scale->dims().size(), scale->dims()));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
scale->dims()[0], C,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"The first dimension of scale must equal to Channels[%d]. But "
|
|
|
|
|
"received: the first dimension of scale is [%d]",
|
|
|
|
|
C, scale->dims()[0]));
|
|
|
|
|
|
|
|
|
|
auto dtype = platform::CudnnDataType<T>::type;
|
|
|
|
|
const auto *reserve_space = ctx.Input<Tensor>("ReserveSpace");
|
|
|
|
|