|
|
|
@ -16,6 +16,11 @@ limitations under the License. */
|
|
|
|
|
#include <utility>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
#include <cmath>
|
|
|
|
|
#ifndef _USE_MATH_DEFINES
|
|
|
|
|
#define _USE_MATH_DEFINES
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/detail/safe_ref.h"
|
|
|
|
@ -212,6 +217,31 @@ struct ReluGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct GeluFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
|
|
|
void operator()(Device d, X x, Out out) const {
|
|
|
|
|
auto temp =
|
|
|
|
|
((x * static_cast<T>(M_SQRT1_2)).erf()).template cast<T>().eval();
|
|
|
|
|
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct GeluGradFunctor : BaseActivationFunctor<T> {
|
|
|
|
|
bool Inplace() const { return IsInplace("gelu"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
|
auto temp = (static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
|
|
|
|
|
((-static_cast<T>(0.5) * x.square()).exp()))
|
|
|
|
|
.template cast<T>()
|
|
|
|
|
.eval();
|
|
|
|
|
dx.device(d) = dout * (out / x + temp);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct TanhFunctor : public BaseActivationFunctor<T> {
|
|
|
|
@ -877,6 +907,7 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
|
|
|
|
|
__macro(exp, ExpFunctor, ExpGradFunctor); \
|
|
|
|
|
__macro(relu, ReluFunctor, ReluGradFunctor); \
|
|
|
|
|
__macro(gelu, GeluFunctor, GeluGradFunctor); \
|
|
|
|
|
__macro(tanh, TanhFunctor, TanhGradFunctor); \
|
|
|
|
|
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
|
|
|
|
|
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
|
|
|
|
|