|
|
|
@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
ctx->SetOutputDim("Out", x_dims);
|
|
|
|
|
if (ctx->Attrs().Get<bool>("is_training") == 1) {
|
|
|
|
|
if (ctx->Attrs().Get<bool>("is_training") == true) {
|
|
|
|
|
ctx->SetOutputDim("Mask", x_dims);
|
|
|
|
|
}
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
@ -43,7 +43,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
DropoutOpMaker(framework::OpProto* proto,
|
|
|
|
|
framework::OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddAttr<AttrType>("dropout_prob", "Probability of setting units to zero.")
|
|
|
|
|
AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
|
|
|
|
|
.SetDefault(.5f);
|
|
|
|
|
AddAttr<bool>("is_training", "Whether in training phase.").SetDefault(true);
|
|
|
|
|
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
|
|
|
|
@ -69,7 +69,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), 1,
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), true,
|
|
|
|
|
"GradOp is only callable when is_training is true");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
|
|
|
|
@ -77,8 +77,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) must not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(ctx->Attrs().Get<AttrType>("dropout_prob"), 0);
|
|
|
|
|
PADDLE_ENFORCE_LE(ctx->Attrs().Get<AttrType>("dropout_prob"), 1);
|
|
|
|
|
PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
|
|
|
|
|
PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims, out_dims,
|
|
|
|
|