You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/operators/activation_op.cc

1003 lines
33 KiB

/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include <memory>
7 years ago
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
6 years ago
#include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cudnn_helper.h"
#endif
DECLARE_bool(use_mkldnn);
namespace paddle {
namespace operators {
using paddle::framework::Tensor;
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
public: \
void Make() override { \
AddInput("X", "Input of " #OP_NAME " operator"); \
AddOutput("Out", "Output of " #OP_NAME " operator"); \
AddAttr<bool>("use_mkldnn", \
"(bool, default false) Only used in mkldnn kernel") \
.SetDefault(false); \
AddAttr<bool>("use_cudnn", \
"(bool, default false) Only used in cudnn kernel, need " \
"install cudnn") \
.SetDefault(false); \
AddAttr<bool>( \
"is_test", \
"(bool, default false) Set to true for inference only, false " \
"for training. Some layers may run faster when this is true.") \
.SetDefault(false); \
AddComment(OP_COMMENT); \
} \
}
7 years ago
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType(ForwardOpType() + "_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
if ((static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepX)) ||
FLAGS_use_mkldnn || (op->HasAttr("use_mkldnn") &&
boost::get<bool>(op->GetAttr("use_mkldnn")))) {
op->SetInput("X", Input("X"));
}
if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
op->SetInput("Out", Output("Out"));
}
return op;
}
};
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
// FIXME(liuwei1031) temporarily disable the code to unblock users
// TODO(liuwei1031) figure out the reason behind
// https://github.com/PaddlePaddle/Paddle/issues/16096
// and re-enable this in the future
// #ifdef PADDLE_WITH_CUDA
// auto it1 = oper.Attrs().find("use_cudnn");
// if (it1 != oper.Attrs().end() && platform::CanCUDNNBeUsed(ctx)) {
// library = framework::LibraryType::kCUDNN;
// }
// #endif
#ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
library);
}
8 years ago
class ActivationOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
8 years ago
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
};
class ActivationOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
8 years ago
class ActivationOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
8 years ago
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, framework::GradVarName("Out"));
}
8 years ago
};
6 years ago
UNUSED constexpr char SigmoidDoc[] = R"DOC(
Sigmoid Activation Operator
7 years ago
$$out = \\frac{1}{1 + e^{-x}}$$
7 years ago
)DOC";
6 years ago
UNUSED constexpr char LogSigmoidDoc[] = R"DOC(
Logsigmoid Activation Operator
7 years ago
$$out = \\log \\frac{1}{1 + e^{-x}}$$
7 years ago
)DOC";
6 years ago
UNUSED constexpr char ExpDoc[] = R"DOC(
Exp Activation Operator.
7 years ago
$out = e^x$
7 years ago
)DOC";
6 years ago
UNUSED constexpr char ReluDoc[] = R"DOC(
Relu Activation Operator.
7 years ago
$out = \max(x, 0)$
7 years ago
)DOC";
7 years ago
UNUSED constexpr char GeluDoc[] = R"DOC(
Gelu Activation Operator.
$out = \\frac{1 + erf(\\frac{x}{\\sqrt{2}})}{2} x$
)DOC";
6 years ago
UNUSED constexpr char TanhDoc[] = R"DOC(
Tanh Activation Operator.
7 years ago
7 years ago
$$out = \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
7 years ago
)DOC";
6 years ago
UNUSED constexpr char TanhShrinkDoc[] = R"DOC(
TanhShrink Activation Operator.
7 years ago
$$out = x - \\frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
7 years ago
)DOC";
7 years ago
6 years ago
UNUSED constexpr char SqrtDoc[] = R"DOC(
Sqrt Activation Operator.
7 years ago
Please make sure legal input, when input a negative value closed to zero,
you should add a small epsilon(1e-12) to avoid negative number caused by numerical errors.
$out = \sqrt{x}$
7 years ago
)DOC";
UNUSED constexpr char RsqrtDoc[] = R"DOC(
Rsqrt Activation Operator.
Please make sure input is legal in case of numeric errors.
$out = \frac{1}{\sqrt{x}}$
)DOC";
6 years ago
UNUSED constexpr char AbsDoc[] = R"DOC(
Abs Activation Operator.
7 years ago
$out = |x|$
7 years ago
)DOC";
6 years ago
UNUSED constexpr char CeilDoc[] = R"DOC(
Ceil Activation Operator.
$out = \left \lceil x \right \rceil$
)DOC";
6 years ago
UNUSED constexpr char FloorDoc[] = R"DOC(
Floor Activation Operator.
$out = \left \lfloor x \right \rfloor$
)DOC";
6 years ago
UNUSED constexpr char CosDoc[] = R"DOC(
7 years ago
Cosine Activation Operator.
7 years ago
$out = cos(x)$
)DOC";
7 years ago
6 years ago
UNUSED constexpr char SinDoc[] = R"DOC(
7 years ago
Sine Activation Operator.
$out = sin(x)$
)DOC";
7 years ago
6 years ago
UNUSED constexpr char RoundDoc[] = R"DOC(
Round Activation Operator.
$out = [x]$
)DOC";
6 years ago
UNUSED constexpr char ReciprocalDoc[] = R"DOC(
Reciprocal Activation Operator.
7 years ago
$$out = \\frac{1}{x}$$
7 years ago
)DOC";
6 years ago
UNUSED constexpr char LogDoc[] = R"DOC(
Log Activation Operator.
7 years ago
$out = \ln(x)$
7 years ago
Natural logarithm of x.
)DOC";
6 years ago
UNUSED constexpr char SquareDoc[] = R"DOC(
Square Activation Operator.
$out = x^2$
)DOC";
6 years ago
UNUSED constexpr char SoftplusDoc[] = R"DOC(
Softplus Activation Operator.
$out = \ln(1 + e^{x})$
)DOC";
6 years ago
UNUSED constexpr char SoftsignDoc[] = R"DOC(
Softsign Activation Operator.
$$out = \\frac{x}{1 + \|x\|}$$
)DOC";
class AcosOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of acos operator");
AddOutput("Out", "Output of acos operator");
AddComment(R"DOC(
Arccosine Activation Operator.
$$out = \cos^{-1}(x)$$
)DOC");
}
};
class AsinOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of asin operator");
AddOutput("Out", "Output of asin operator");
AddComment(R"DOC(
Arcsine Activation Operator.
$$out = \sin^{-1}(x)$$
)DOC");
}
};
class AtanOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of atan operator");
AddOutput("Out", "Output of atan operator");
AddComment(R"DOC(
Arctanh Activation Operator.
$$out = \tanh^{-1}(x)$$
)DOC");
}
};
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of LeakyRelu operator");
AddOutput("Out", "Output of LeakyRelu operator");
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
7 years ago
AddComment(R"DOC(
LeakyRelu Activation Operator.
7 years ago
$out = \max(x, \alpha * x)$
7 years ago
)DOC");
}
};
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Softshrink operator");
AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
7 years ago
AddComment(R"DOC(
:strong:`Softshrink Activation Operator`
.. math::
out = \begin{cases}
x - \lambda, \text{if } x > \lambda \\
x + \lambda, \text{if } x < -\lambda \\
0, \text{otherwise}
\end{cases}
7 years ago
)DOC");
}
};
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of HardShrink operator");
AddOutput("Out", "Output of HardShrink operator");
AddAttr<float>("threshold",
"The value of threshold for HardShrink. [default: 0.5]")
.SetDefault(0.5f);
7 years ago
AddComment(R"DOC(
7 years ago
:strong:`HardShrink activation operator`
7 years ago
.. math::
out = \begin{cases}
x, \text{if } x > \lambda \\
x, \text{if } x < -\lambda \\
0, \text{otherwise}
\end{cases}
7 years ago
)DOC");
}
};
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of BRelu operator");
AddOutput("Out", "Output of BRelu operator");
AddAttr<float>("t_min", "The min marginal value of BRelu")
.SetDefault(static_cast<float>(0));
AddAttr<float>("t_max", "The max marginal value of BRelu")
.SetDefault(static_cast<float>(24));
7 years ago
AddComment(R"DOC(
BRelu Activation Operator.
7 years ago
$out = \max(\min(x, t_{min}), t_{max})$
7 years ago
)DOC");
}
};
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of SoftRelu operator");
AddOutput("Out", "Output of SoftRelu operator");
AddAttr<float>("threshold", "The threshold value of SoftRelu")
.SetDefault(40.0f);
7 years ago
AddComment(R"DOC(
SoftRelu Activation Operator.
7 years ago
$out = \ln(1 + \exp(\max(\min(x, threshold), -threshold)))$
7 years ago
)DOC");
}
};
class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
7 years ago
AddInput("X", "Input of ELU operator");
AddOutput("Out", "Output of ELU operator");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
AddComment(R"DOC(
ELU Activation Operator.
7 years ago
Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
7 years ago
)DOC");
}
};
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Relu6 operator");
AddOutput("Out", "Output of Relu6 operator");
AddAttr<float>("threshold", "The threshold value of Relu6")
.SetDefault(6.0f);
7 years ago
AddComment(R"DOC(
Relu6 Activation Operator.
7 years ago
$out = \min(\max(0, x), 6)$
7 years ago
)DOC");
}
};
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Pow operator");
AddInput("FactorTensor",
"(Tensor<float>, optional). If provided, pow will use this"
"The shape of FactorTensor MUST BE [1]."
"it has higher priority than attr(factor).")
.AsDispensable();
AddOutput("Out", "Output of Pow operator");
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
7 years ago
AddComment(R"DOC(
Pow Activation Operator.
7 years ago
$out = x^{factor}$
7 years ago
)DOC");
}
};
class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of STanh operator");
AddOutput("Out", "Output of STanh operator");
AddAttr<float>("scale_a", "The scale parameter of a for the input")
.SetDefault(2.0f / 3.0f);
AddAttr<float>("scale_b", "The scale parameter of b for the input")
.SetDefault(1.7159f);
7 years ago
AddComment(R"DOC(
STanh Activation Operator.
7 years ago
$$out = b * \\frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
7 years ago
)DOC");
}
};
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Out", "Output of ThresholdedRelu operator");
7 years ago
AddAttr<float>("threshold",
"The threshold location of activation. [default 1.0].")
.SetDefault(1.0f);
7 years ago
AddComment(R"DOC(
7 years ago
:strong:`ThresholdedRelu activation operator`
7 years ago
.. math::
7 years ago
out = \begin{cases}
7 years ago
x, \text{if } x > threshold \\
0, \text{otherwise}
\end{cases}
7 years ago
)DOC");
}
};
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of HardSigmoid operator");
AddOutput("Out", "Output of HardSigmoid operator");
AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
.SetDefault(0.2f);
AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
.SetDefault(0.5f);
AddComment(R"DOC(
HardSigmoid Activation Operator.
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
7 years ago
which is much faster than sigmoid.
$out = \max(0, \min(1, slope * x + shift))$
The slope should be positive. The offset can be either positive or negative.
7 years ago
The default slope and shift are set according to the above reference.
It is recommended to use the defaults for this activation.
7 years ago
)DOC");
}
};
class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of Swish operator");
AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
AddComment(R"DOC(
Swish Activation Operator.
$$out = \\frac{x}{1 + e^{- \beta \ x}}$$
)DOC");
}
};
class HardSwishOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of HardSwish operator");
AddOutput("Out", "Output of HardSwish operator");
AddAttr<float>("threshold", "The threshold parameter of HardSwish operator")
.SetDefault(6.0f);
AddAttr<float>("scale", "The scale parameter of HardSwish operator")
.SetDefault(6.0f);
AddAttr<float>("offset", "The offset parameter of HardSwish operator")
.SetDefault(3.0f);
AddComment(R"DOC(
HardSwish Activation Operator.
The hard version of swish(https://arxiv.org/pdf/1905.02244.pdf).
$out = \frac{x * (min(max(0, x+offset), threshold))}{scale}$
The threshold and scale should be positive. The offset can be either positive or negative.
The default parameters are set according to the above reference.
It is recommended to use the defaults for this activation.
)DOC");
}
};
REGISTER_ACTIVATION_OP_MAKER(Sigmoid, SigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(LogSigmoid, LogSigmoidDoc);
REGISTER_ACTIVATION_OP_MAKER(Exp, ExpDoc);
REGISTER_ACTIVATION_OP_MAKER(Relu, ReluDoc);
REGISTER_ACTIVATION_OP_MAKER(Gelu, GeluDoc);
REGISTER_ACTIVATION_OP_MAKER(Tanh, TanhDoc);
REGISTER_ACTIVATION_OP_MAKER(TanhShrink, TanhShrinkDoc);
REGISTER_ACTIVATION_OP_MAKER(Sqrt, SqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Rsqrt, RsqrtDoc);
REGISTER_ACTIVATION_OP_MAKER(Abs, AbsDoc);
REGISTER_ACTIVATION_OP_MAKER(Ceil, CeilDoc);
REGISTER_ACTIVATION_OP_MAKER(Floor, FloorDoc);
REGISTER_ACTIVATION_OP_MAKER(Cos, CosDoc);
REGISTER_ACTIVATION_OP_MAKER(Sin, SinDoc);
REGISTER_ACTIVATION_OP_MAKER(Round, RoundDoc);
REGISTER_ACTIVATION_OP_MAKER(Reciprocal, ReciprocalDoc);
REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
if (ctx->HasOutput("DX")) {
ctx->ShareDim("X", "DX");
ctx->ShareLoD("X", "DX");
}
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
}
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
if (ctx->HasOutput("DOut")) {
ctx->ShareDim("Out", "DOut");
ctx->ShareLoD("Out", "DOut");
}
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX");
}
};
template <ActBwdOpFwdDeps kDepValue>
class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("X", "DDOut");
ctx->ShareLoD("X", "DDOut");
}
}
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
if (ctx->HasOutput("DDOut")) {
ctx->ShareDim("Out", "DDOut");
ctx->ShareLoD("Out", "DDOut");
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "DDX");
}
};
//
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
//
class ReluDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
auto* op = new ::paddle::framework::OpDesc();
op->SetType("relu_grad_grad");
// input1: Out
op->SetInput("Out", Input("Out"));
// input2: ddx
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs());
// output: ddy
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op);
}
};
// leaky_relu Grad: dx=dy if y>=0 else alpha * dy
// leaky_relu GradGrad: ddy=ddx if y>=0 else alpha * ddx
class LeakyReluDoubleGradMaker
: public ::paddle::framework::SingleGradOpDescMaker {
public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
auto* op = new ::paddle::framework::OpDesc();
op->SetType("leaky_relu_grad_grad");
// input1: Out
op->SetInput("Out", Input("Out"));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs());
// Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op);
}
};
// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
class SqrtDoubleGradMaker : public ::paddle::framework::SingleGradOpDescMaker {
public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
auto* op = new ::paddle::framework::OpDesc();
op->SetType("sqrt_grad_grad");
op->SetInput("Out", Input("Out"));
op->SetInput("DX", Output(framework::GradVarName("X")));
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs());
op->SetOutput("DOut", InputGrad("Out"));
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op);
}
};
// square Grad: dx=2x*dy
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
class SquareDoubleGradMaker
: public ::paddle::framework::SingleGradOpDescMaker {
public:
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
auto* op = new ::paddle::framework::OpDesc();
op->SetType("square_grad_grad");
op->SetInput("X", Input("X"));
// Out@GRAD: dy
op->SetInput("DOut", Input(framework::GradVarName("Out")));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs());
// X@GRAD: dx
op->SetOutput("DX", InputGrad("X"));
// Out@GRAD@GRAD: ddy
op->SetOutput("DDOut", InputGrad(framework::GradVarName("Out")));
return std::unique_ptr<::paddle::framework::OpDesc>(op);
}
};
DECLARE_INPLACE_OP_INFERER(ActivationGradOpInplaceInference,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
class PowGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("pow_grad");
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetInput("FactorTensor", Input("FactorTensor"));
op->SetAttrMap(Attrs());
return op;
}
};
class PowOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "FactorTensor") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
class PowOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
auto out_grad_name = framework::GradVarName("Out");
ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, framework::GradVarName("Out"));
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
if (var_name == "FactorTensor") {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
REGISTER_OPERATOR( \
KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker, \
ops::ActivationOpInferVarType, \
ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \
void>::type); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
ops::ActivationGradOpInplaceInference);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::functor<float>>, \
ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::functor<double>>); \
REGISTER_OP_CPU_KERNEL( \
act_type##_grad, \
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::grad_functor<float>>, \
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::grad_functor<double>>);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
/* ========================== relu register ============================= */
REGISTER_OPERATOR(
relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::ReluGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::ReluDoubleGradMaker);
REGISTER_OPERATOR(
relu_grad_grad,
ops::ActivationOpDoubleGrad2<ops::ReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(relu, Relu, ReluFunctor, ReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
relu_grad_grad,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::ReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ======================== leaky relu register ============================ */
REGISTER_OPERATOR(
leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::LeakyReluGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(leaky_relu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::LeakyReluDoubleGradMaker);
REGISTER_OPERATOR(
leaky_relu_grad_grad,
ops::ActivationOpDoubleGrad2<ops::LeakyReluGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(leaky_relu, LeakyRelu, LeakyReluFunctor,
LeakyReluGradFunctor);
REGISTER_OP_CPU_KERNEL(
leaky_relu_grad_grad,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::LeakyReluGradGradFunctor<float>>,
ops::ActivationDoubleGradKernel<plat::CPUDeviceContext,
ops::LeakyReluGradGradFunctor<double>>,
ops::ActivationDoubleGradKernel<
plat::CPUDeviceContext, ops::LeakyReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== sqrt register ============================= */
REGISTER_OPERATOR(
sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::SqrtGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(sqrt_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::SqrtDoubleGradMaker);
REGISTER_OPERATOR(
sqrt_grad_grad,
ops::ActivationOpDoubleGrad<ops::SqrtGradGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
REGISTER_OP_CPU_KERNEL(
sqrt_grad_grad, ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<float>>,
ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<double>>,
ops::SqrtDoubleGradKernel<plat::CPUDeviceContext,
ops::SqrtGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== square register ============================ */
REGISTER_OPERATOR(
square, ops::ActivationOp, ops::SquareOpMaker,
ops::ActivationOpInferVarType,
ops::ActivationGradOpDescMaker<ops::SquareGradFunctor<float>::FwdDeps()>,
paddle::framework::SingleOpInplaceInToOut);
REGISTER_OPERATOR(square_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInference,
ops::SquareDoubleGradMaker);
REGISTER_OPERATOR(
square_grad_grad,
ops::ActivationOpDoubleGrad<ops::SquareGradGradFunctor<float>::FwdDeps()>);
REGISTER_ACTIVATION_CPU_KERNEL(square, Square, SquareFunctor,
SquareGradFunctor);
REGISTER_OP_CPU_KERNEL(
square_grad_grad,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<float>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<double>>,
ops::SquareDoubleGradKernel<plat::CPUDeviceContext,
ops::SquareGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== pow register ============================ */
REGISTER_OPERATOR(
pow, ops::PowOp, ops::PowOpMaker, ops::ActivationOpInferVarType,
ops::PowGradOpDescMaker,
std::conditional<ops::CanInplaceAct<ops::PowGradFunctor<float>>(),
::paddle::framework::SingleOpInplaceInToOut, void>::type);
REGISTER_OPERATOR(pow_grad, ops::PowOpGrad,
ops::ActivationGradOpInplaceInference);
REGISTER_OP_CPU_KERNEL(
pow, ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<float>>,
ops::PowKernel<plat::CPUDeviceContext, ops::PowFunctor<double>>);
REGISTER_OP_CPU_KERNEL(
pow_grad,
ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<float>>,
ops::PowGradKernel<plat::CPUDeviceContext, ops::PowGradFunctor<double>>);
/* ========================================================================== */