refine drop_out_op

del_some_in_makelist
chengduoZH 7 years ago
parent 52965458d2
commit a1e1ae3ff7

@ -25,8 +25,6 @@ class DropoutOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims); ctx->SetOutputDim("Out", x_dims);
@ -47,7 +45,11 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate(); AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate();
AddAttr<float>("dropout_prob", "Probability of setting units to zero.") AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
.SetDefault(.5f); .SetDefault(.5f)
.AddCustomChecker([](const float& drop_p) {
PADDLE_ENFORCE(drop_p > 0.0f && drop_p < 1.0f,
"'dropout_prob' must be between 0 and 1.");
});
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false); AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0); AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
@ -78,8 +80,6 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) must not be null."); "Input(Out@GRAD) must not be null.");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(x_dims, out_dims, PADDLE_ENFORCE_EQ(x_dims, out_dims,

Loading…
Cancel
Save