|
|
|
@ -616,30 +616,63 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
float slope;
|
|
|
|
|
float offset;
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"slope", &slope}, {"offset", &offset}};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Device, typename X, typename Y>
|
|
|
|
|
void operator()(Device d, X x, Y y) const {
|
|
|
|
|
auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
|
|
|
|
|
y.device(d) = temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
float slope;
|
|
|
|
|
float offset;
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"slope", &slope}, {"offset", &offset}};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Device, typename X, typename Y, typename dY, typename dX>
|
|
|
|
|
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
|
|
|
|
|
dx.device(d) =
|
|
|
|
|
dy *
|
|
|
|
|
((y > static_cast<T>(0)) * (y < static_cast<T>(1))).template cast<T>() *
|
|
|
|
|
static_cast<T>(slope);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
|
|
|
|
|
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
|
|
|
|
|
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
|
|
|
|
|
__macro(exp, ExpFunctor, ExpGradFunctor); \
|
|
|
|
|
__macro(relu, ReluFunctor, ReluGradFunctor); \
|
|
|
|
|
__macro(tanh, TanhFunctor, TanhGradFunctor); \
|
|
|
|
|
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
|
|
|
|
|
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
|
|
|
|
|
__macro(abs, AbsFunctor, AbsGradFunctor); \
|
|
|
|
|
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
|
|
|
|
|
__macro(log, LogFunctor, LogGradFunctor); \
|
|
|
|
|
__macro(square, SquareFunctor, SquareGradFunctor); \
|
|
|
|
|
__macro(brelu, BReluFunctor, BReluGradFunctor); \
|
|
|
|
|
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
|
|
|
|
|
__macro(pow, PowFunctor, PowGradFunctor); \
|
|
|
|
|
__macro(stanh, STanhFunctor, STanhGradFunctor); \
|
|
|
|
|
__macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \
|
|
|
|
|
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
|
|
|
|
|
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
|
|
|
|
|
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
|
|
|
|
|
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
|
|
|
|
|
__macro(elu, ELUFunctor, ELUGradFunctor); \
|
|
|
|
|
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \
|
|
|
|
|
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
|
|
|
|
|
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
|
|
|
|
|
__macro(logsigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
|
|
|
|
|
__macro(exp, ExpFunctor, ExpGradFunctor); \
|
|
|
|
|
__macro(relu, ReluFunctor, ReluGradFunctor); \
|
|
|
|
|
__macro(tanh, TanhFunctor, TanhGradFunctor); \
|
|
|
|
|
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
|
|
|
|
|
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
|
|
|
|
|
__macro(abs, AbsFunctor, AbsGradFunctor); \
|
|
|
|
|
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
|
|
|
|
|
__macro(log, LogFunctor, LogGradFunctor); \
|
|
|
|
|
__macro(square, SquareFunctor, SquareGradFunctor); \
|
|
|
|
|
__macro(brelu, BReluFunctor, BReluGradFunctor); \
|
|
|
|
|
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
|
|
|
|
|
__macro(pow, PowFunctor, PowGradFunctor); \
|
|
|
|
|
__macro(stanh, STanhFunctor, STanhGradFunctor); \
|
|
|
|
|
__macro(softplus, SoftplusFunctor, SoftplusGradFunctor); \
|
|
|
|
|
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor); \
|
|
|
|
|
__macro(relu6, Relu6Functor, Relu6GradFunctor); \
|
|
|
|
|
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
|
|
|
|
|
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
|
|
|
|
|
__macro(elu, ELUFunctor, ELUGradFunctor); \
|
|
|
|
|
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \
|
|
|
|
|
__macro(hard_sigmoid, HardSigmoidFunctor, HardSigmoidGradFunctor); \
|
|
|
|
|
__macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);
|
|
|
|
|