|
|
@ -29,6 +29,11 @@ class PReluOp : public framework::OperatorWithKernel {
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
void InferShape(const framework::InferShapeContext &ctx) const override {
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
|
|
|
|
auto *in = ctx.Input<framework::Tensor>("X");
|
|
|
|
auto *in = ctx.Input<framework::Tensor>("X");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Alpha"),
|
|
|
|
|
|
|
|
"Input(Alpha) should not be null");
|
|
|
|
|
|
|
|
auto *alpha = ctx.Input<framework::Tensor>("Alpha");
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(alpha->numel() == 1, "Size of weight Alpha must be one.");
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
|
|
|
|
"Output(Out) should not be null");
|
|
|
|
"Output(Out) should not be null");
|
|
|
|
auto *out = ctx.Output<framework::LoDTensor>("Out");
|
|
|
|
auto *out = ctx.Output<framework::LoDTensor>("Out");
|
|
|
@ -36,15 +41,13 @@ class PReluOp : public framework::OperatorWithKernel {
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
template <typename AttrType>
|
|
|
|
|
|
|
|
class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
AddInput("X", "The input tensor of prelu operator.");
|
|
|
|
AddInput("X", "The input tensor of prelu operator.");
|
|
|
|
|
|
|
|
AddInput("Alpha", "The alpha weight of prelu operator.");
|
|
|
|
AddOutput("Out", "The output tensor of prelu operator.");
|
|
|
|
AddOutput("Out", "The output tensor of prelu operator.");
|
|
|
|
AddAttr<AttrType>("alpha", "The scaling factor alpha of prelu.")
|
|
|
|
|
|
|
|
.SetDefault(0.0);
|
|
|
|
|
|
|
|
AddComment(R"DOC(PRelu operator
|
|
|
|
AddComment(R"DOC(PRelu operator
|
|
|
|
|
|
|
|
|
|
|
|
The equation is:
|
|
|
|
The equation is:
|
|
|
@ -66,11 +69,15 @@ class PReluGradOp : public framework::OperatorWithKernel {
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null.");
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
auto *X_grad =
|
|
|
|
auto *dx = ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
|
|
|
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"));
|
|
|
|
auto *x = ctx.Input<framework::Tensor>("X");
|
|
|
|
auto *X = ctx.Input<framework::Tensor>("X");
|
|
|
|
|
|
|
|
|
|
|
|
auto *dalpha =
|
|
|
|
|
|
|
|
ctx.Output<framework::LoDTensor>(framework::GradVarName("Alpha"));
|
|
|
|
|
|
|
|
auto *alpha = ctx.Input<framework::Tensor>("Alpha");
|
|
|
|
|
|
|
|
|
|
|
|
X_grad->Resize(X->dims());
|
|
|
|
dx->Resize(x->dims());
|
|
|
|
|
|
|
|
dalpha->Resize(alpha->dims());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
@ -79,7 +86,7 @@ class PReluGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
|
|
REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker<float>, prelu_grad,
|
|
|
|
REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker, prelu_grad,
|
|
|
|
ops::PReluGradOp);
|
|
|
|
ops::PReluGradOp);
|
|
|
|
REGISTER_OP_CPU_KERNEL(prelu,
|
|
|
|
REGISTER_OP_CPU_KERNEL(prelu,
|
|
|
|
ops::PReluKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
ops::PReluKernel<paddle::platform::CPUPlace, float>);
|
|
|
|