|
|
|
@ -146,6 +146,24 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// tanhshrink(x) = x - tanh(x)
|
|
|
|
|
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device, typename X, typename Y>
|
|
|
|
|
void operator()(Device d, X x, Y y) const {
|
|
|
|
|
y.device(d) = x - x.tanh();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device, typename X, typename Y, typename dY, typename dX>
|
|
|
|
|
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
|
|
|
|
|
dx.device(d) = dy * (x.tanh() * x.tanh());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// sqrt(x) = x^(1/2)
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SqrtFunctor : public BaseActivationFunctor<T> {
|
|
|
|
@ -407,4 +425,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
__macro(pow, PowFunctor, PowGradFunctor); \
|
|
|
|
|
__macro(stanh, STanhFunctor, STanhGradFunctor); \
|
|
|
|
|
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
|
|
|
|
|
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor)
|
|
|
|
|
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
|
|
|
|
|
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor)
|
|
|
|
|