|
|
|
@ -528,10 +528,18 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
|
|
|
|
|
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
|
|
|
|
|
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
|
|
|
|
|
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
|
|
|
|
|
const bool is_test = ctx.Attr<bool>("is_test");
|
|
|
|
|
const float epsilon = ctx.Attr<float>("epsilon");
|
|
|
|
|
const DataLayout data_layout =
|
|
|
|
|
framework::StringToDataLayout(data_layout_str);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
is_test, false,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"`is_test = True` CANNOT be used in train program. If "
|
|
|
|
|
"you want to use global status in pre_train model, "
|
|
|
|
|
"please set `use_global_stats = True`"));
|
|
|
|
|
|
|
|
|
|
// Get the size for each dimension.
|
|
|
|
|
// NCHW [batch_size, in_channels, in_height, in_width]
|
|
|
|
|
const auto &x_dims = x->dims();
|
|
|
|
|