|
|
|
@ -26,7 +26,6 @@ class DropoutOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
// validity check
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE_GE(ctx.Attr<float>("dropout_prob"), 0);
|
|
|
|
|
PADDLE_ENFORCE_LE(ctx.Attr<float>("dropout_prob"), 1);
|
|
|
|
@ -34,10 +33,11 @@ class DropoutOp : public framework::OperatorWithKernel {
|
|
|
|
|
PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 ||
|
|
|
|
|
ctx.Attr<int>("is_training") == 1);
|
|
|
|
|
|
|
|
|
|
// resize
|
|
|
|
|
auto dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
ctx.Output<LoDTensor>("Out")->Resize(dims);
|
|
|
|
|
ctx.Output<LoDTensor>("Mask")->Resize(dims);
|
|
|
|
|
if (ctx.Attr<int>("is_training") == 1) {
|
|
|
|
|
ctx.Output<LoDTensor>("Mask")->Resize(dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -75,24 +75,27 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
|
// validity check
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx.Attr<int>("is_training"), 1,
|
|
|
|
|
"GradOp is only callable when is_training is true");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) must not be null.");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_GE(ctx.Attr<AttrType>("dropout_prob"), 0);
|
|
|
|
|
PADDLE_ENFORCE_LE(ctx.Attr<AttrType>("dropout_prob"), 1);
|
|
|
|
|
// TODO(xinghai-sun): remove this check after swtiching to bool
|
|
|
|
|
PADDLE_ENFORCE(ctx.Attr<int>("is_training") == 0 ||
|
|
|
|
|
ctx.Attr<int>("is_training") == 1);
|
|
|
|
|
auto x_dims = ctx.Input<Tensor>("X")->dims();
|
|
|
|
|
auto mask_dims = ctx.Input<Tensor>("Mask")->dims();
|
|
|
|
|
auto out_dims = ctx.Input<Tensor>(framework::GradVarName("Out"))->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims, out_dims,
|
|
|
|
|
"Dimensions of Input(X) and Out@Grad must be the same.");
|
|
|
|
|
auto mask_dims = ctx.Input<Tensor>("Mask")->dims();
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_dims, mask_dims,
|
|
|
|
|
"Dimensions of Input(X) and Mask must be the same.");
|
|
|
|
|
// resize
|
|
|
|
|
|
|
|
|
|
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
|
|
|
|
|
x_grad->Resize(x_dims);
|
|
|
|
|
}
|
|
|
|
|