|
|
|
@ -199,6 +199,39 @@ struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// tanhshrink(x) = x - tanh(x)
|
|
|
|
|
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct HardShrinkFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
float threshold;
|
|
|
|
|
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"threshold", &threshold}};
|
|
|
|
|
}
|
|
|
|
|
template <typename Device, typename X, typename Y>
|
|
|
|
|
void operator()(Device d, X x, Y y) const {
|
|
|
|
|
auto temp1 = (x < (threshold * -1)).template cast<T>().eval();
|
|
|
|
|
auto temp2 = (x > threshold).template cast<T>().eval();
|
|
|
|
|
y.device(d) = x * (temp1 + temp2);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
float threshold;
|
|
|
|
|
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"threshold", &threshold}};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Device, typename X, typename Y, typename dY, typename dX>
|
|
|
|
|
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
|
|
|
|
|
auto temp1 = (x < (threshold * -1)).template cast<T>().eval();
|
|
|
|
|
auto temp2 = (x > threshold).template cast<T>().eval();
|
|
|
|
|
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// softshrink(x) = x - lambda, if x > lambda; x + lambda, if x < lambda; 0
|
|
|
|
|
// otherwise
|
|
|
|
|
template <typename T>
|
|
|
|
@ -351,8 +384,6 @@ template <typename T>
|
|
|
|
|
struct Relu6Functor : public BaseActivationFunctor<T> {
|
|
|
|
|
float threshold;
|
|
|
|
|
|
|
|
|
|
// NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
|
|
|
|
|
// not polymorphism for speed.
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"threshold", &threshold}};
|
|
|
|
|
}
|
|
|
|
@ -555,4 +586,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
|
|
|
|
|
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
|
|
|
|
|
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
|
|
|
|
|
__macro(elu, ELUFunctor, ELUGradFunctor)
|
|
|
|
|
__macro(elu, ELUFunctor, ELUGradFunctor); \
|
|
|
|
|
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor)
|
|
|
|
|