You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
1552 lines
54 KiB
1552 lines
54 KiB
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License. */
|
|
|
|
#pragma once
|
|
#include <glog/logging.h>
|
|
#include <algorithm>
|
|
#include <memory>
|
|
#include <string>
|
|
#include <unordered_set>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
#include <cmath>
|
|
#ifndef _USE_MATH_DEFINES
|
|
#define _USE_MATH_DEFINES
|
|
#endif
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
#include "paddle/fluid/operators/detail/safe_ref.h"
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
#include "paddle/fluid/platform/float16.h"
|
|
|
|
#ifdef PADDLE_WITH_MKLDNN
|
|
#include "paddle/fluid/platform/mkldnn_helper.h"
|
|
#endif
|
|
|
|
namespace paddle {
|
|
namespace operators {
|
|
|
|
enum ActBwdOpFwdDeps {
|
|
kNoDeps = 0x00, // Do not need any forward input/output
|
|
kDepX = 0x01, // Only need forward input X
|
|
kDepOut = 0x02, // Only need forward output Out
|
|
|
|
// Never add kDepXOut, because Out can be always calculated
|
|
// by forward input X in backward part.
|
|
// FIXME(zjl): but in MKLDNN abs, X and Out are all needed...
|
|
// Developers should not rely on this enum value!
|
|
kDepXOut = 0x03
|
|
};
|
|
|
|
std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet();
|
|
|
|
static bool IsInplace(const std::string& op) {
|
|
static auto InplaceOpSet = GetInplaceOpSet();
|
|
bool inplace = InplaceOpSet->count(op);
|
|
// for op_grad
|
|
const int kGradSuffixLen = 4;
|
|
if (op.size() > kGradSuffixLen &&
|
|
op.compare(op.size() - kGradSuffixLen - 1, kGradSuffixLen, "grad")) {
|
|
inplace =
|
|
InplaceOpSet->count(op.substr(0, op.size() - (kGradSuffixLen + 1)));
|
|
}
|
|
return inplace;
|
|
}
|
|
|
|
/* The following operator can be used to process SelectedRows, because the
|
|
* output of those operator for zero is zero too.
|
|
*/
|
|
static std::unordered_set<std::string> CanBeUsedBySelectedRows = {
|
|
"abs", "abs_grad", "square", "square_grad", "sqrt", "sqrt_grad"};
|
|
|
|
inline void ExtractActivationTensor(const framework::ExecutionContext& context,
|
|
const framework::Tensor** X,
|
|
framework::Tensor** Out) {
|
|
auto x_var = context.InputVar("X");
|
|
auto out_var = context.OutputVar("Out");
|
|
PADDLE_ENFORCE(x_var != nullptr,
|
|
"Cannot get input Variable X, variable name = %s",
|
|
context.op().Input("X"));
|
|
PADDLE_ENFORCE(out_var != nullptr,
|
|
"Cannot get output Variable Out, variable name = %s",
|
|
context.op().Output("Out"));
|
|
if (CanBeUsedBySelectedRows.count(context.op().Type())) {
|
|
*X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
|
|
*Out = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
|
|
out_var);
|
|
} else {
|
|
*X = context.Input<framework::Tensor>("X");
|
|
*Out = context.Output<framework::Tensor>("Out");
|
|
}
|
|
|
|
PADDLE_ENFORCE(*Out != nullptr,
|
|
"Cannot get output tensor Out, variable name = %s",
|
|
context.op().Output("Out"));
|
|
}
|
|
|
|
template <ActBwdOpFwdDeps kDepValue>
|
|
inline void ExtractActivationGradTensor(
|
|
const framework::ExecutionContext& context, const framework::Tensor** X,
|
|
const framework::Tensor** Out, const framework::Tensor** dOut,
|
|
framework::Tensor** dX) {
|
|
auto out_grad_var = context.InputVar(framework::GradVarName("Out"));
|
|
auto x_grad_var = context.OutputVar(framework::GradVarName("X"));
|
|
const framework::Variable* out_var = nullptr;
|
|
|
|
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
|
|
out_var = context.InputVar("Out");
|
|
PADDLE_ENFORCE(out_var != nullptr,
|
|
"Cannot get input Variable Out, variable name = %s",
|
|
context.op().Input("Out"));
|
|
}
|
|
PADDLE_ENFORCE(out_grad_var != nullptr,
|
|
"Cannot get input Variable %s, variable name = %s",
|
|
framework::GradVarName("Out"),
|
|
context.op().Input(framework::GradVarName("Out")));
|
|
PADDLE_ENFORCE(x_grad_var != nullptr,
|
|
"Cannot get output Variable %s, variable name = %s",
|
|
framework::GradVarName("X"),
|
|
context.op().Output(framework::GradVarName("X")));
|
|
|
|
if (CanBeUsedBySelectedRows.count(context.op().Type())) {
|
|
*dOut = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(
|
|
*out_grad_var);
|
|
*dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
|
|
x_grad_var);
|
|
|
|
if (out_var) {
|
|
*Out =
|
|
paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
|
|
} else {
|
|
*Out = *dOut; // fake out
|
|
}
|
|
|
|
} else {
|
|
*Out = context.Input<framework::Tensor>("Out");
|
|
*dOut = context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
*dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
if (out_var) {
|
|
*Out = &(out_var->Get<framework::LoDTensor>());
|
|
} else {
|
|
*Out = *dOut; // fake out
|
|
}
|
|
}
|
|
|
|
PADDLE_ENFORCE(*dX != nullptr,
|
|
"Cannot get output tensor %s, variable name = %s",
|
|
framework::GradVarName("X"),
|
|
context.op().Output(framework::GradVarName("X")));
|
|
|
|
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
|
|
auto x_var = context.InputVar("X");
|
|
PADDLE_ENFORCE(x_var != nullptr,
|
|
"Cannot get input tensor X, variable name = %s",
|
|
context.op().Input("X"));
|
|
if (CanBeUsedBySelectedRows.count(context.op().Type())) {
|
|
*X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
|
|
} else {
|
|
*X = context.Input<framework::Tensor>("X");
|
|
}
|
|
} else {
|
|
VLOG(10) << " Inplace activation of Op : " << context.op().Type();
|
|
*X = *dX;
|
|
}
|
|
}
|
|
|
|
template <typename DeviceContext, typename Functor>
|
|
class ActivationKernel
|
|
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
public:
|
|
using T = typename Functor::ELEMENT_TYPE;
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
const framework::Tensor* X = nullptr;
|
|
framework::Tensor* Out = nullptr;
|
|
ExtractActivationTensor(context, &X, &Out);
|
|
Out->mutable_data<T>(context.GetPlace());
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
auto* place =
|
|
context.template device_context<DeviceContext>().eigen_device();
|
|
Functor functor;
|
|
|
|
auto attrs = functor.GetAttrs();
|
|
for (auto& attr : attrs) {
|
|
*attr.second = context.Attr<float>(attr.first);
|
|
}
|
|
functor(*place, x, out);
|
|
}
|
|
};
|
|
|
|
template <typename DeviceContext, typename Functor>
|
|
class ActivationGradKernel
|
|
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
public:
|
|
using T = typename Functor::ELEMENT_TYPE;
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
const framework::Tensor *X, *Out, *dOut;
|
|
framework::Tensor* dX = nullptr;
|
|
X = Out = dOut = nullptr;
|
|
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
|
|
&dX);
|
|
dX->mutable_data<T>(context.GetPlace());
|
|
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
|
|
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
|
|
auto* place =
|
|
context.template device_context<DeviceContext>().eigen_device();
|
|
Functor functor;
|
|
auto attrs = functor.GetAttrs();
|
|
for (auto& attr : attrs) {
|
|
*attr.second = context.Attr<float>(attr.first);
|
|
}
|
|
functor(*place, x, out, dout, dx);
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct BaseActivationFunctor {
|
|
using ELEMENT_TYPE = T;
|
|
|
|
using AttrPair = std::vector<std::pair<const char*, float*>>;
|
|
|
|
AttrPair GetAttrs() { return AttrPair(); }
|
|
|
|
/* NOTE(*): Output reuse X memory if X is not dependented by its Gradient.
|
|
For example, sigmoid op's gradient didn't involve x, so its output can
|
|
reuse
|
|
input memory. But abs op's gradient use x, it can not be inplaced.
|
|
gradient did use x.
|
|
*/
|
|
bool Inplace() const { return false; }
|
|
};
|
|
|
|
// sigmoid(x) = 1 / (1 + exp(-x))
|
|
template <typename T>
|
|
struct SigmoidFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * out * (static_cast<T>(1) - out);
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
// Originally: logsigmoid(x) = -log (1 + exp(-x))
|
|
// For numerical stability, we can use the log-sum-exp trick:
|
|
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
|
|
// We can rewrite the above equation as:
|
|
// out = -log( exp(0) + exp(-x)) [since exp(0) = 1]
|
|
// = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
|
|
// = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
|
|
// max(-x, 0)))
|
|
// = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
|
|
// = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
|
|
//
|
|
// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0))
|
|
// + exp(-x - max(-x, 0))))
|
|
template <typename T>
|
|
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
|
|
out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
|
|
}
|
|
};
|
|
|
|
// Originally: f' = exp(-x) / (1 + exp(-x))
|
|
// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) +
|
|
// exp(-x - max(-x, 0)))
|
|
template <typename T>
|
|
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
|
|
dx.device(d) =
|
|
dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// exp(x) = e^x
|
|
template <typename T>
|
|
struct ExpFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.exp();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct ExpGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * out;
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
// relu(x) = max(x, 0)
|
|
template <typename T>
|
|
struct ReluFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.cwiseMax(static_cast<T>(0));
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct ReluGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * (out > static_cast<T>(0)).template cast<T>();
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
|
|
template <typename T>
|
|
struct GeluFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
// Because the execute or device context can not be deliver here, it keep the
|
|
// marco for NVCC.
|
|
#if defined(PADDLE_WITH_MKLML) && !defined(_WIN32) && !defined(__APPLE__) && \
|
|
!defined(__OSX__) && !defined(PADDLE_WITH_CUDA)
|
|
auto x_data = x.data();
|
|
auto out_data = out.data();
|
|
int n = std::min(x.size(), out.size());
|
|
|
|
std::memset(out_data, 0, n * sizeof(T));
|
|
math::CBlas<T>::AXPY(n, static_cast<T>(M_SQRT1_2), x_data, 1, out_data, 1);
|
|
math::CBlas<T>::VMERF(n, out_data, out_data, VML_LA);
|
|
for (int i = 0; i < n; i++) {
|
|
out_data[i] += static_cast<T>(1);
|
|
}
|
|
math::CBlas<T>::VMUL(n, x_data, out_data, out_data);
|
|
for (int i = 0; i < n; i++) {
|
|
out_data[i] *= static_cast<T>(0.5);
|
|
}
|
|
#else
|
|
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
|
|
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
|
|
#endif
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct GeluGradFunctor : BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
auto first = static_cast<T>(0.5) *
|
|
(static_cast<T>(1) + ((x * static_cast<T>(M_SQRT1_2)).erf()));
|
|
|
|
auto second = static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
|
|
(-static_cast<T>(0.5) * x.square()).exp();
|
|
dx.device(d) = dout * (first + second);
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
|
template <typename T>
|
|
struct TanhFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.tanh();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct TanhGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * (static_cast<T>(1) - out * out);
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
// tanhshrink(x) = x - tanh(x)
|
|
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
|
template <typename T>
|
|
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x - x.tanh();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * (x.tanh() * x.tanh());
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// tanhshrink(x) = x - tanh(x)
|
|
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
|
template <typename T>
|
|
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
|
|
float threshold;
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"threshold", &threshold}};
|
|
}
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
|
|
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
|
|
out.device(d) = x * (temp1 + temp2);
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
|
|
float threshold;
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"threshold", &threshold}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
|
|
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
|
|
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < -lambda; 0
|
|
// otherwise
|
|
template <typename T>
|
|
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
|
|
float lambda;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"lambda", &lambda}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
auto lambdaT = static_cast<T>(lambda);
|
|
auto temp1 = (x > lambdaT).template cast<T>().eval();
|
|
auto temp2 = (x < -lambdaT).template cast<T>().eval();
|
|
out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
|
|
float lambda;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"lambda", &lambda}};
|
|
}
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
auto lambdaT = static_cast<T>(lambda);
|
|
auto temp1 = (x > lambdaT).template cast<T>().eval();
|
|
auto temp2 = (x < -lambdaT).template cast<T>().eval();
|
|
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// sqrt(x) = x^(1/2)
|
|
template <typename T>
|
|
struct SqrtFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.sqrt();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = static_cast<T>(0.5) * dout / out;
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
// rsqrt(x) = x^(-1/2)
|
|
template <typename T>
|
|
struct RsqrtFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.rsqrt();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = static_cast<T>(-0.5) * dout * out * out * out;
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
// ceil(x) = ceiling(x)
|
|
template <typename T>
|
|
struct CeilFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.ceil();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = static_cast<T>(0) * out;
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kNoDeps; }
|
|
};
|
|
|
|
// floor(x) = flooring(x)
|
|
template <typename T>
|
|
struct FloorFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.floor();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct Sine {
|
|
HOSTDEVICE T operator()(const T& val) const { return sin(val); }
|
|
};
|
|
|
|
template <>
|
|
struct Sine<platform::float16> {
|
|
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
|
|
return platform::float16(sin(static_cast<float>(val)));
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct Cosine {
|
|
HOSTDEVICE T operator()(const T& val) const { return cos(val); }
|
|
};
|
|
|
|
template <>
|
|
struct Cosine<platform::float16> {
|
|
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
|
|
return platform::float16(cos(static_cast<float>(val)));
|
|
}
|
|
};
|
|
|
|
// cosine'(x) = -sin(x)
|
|
template <typename T>
|
|
struct CosGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = -dout * x.unaryExpr(Sine<T>());
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// cosine(x) = cos(x)
|
|
template <typename T>
|
|
struct CosFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.unaryExpr(Cosine<T>());
|
|
}
|
|
};
|
|
|
|
// sine'(x) = cos(x)
|
|
template <typename T>
|
|
struct SinGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * x.unaryExpr(Cosine<T>());
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// sine(x) = sin(x)
|
|
template <typename T>
|
|
struct SinFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.unaryExpr(Sine<T>());
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct Acos {
|
|
HOSTDEVICE T operator()(const T& val) const { return acos(val); }
|
|
};
|
|
|
|
template <>
|
|
struct Acos<platform::float16> {
|
|
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
|
|
return platform::float16(acos(static_cast<float>(val)));
|
|
}
|
|
};
|
|
|
|
// Acos(x) = acos(x)
|
|
template <typename T>
|
|
struct AcosFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.unaryExpr(Acos<T>());
|
|
}
|
|
};
|
|
|
|
// acos'(x) = -1/sqrt(1-x^2)
|
|
template <typename T>
|
|
struct AcosGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) =
|
|
-dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct Asin {
|
|
HOSTDEVICE T operator()(const T& val) const { return asin(val); }
|
|
};
|
|
|
|
template <>
|
|
struct Asin<platform::float16> {
|
|
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
|
|
return platform::float16(asin(static_cast<float>(val)));
|
|
}
|
|
};
|
|
|
|
// Asin(x) = asin(x)
|
|
template <typename T>
|
|
struct AsinFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.unaryExpr(Asin<T>());
|
|
}
|
|
};
|
|
|
|
// asin'(x) = 1/sqrt(1-x^2)
|
|
template <typename T>
|
|
struct AsinGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) =
|
|
dout * static_cast<T>(1) / (static_cast<T>(1) - x.square()).sqrt();
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct Atan {
|
|
HOSTDEVICE T operator()(const T& val) const { return atan(val); }
|
|
};
|
|
|
|
template <>
|
|
struct Atan<platform::float16> {
|
|
HOSTDEVICE platform::float16 operator()(const platform::float16& val) const {
|
|
return platform::float16(atan(static_cast<float>(val)));
|
|
}
|
|
};
|
|
|
|
// Atan(x) = atan(x)
|
|
template <typename T>
|
|
struct AtanFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.unaryExpr(Atan<T>());
|
|
}
|
|
};
|
|
|
|
// atan'(x) = 1 / (1 + x^2)
|
|
template <typename T>
|
|
struct AtanGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * static_cast<T>(1) / (static_cast<T>(1) + x.square());
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// round(x) = [x]
|
|
template <typename T>
|
|
struct RoundFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.round();
|
|
}
|
|
};
|
|
|
|
// abs(x) = |x|
|
|
template <typename T>
|
|
struct AbsFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.abs();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct AbsGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * x.sign();
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepXOut; }
|
|
};
|
|
|
|
// reciprocal(x) = 1 / x
|
|
template <typename T>
|
|
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = static_cast<T>(1) / x;
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * static_cast<T>(-1) * out * out;
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
// log(x) = natural logarithm of x
|
|
template <typename T>
|
|
struct LogFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.log();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct LogGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * (static_cast<T>(1) / x);
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// square(x) = x^2
|
|
template <typename T>
|
|
struct SquareFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.square();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct SquareGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * static_cast<T>(2) * x;
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct BReluFunctor : public BaseActivationFunctor<T> {
|
|
float t_min;
|
|
float t_max;
|
|
|
|
// NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
|
|
// not polymorphism for speed.
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"t_min", &t_min}, {"t_max", &t_max}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) =
|
|
x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct BReluGradFunctor : public BaseActivationFunctor<T> {
|
|
float t_min;
|
|
float t_max;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"t_min", &t_min}, {"t_max", &t_max}};
|
|
}
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout *
|
|
((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
|
|
.template cast<T>();
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// relu6(x) = min(max(0, x), 6)
|
|
template <typename T>
|
|
struct Relu6Functor : public BaseActivationFunctor<T> {
|
|
float threshold;
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"threshold", &threshold}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) =
|
|
x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct Relu6GradFunctor : public BaseActivationFunctor<T> {
|
|
float threshold;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"threshold", &threshold}};
|
|
}
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) =
|
|
dout *
|
|
((out > static_cast<T>(0)) * (out < static_cast<T>(threshold)))
|
|
.template cast<T>();
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
// softplus(x) = log(1 + exp(x))
|
|
// When x is a very large positive number, exp(x) may explode to inf,
|
|
// Using trick below for numerical stability
|
|
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
|
|
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
|
|
template <typename T>
|
|
struct SoftplusFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) {
|
|
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
|
|
out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
|
|
}
|
|
};
|
|
|
|
// d(softplus(x))/dx = exp(x) / (1 + exp(x))
|
|
// For numerical stability:
|
|
// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) +
|
|
// exp(x - max(x, 0)))
|
|
template <typename T>
|
|
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
|
|
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
|
|
dx.device(d) =
|
|
dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// softsign(x) = x / (1 + |x|)
|
|
template <typename T>
|
|
struct SoftsignFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) {
|
|
out.device(d) = x / (static_cast<T>(1) + x.abs());
|
|
}
|
|
};
|
|
|
|
// d(softsign(x))/dx = 1 / (1 + |x|)^2
|
|
// Taken from https://en.wikipedia.org/wiki/Activation_function
|
|
template <typename T>
|
|
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
|
|
dx.device(d) =
|
|
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct SoftReluFunctor : public BaseActivationFunctor<T> {
|
|
float threshold;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"threshold", &threshold}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
auto tmp = static_cast<T>(threshold);
|
|
auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
|
|
out.device(d) = (static_cast<T>(1) + temp.exp()).log();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
|
|
float threshold;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"threshold", &threshold}};
|
|
}
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
auto tmp = static_cast<T>(threshold);
|
|
auto temp = ((out > -tmp) * (out < tmp)).template cast<T>().eval();
|
|
dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct LeakyReluFunctor : public BaseActivationFunctor<T> {
|
|
float alpha;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"alpha", &alpha}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
|
|
float alpha;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"alpha", &alpha}};
|
|
}
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
auto temp1 = static_cast<T>(alpha) *
|
|
(x < static_cast<T>(0)).template cast<T>().eval();
|
|
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
|
|
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepXOut; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct ELUFunctor : public BaseActivationFunctor<T> {
|
|
float alpha;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"alpha", &alpha}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.cwiseMax(static_cast<T>(0)) +
|
|
(static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
|
|
.cwiseMin(static_cast<T>(0));
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct ELUGradFunctor : public BaseActivationFunctor<T> {
|
|
float alpha;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"alpha", &alpha}};
|
|
}
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
|
|
dout * static_cast<T>(alpha) * x.exp() *
|
|
(x < static_cast<T>(0)).template cast<T>();
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
|
|
template <typename T>
|
|
struct PowFunctor : public BaseActivationFunctor<T> {
|
|
float factor;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"factor", &factor}};
|
|
}
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x.pow(static_cast<T>(factor));
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct PowGradFunctor : public BaseActivationFunctor<T> {
|
|
float factor;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"factor", &factor}};
|
|
}
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout * static_cast<T>(factor) *
|
|
x.pow(static_cast<T>(factor) - static_cast<T>(1));
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct STanhFunctor : public BaseActivationFunctor<T> {
|
|
float scale_a;
|
|
float scale_b;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) =
|
|
static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct STanhGradFunctor : public BaseActivationFunctor<T> {
|
|
float scale_a;
|
|
float scale_b;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
auto a = static_cast<T>(scale_a);
|
|
auto b = static_cast<T>(scale_b);
|
|
auto temp = (a * x).tanh() * (a * x).tanh();
|
|
dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
|
|
float threshold;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"threshold", &threshold}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
auto th = static_cast<T>(threshold);
|
|
out.device(d) = (x > th).template cast<T>() * x;
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
|
|
float threshold;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"threshold", &threshold}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
auto th = static_cast<T>(threshold);
|
|
dx.device(d) = dout * (x > th).template cast<T>();
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
|
|
float slope;
|
|
float offset;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"slope", &slope}, {"offset", &offset}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
|
|
out.device(d) =
|
|
temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
|
|
float slope;
|
|
float offset;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"slope", &slope}, {"offset", &offset}};
|
|
}
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
dx.device(d) = dout *
|
|
((out > static_cast<T>(0)) * (out < static_cast<T>(1)))
|
|
.template cast<T>() *
|
|
static_cast<T>(slope);
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct SwishFunctor : public BaseActivationFunctor<T> {
|
|
float beta;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"beta", &beta}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
void operator()(Device d, X x, Out out) const {
|
|
out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct SwishGradFunctor : public BaseActivationFunctor<T> {
|
|
float beta;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"beta", &beta}};
|
|
}
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
typename dX>
|
|
void operator()(Device d, X x, Out fake_out, dOut dout, dX dx) const {
|
|
auto temp1 = static_cast<T>(1) /
|
|
(static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
|
|
auto out = x * temp1;
|
|
auto temp2 = temp1 * (static_cast<T>(1) - (static_cast<T>(beta) * out));
|
|
dx.device(d) = dout * ((static_cast<T>(beta) * out) + temp2);
|
|
}
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
/*
|
|
* in arguments: x, out, ddx
|
|
* out arguments: ddout, dout, dx
|
|
*/
|
|
template <ActBwdOpFwdDeps kDepValue>
|
|
inline void ExtractActivationDoubleGradTensor(
|
|
const framework::ExecutionContext& ctx, const framework::Tensor** X,
|
|
const framework::Tensor** Out, const framework::Tensor** ddX,
|
|
framework::Tensor** dX, framework::Tensor** dOut,
|
|
framework::Tensor** ddOut) {
|
|
auto ddx_var = ctx.InputVar("DDX");
|
|
auto ddo_var = ctx.OutputVar("DDOut");
|
|
PADDLE_ENFORCE(ddx_var != nullptr,
|
|
"Cannot get input Variable Out, variable name = %s",
|
|
ctx.op().Input("DDX"));
|
|
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) {
|
|
*ddX = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*ddx_var);
|
|
if (ddo_var) {
|
|
*ddOut = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
|
|
ddo_var);
|
|
}
|
|
} else {
|
|
*ddX = ctx.Input<framework::Tensor>("DDX");
|
|
if (ddo_var) {
|
|
*ddOut = ctx.Output<framework::Tensor>("DDOut");
|
|
}
|
|
}
|
|
PADDLE_ENFORCE(*ddX != nullptr,
|
|
"Cannot get output tensor DDX, variable name = %s",
|
|
ctx.op().Output("DDX"));
|
|
|
|
if (static_cast<int>(kDepValue) & static_cast<int>(kDepX)) {
|
|
auto x_var = ctx.InputVar("X");
|
|
PADDLE_ENFORCE(x_var != nullptr,
|
|
"Cannot get input Variable Out, variable name = %s",
|
|
ctx.op().Input("X"));
|
|
auto dx_var = ctx.OutputVar("DX");
|
|
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) {
|
|
*X = paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_var);
|
|
if (dx_var) {
|
|
*dX = paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
|
|
dx_var);
|
|
}
|
|
} else {
|
|
*X = ctx.Input<framework::Tensor>("X");
|
|
if (dx_var) {
|
|
*dX = ctx.Output<framework::Tensor>("DX");
|
|
}
|
|
}
|
|
} else {
|
|
VLOG(10) << "Inplace activation of Op: " << ctx.op().Type();
|
|
*X = *ddX;
|
|
}
|
|
if (static_cast<int>(kDepValue) & static_cast<int>(kDepOut)) {
|
|
auto out_var = ctx.InputVar("Out");
|
|
PADDLE_ENFORCE(out_var != nullptr,
|
|
"Cannot get input tensor Out, variable name = %s",
|
|
ctx.op().Input("Out"));
|
|
auto dout_var = ctx.OutputVar("DOut");
|
|
if (CanBeUsedBySelectedRows.count(ctx.op().Type())) {
|
|
*Out =
|
|
paddle::framework::GetLoDTensorOrSelectedRowsValueFromVar(*out_var);
|
|
if (dout_var) {
|
|
*dOut =
|
|
paddle::framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(
|
|
dout_var);
|
|
}
|
|
} else {
|
|
*Out = ctx.Input<framework::Tensor>("Out");
|
|
if (dout_var) {
|
|
*dOut = ctx.Output<framework::Tensor>("DOut");
|
|
}
|
|
}
|
|
} else {
|
|
VLOG(10) << "Inplace activation of Op: " << ctx.op().Type();
|
|
*Out = *ddX;
|
|
}
|
|
}
|
|
|
|
template <typename DeviceContext, typename Functor>
|
|
class ActivationDoubleGradKernel
|
|
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
public:
|
|
using T = typename Functor::ELEMENT_TYPE;
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
const framework::Tensor *X, *Out, *ddX;
|
|
X = Out = ddX = nullptr;
|
|
framework::Tensor *ddOut, *dOut, *dX;
|
|
ddOut = dOut = dX = nullptr;
|
|
|
|
ExtractActivationDoubleGradTensor<Functor::FwdDeps()>(ctx, &X, &Out, &ddX,
|
|
&dX, &dOut, &ddOut);
|
|
|
|
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
|
|
if (dOut) dOut->mutable_data<T>(ctx.GetPlace());
|
|
if (dX) dX->mutable_data<T>(Out->dims(), ctx.GetPlace());
|
|
|
|
auto& place = ctx.template device_context<DeviceContext>();
|
|
|
|
Functor functor;
|
|
auto attrs = functor.GetAttrs();
|
|
for (auto& attr : attrs) {
|
|
*attr.second = ctx.Attr<float>(attr.first);
|
|
}
|
|
functor(place, X, Out, ddX, ddOut, dOut, dX);
|
|
}
|
|
};
|
|
|
|
template <typename T>
|
|
struct ReluGradGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device>
|
|
void operator()(const Device& dev, const framework::Tensor* X,
|
|
const framework::Tensor* Out, const framework::Tensor* ddX,
|
|
framework::Tensor* ddOut, framework::Tensor* dOut,
|
|
framework::Tensor* dX) const {
|
|
auto* d = dev.eigen_device();
|
|
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
if (ddOut) {
|
|
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
|
|
ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>();
|
|
}
|
|
}
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
|
|
float alpha;
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
return {{"alpha", &alpha}};
|
|
}
|
|
template <typename Device>
|
|
void operator()(const Device& dev, const framework::Tensor* X,
|
|
const framework::Tensor* Out, const framework::Tensor* ddX,
|
|
framework::Tensor* ddOut, framework::Tensor* dOut,
|
|
framework::Tensor* dX) const {
|
|
auto* d = dev.eigen_device();
|
|
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
|
|
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
|
|
if (ddOut) {
|
|
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
|
|
ddout.device(*d) = ddx *
|
|
((x >= static_cast<T>(0)).template cast<T>().eval() +
|
|
static_cast<T>(alpha) *
|
|
(x < static_cast<T>(0)).template cast<T>().eval())
|
|
.template cast<T>();
|
|
}
|
|
}
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device>
|
|
void operator()(const Device& dev, const framework::Tensor* Out,
|
|
const framework::Tensor* ddX, framework::Tensor* ddOut,
|
|
framework::Tensor* dOut, const framework::Tensor* dX) const {
|
|
auto* d = dev.eigen_device();
|
|
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
if (ddOut) {
|
|
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
|
|
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
|
|
}
|
|
if (dOut) {
|
|
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
|
|
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
|
|
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
|
|
}
|
|
}
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
};
|
|
|
|
template <typename T>
|
|
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
|
|
template <typename Device>
|
|
void operator()(const Device& dev, const framework::Tensor* X,
|
|
const framework::Tensor* ddX, framework::Tensor* ddOut,
|
|
const framework::Tensor* dOut, framework::Tensor* dX) const {
|
|
auto* d = dev.eigen_device();
|
|
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
|
|
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
|
|
if (ddOut) {
|
|
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
|
|
ddout.device(*d) = ddx * static_cast<T>(2) * x;
|
|
}
|
|
if (dX) {
|
|
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
|
|
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
|
|
dx.device(*d) = ddx * static_cast<T>(2) * dout;
|
|
}
|
|
}
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
};
|
|
|
|
// TODO(dengkaipeng): double gradient calculation for Square/Sqrt need
|
|
// DOut(dy) as input(not output), tensor extraction is different from
|
|
// others. Impliment extraction kernel seperately here.
|
|
inline void ExtractDoubleGradTensorWithInputDOut(
|
|
const framework::ExecutionContext& ctx, const framework::Tensor** X,
|
|
const framework::Tensor** ddX, framework::Tensor** dX,
|
|
const framework::Tensor** dOut, framework::Tensor** ddOut) {
|
|
// extract ddX(output), ddOut(input)
|
|
auto ddx_var = ctx.InputVar("DDX");
|
|
auto ddo_var = ctx.OutputVar("DDOut");
|
|
PADDLE_ENFORCE(ddx_var != nullptr,
|
|
"Cannot get input Variable Out, variable name = %s",
|
|
ctx.op().Input("DDX"));
|
|
*ddX = ctx.Input<framework::Tensor>("DDX");
|
|
if (ddo_var) {
|
|
*ddOut = ctx.Output<framework::Tensor>("DDOut");
|
|
}
|
|
PADDLE_ENFORCE(*ddX != nullptr,
|
|
"Cannot get output tensor DDX, variable name = %s",
|
|
ctx.op().Output("DDX"));
|
|
|
|
// extract x(input), dx(output)
|
|
auto x_var = ctx.InputVar("X");
|
|
PADDLE_ENFORCE(x_var != nullptr,
|
|
"Cannot get input Variable Out, variable name = %s",
|
|
ctx.op().Input("X"));
|
|
auto dx_var = ctx.OutputVar("DX");
|
|
*X = ctx.Input<framework::Tensor>("X");
|
|
if (dx_var) {
|
|
*dX = ctx.Output<framework::Tensor>("DX");
|
|
}
|
|
|
|
// extract dOut(input)
|
|
auto dout_var = ctx.InputVar("DOut");
|
|
if (dout_var) {
|
|
*dOut = ctx.Input<framework::Tensor>("DOut");
|
|
}
|
|
}
|
|
|
|
template <typename DeviceContext, typename Functor>
|
|
class SquareDoubleGradKernel
|
|
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
public:
|
|
using T = typename Functor::ELEMENT_TYPE;
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
const framework::Tensor *X, *ddX, *dOut;
|
|
X = ddX = dOut = nullptr;
|
|
framework::Tensor *dX, *ddOut;
|
|
dX = ddOut = nullptr;
|
|
|
|
ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);
|
|
|
|
if (dX) dX->mutable_data<T>(X->dims(), ctx.GetPlace());
|
|
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
|
|
|
|
auto& place = ctx.template device_context<DeviceContext>();
|
|
|
|
Functor functor;
|
|
functor(place, X, ddX, ddOut, dOut, dX);
|
|
}
|
|
};
|
|
|
|
template <typename DeviceContext, typename Functor>
|
|
class SqrtDoubleGradKernel
|
|
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
public:
|
|
using T = typename Functor::ELEMENT_TYPE;
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
const framework::Tensor *Out, *dX, *ddX;
|
|
Out = dX = ddX = nullptr;
|
|
framework::Tensor *ddOut, *dOut;
|
|
ddOut = dOut = nullptr;
|
|
|
|
// extract ddx(input), ddout(output)
|
|
auto ddx_var = ctx.InputVar("DDX");
|
|
auto ddo_var = ctx.OutputVar("DDOut");
|
|
PADDLE_ENFORCE(ddx_var != nullptr,
|
|
"Cannot get input Variable DDX, variable name = %s",
|
|
ctx.op().Input("DDX"));
|
|
ddX = ctx.Input<framework::Tensor>("DDX");
|
|
if (ddo_var) {
|
|
ddOut = ctx.Output<framework::Tensor>("DDOut");
|
|
}
|
|
PADDLE_ENFORCE(ddX != nullptr,
|
|
"Cannot get input Variable DDX, variable name = %s",
|
|
ctx.op().Input("DDX"));
|
|
|
|
// extract out(input), dout(output)
|
|
auto out_var = ctx.InputVar("Out");
|
|
PADDLE_ENFORCE(out_var != nullptr,
|
|
"Cannot get input Variable Out, variable name = %s",
|
|
ctx.op().Input("Out"));
|
|
auto dout_var = ctx.OutputVar("DOut");
|
|
Out = ctx.Input<framework::Tensor>("Out");
|
|
if (dout_var) {
|
|
dOut = ctx.Output<framework::Tensor>("DOut");
|
|
}
|
|
|
|
// extract dx(input)
|
|
auto dx_var = ctx.InputVar("DX");
|
|
PADDLE_ENFORCE(dx_var != nullptr,
|
|
"Cannot get input Variable DX, variable name = %s",
|
|
ctx.op().Input("DX"));
|
|
if (dx_var) {
|
|
dX = ctx.Input<framework::Tensor>("DX");
|
|
}
|
|
|
|
if (dOut) dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
|
|
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
|
|
|
|
auto& place = ctx.template device_context<DeviceContext>();
|
|
|
|
Functor functor;
|
|
functor(place, Out, ddX, ddOut, dOut, dX);
|
|
}
|
|
};
|
|
|
|
} // namespace operators
|
|
} // namespace paddle
|
|
|
|
#define FOR_EACH_ACTIVATION_OP(__macro) \
|
|
__macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
|
|
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
|
|
__macro(exp, Exp, ExpFunctor, ExpGradFunctor); \
|
|
__macro(gelu, Gelu, GeluFunctor, GeluGradFunctor); \
|
|
__macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
|
|
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
|
|
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
|
|
__macro(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor); \
|
|
__macro(abs, Abs, AbsFunctor, AbsGradFunctor); \
|
|
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
|
|
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
|
|
__macro(cos, Cos, CosFunctor, CosGradFunctor); \
|
|
__macro(acos, Acos, AcosFunctor, AcosGradFunctor); \
|
|
__macro(sin, Sin, SinFunctor, SinGradFunctor); \
|
|
__macro(asin, Asin, AsinFunctor, AsinGradFunctor); \
|
|
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
|
|
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
|
|
__macro(log, Log, LogFunctor, LogGradFunctor); \
|
|
__macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \
|
|
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
|
|
__macro(pow, Pow, PowFunctor, PowGradFunctor); \
|
|
__macro(stanh, STanh, STanhFunctor, STanhGradFunctor); \
|
|
__macro(softplus, Softplus, SoftplusFunctor, SoftplusGradFunctor); \
|
|
__macro(softsign, Softsign, SoftsignFunctor, SoftsignGradFunctor); \
|
|
__macro(relu6, Relu6, Relu6Functor, Relu6GradFunctor); \
|
|
__macro(tanh_shrink, TanhShrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
|
|
__macro(elu, ELU, ELUFunctor, ELUGradFunctor); \
|
|
__macro(hard_shrink, HardShrink, HardShrinkFunctor, HardShrinkGradFunctor); \
|
|
__macro(hard_sigmoid, HardSigmoid, HardSigmoidFunctor, \
|
|
HardSigmoidGradFunctor); \
|
|
__macro(swish, Swish, SwishFunctor, SwishGradFunctor); \
|
|
__macro(thresholded_relu, ThresholdedRelu, ThresholdedReluFunctor, \
|
|
ThresholdedReluGradFunctor);
|