|
|
|
@ -444,11 +444,17 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
|
|
|
|
|
// check output
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "");
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")),
|
|
|
|
|
"Output(Scale@GRAD) and Output(Bias@GRAD) should not be "
|
|
|
|
|
"null at same time");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const bool has_scale_grad = ctx->HasOutput(framework::GradVarName("Scale"));
|
|
|
|
|
const bool has_bias_grad = ctx->HasOutput(framework::GradVarName("Bias"));
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ((has_scale_grad == has_bias_grad), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(Scale@GRAD) and Output(Bias@GRAD) must be null "
|
|
|
|
|
"or not be null at same time. But now, "
|
|
|
|
|
"has Scale@Grad=[%d], has Bias@GRAD=[%d]",
|
|
|
|
|
has_scale_grad, has_bias_grad));
|
|
|
|
|
|
|
|
|
|
const bool use_global_stats = ctx->Attrs().Get<bool>("use_global_stats");
|
|
|
|
|
if (use_global_stats) {
|
|
|
|
|
PADDLE_ENFORCE(!ctx->Attrs().Get<bool>("use_mkldnn"),
|
|
|
|
@ -463,7 +469,8 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
: x_dims[x_dims.size() - 1]);
|
|
|
|
|
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("Scale"))) {
|
|
|
|
|
// has_scale_grad == has_bias_grad, judge has_scale_grad is enough
|
|
|
|
|
if (has_scale_grad) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Scale"), {C});
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("Bias"), {C});
|
|
|
|
|
}
|
|
|
|
|