|
|
|
@ -95,6 +95,41 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Originally: logsigmoid(x) = -log (1 + exp(-x))
|
|
|
|
|
// For numerical stability, we can use the log-sum-exp trick:
|
|
|
|
|
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
|
|
|
|
|
// We can rewrite the above equation as:
|
|
|
|
|
// y = -log( exp(0) + exp(-x)) [since exp(0) = 1]
|
|
|
|
|
// = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
|
|
|
|
|
// = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
|
|
|
|
|
// max(-x, 0)))
|
|
|
|
|
// = -log( exp(max(-x, 0)) * (exp(-max(-x, 0)) + exp(-x - max(-x, 0))))
|
|
|
|
|
// = -log( exp(max(-x, 0)) - log(exp(-max(-x, 0)) + exp(-x - max(-x, 0)))
|
|
|
|
|
//
|
|
|
|
|
// Hence, logsigmoid(x) = - (max(-x, 0) + log(exp(-max(-x, 0))
|
|
|
|
|
// + exp(-x - max(-x, 0))))
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device, typename X, typename Y>
|
|
|
|
|
void operator()(Device d, X x, Y y) const {
|
|
|
|
|
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
|
|
|
|
|
y.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Originally: f' = exp(-x) / (1 + exp(-x))
|
|
|
|
|
// For numerical stability: f' = exp(-x - max(-x, 0)) / (exp(-max(-x, 0)) +
|
|
|
|
|
// exp(-x - max(-x, 0)))
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device, typename X, typename Y, typename dY, typename dX>
|
|
|
|
|
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
|
|
|
|
|
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
|
|
|
|
|
dx.device(d) =
|
|
|
|
|
dy * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// exp(x) = e^x
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ExpFunctor : public BaseActivationFunctor<T> {
|
|
|
|
@ -164,6 +199,37 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < lambda; 0
|
|
|
|
|
// otherwise
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
float lambda;
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"lambda", &lambda}};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Device, typename X, typename Y>
|
|
|
|
|
void operator()(Device d, X x, Y y) const {
|
|
|
|
|
auto temp1 = (x > lambda).template cast<T>().eval();
|
|
|
|
|
auto temp2 = (x < -lambda).template cast<T>().eval();
|
|
|
|
|
y.device(d) = temp1 * (x - lambda) + temp2 * (x + lambda);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
float lambda;
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"lambda", &lambda}};
|
|
|
|
|
}
|
|
|
|
|
template <typename Device, typename X, typename Y, typename dY, typename dX>
|
|
|
|
|
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
|
|
|
|
|
auto temp1 = (x > lambda).template cast<T>().eval();
|
|
|
|
|
auto temp2 = (x < -lambda).template cast<T>().eval();
|
|
|
|
|
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// sqrt(x) = x^(1/2)
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SqrtFunctor : public BaseActivationFunctor<T> {
|
|
|
|
@ -471,9 +537,11 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
|
|
|
|
|
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
|
|
|
|
|
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
|
|
|
|
|
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
|
|
|
|
|
__macro(exp, ExpFunctor, ExpGradFunctor); \
|
|
|
|
|
__macro(relu, ReluFunctor, ReluGradFunctor); \
|
|
|
|
|
__macro(tanh, TanhFunctor, TanhGradFunctor); \
|
|
|
|
|
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
|
|
|
|
|
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
|
|
|
|
|
__macro(abs, AbsFunctor, AbsGradFunctor); \
|
|
|
|
|
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
|
|
|
|
@ -484,7 +552,7 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
__macro(pow, PowFunctor, PowGradFunctor); \
|
|
|
|
|
__macro(stanh, STanhFunctor, STanhGradFunctor); \
|
|
|
|
|
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
|
|
|
|
|
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
|
|
|
|
|
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
|
|
|
|
|
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
|
|
|
|
|
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
|
|
|
|
|
__macro(elu, ELUFunctor, ELUGradFunctor)
|
|
|
|
|