|
|
@ -152,8 +152,8 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
|
|
|
|
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
|
|
|
|
functor(dev_ctx, saved_variance, static_cast<BatchNormParamType<T>>(0));
|
|
|
|
|
|
|
|
|
|
|
|
if ((N * H * W * D) == 1) {
|
|
|
|
if ((N * H * W * D) == 1) {
|
|
|
|
LOG(WARNING) << "Only 1 element in normalization dimension, "
|
|
|
|
// Only 1 element in normalization dimension,
|
|
|
|
<< "we skip the batch norm calculation, let y = x.";
|
|
|
|
// skip the batch norm calculation, let y = x.
|
|
|
|
framework::TensorCopy(*x, ctx.GetPlace(), y);
|
|
|
|
framework::TensorCopy(*x, ctx.GetPlace(), y);
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
double this_factor = 1. - momentum;
|
|
|
|
double this_factor = 1. - momentum;
|
|
|
|