|
|
|
@ -407,6 +407,33 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// softplus(x) = log(1 + exp(x))
|
|
|
|
|
// When x is a very large positive number, exp(x) may explode to inf,
|
|
|
|
|
// Using trick below for numerical stability
|
|
|
|
|
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
|
|
|
|
|
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SoftplusFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device, typename X, typename Y>
|
|
|
|
|
void operator()(Device d, X x, Y y) {
|
|
|
|
|
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
|
|
|
|
|
y.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// d(softplus(x))/dx = exp(x) / (1 + exp(x))
|
|
|
|
|
// For numerical stability:
|
|
|
|
|
// d(softplus(x))/dx = exp(x - max(x, 0)) / (exp(-max(x, 0)) +
|
|
|
|
|
// exp(x - max(x, 0)))
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device, typename X, typename Y, typename dY, typename dX>
|
|
|
|
|
void operator()(Device d, X x, Y y, dY dy, dX dx) {
|
|
|
|
|
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
|
|
|
|
|
dx.device(d) = dy * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// softsign(x) = x / (1 + |x|)
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SoftsignFunctor : public BaseActivationFunctor<T> {
|
|
|
|
@ -582,6 +609,7 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
|
|
|
|
|
__macro(pow, PowFunctor, PowGradFunctor); \
|
|
|
|
|
__macro(stanh, STanhFunctor, STanhGradFunctor); \
|
|
|
|
|
__macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \
|
|
|
|
|
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
|
|
|
|
|
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
|
|
|
|
|
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
|
|
|
|
|