|
|
|
@ -68,6 +68,9 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
auto scale_dim = ctx->GetInputDim("Scale");
|
|
|
|
|
auto bias_dim = ctx->GetInputDim("Bias");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
|
|
|
|
|
|
|
|
|
|
bool check = true;
|
|
|
|
|
if ((!ctx->IsRuntime()) && (framework::product(scale_dim) <= 0 ||
|
|
|
|
|
framework::product(bias_dim) <= 0)) {
|
|
|
|
@ -75,9 +78,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (check) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dim[0], C);
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dim.size(), 1UL);
|
|
|
|
|
PADDLE_ENFORCE_EQ(scale_dim[0], C);
|
|
|
|
|
}
|
|
|
|
|
ctx->SetOutputDim("Y", x_dims);
|
|
|
|
|