|
|
|
@ -18,9 +18,9 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
class PreluOp : public framework::OperatorWithKernel {
|
|
|
|
|
class PReluOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
PreluOp(const std::string &type, const framework::VariableNameMap &inputs,
|
|
|
|
|
PReluOp(const std::string &type, const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
@ -34,13 +34,13 @@ class PreluOp : public framework::OperatorWithKernel {
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// template <typename AttrType>
|
|
|
|
|
class PreluOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
PreluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
|
PReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X", "The input tensor of prelu operator.");
|
|
|
|
|
AddOutput("Out", "The output tensor of prelu operator.");
|
|
|
|
|
AddComment(R"DOC(Prelu operator
|
|
|
|
|
AddComment(R"DOC(PRelu operator
|
|
|
|
|
|
|
|
|
|
The equation is:
|
|
|
|
|
f(x) = alpha * x , for x < 0
|
|
|
|
@ -52,7 +52,7 @@ f(x) = x , for x >= 0
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// The operator to calculate gradients of a prelu operator.
|
|
|
|
|
class PreluGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
class PReluGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
@ -72,9 +72,9 @@ class PreluGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
|
|
|
|
|
REGISTER_OP(prelu, ops::PreluOp, ops::PreluOpMaker, prelu_grad,
|
|
|
|
|
ops::PreluGradOp);
|
|
|
|
|
REGISTER_OP(prelu, ops::PReluOp, ops::PReluOpMaker, prelu_grad,
|
|
|
|
|
ops::PReluGradOp);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(prelu,
|
|
|
|
|
ops::PreluKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
ops::PReluKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(prelu_grad,
|
|
|
|
|
ops::PreluGradKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|
ops::PReluGradKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|