|
|
|
@ -50,10 +50,6 @@ class BatchNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("SavedMean"), "");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("SavedVariance"), "");
|
|
|
|
|
|
|
|
|
|
const float epsilon = ctx->Attrs().Get<float>("epsilon");
|
|
|
|
|
PADDLE_ENFORCE_GE(epsilon, 0.0, "epsilon should be larger than 0");
|
|
|
|
|
PADDLE_ENFORCE_LE(epsilon, 0.001, "epsilon should not be too large");
|
|
|
|
|
|
|
|
|
|
// make sure Mean/MeanOut and Variance/VarianceOut share memory in Python
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->Inputs("Mean")[0], ctx->Outputs("MeanOut")[0],
|
|
|
|
|
"Mean and MeanOut should share the same memory");
|
|
|
|
@ -91,7 +87,12 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddAttr<bool>("is_test", "").SetDefault(false);
|
|
|
|
|
AddAttr<float>("momentum", "").SetDefault(0.9);
|
|
|
|
|
AddAttr<float>("epsilon", "").SetDefault(1e-5);
|
|
|
|
|
AddAttr<float>("epsilon", "")
|
|
|
|
|
.SetDefault(1e-5)
|
|
|
|
|
.AddCustomChecker([](const float &epsilon) {
|
|
|
|
|
PADDLE_ENFORCE(epsilon >= 0.0f && epsilon <= 0.001f,
|
|
|
|
|
"'epsilon' should be between 0.0 and 0.001.");
|
|
|
|
|
});
|
|
|
|
|
AddAttr<std::string>("data_layout", "").SetDefault("NCHW");
|
|
|
|
|
AddInput("X", "The input tensor");
|
|
|
|
|
AddInput("Scale",
|
|
|
|
|