update error log for batch_norm_grad (#22017)

* update error information about batch_norm_grad

* update bn,test=develop
release/1.7
ceci3 6 years ago committed by GitHub
parent 985e4bae5e
commit 95d79b6d00
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -528,10 +528,18 @@ class BatchNormGradKernel<platform::CPUDeviceContext, T>
const auto *saved_inv_variance = ctx.Input<Tensor>("SavedVariance");
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const bool use_global_stats = ctx.Attr<bool>("use_global_stats");
const bool is_test = ctx.Attr<bool>("is_test");
const float epsilon = ctx.Attr<float>("epsilon");
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
// Get the size for each dimension.
// NCHW [batch_size, in_channels, in_height, in_width]
const auto &x_dims = x->dims();

@ -423,6 +423,13 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
const auto *x = ctx.Input<Tensor>("X");
const auto *d_y = ctx.Input<Tensor>(framework::GradVarName("Y"));
const auto *scale = ctx.Input<Tensor>("Scale");
const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(
is_test, false,
platform::errors::InvalidArgument(
"`is_test = True` CANNOT be used in train program. If "
"you want to use global status in pre_train model, "
"please set `use_global_stats = True`"));
const auto &x_dims = x->dims();

@ -2461,6 +2461,7 @@ def batch_norm(input,
Note:
if build_strategy.sync_batch_norm=True, the batch_norm in network will use
sync_batch_norm automatically.
`is_test = True` can only be used in test program and inference program, `is_test` CANNOT be set to True in train program, if you want to use global status from pre_train model in train program, please set `use_global_stats = True`.
Args:
input(Variable): The rank of input variable can be 2, 3, 4, 5. The data type

Loading…
Cancel
Save