|
|
|
@ -590,6 +590,32 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
float threshold;
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"threshold", &threshold}};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Device, typename X, typename Y>
|
|
|
|
|
void operator()(Device d, X x, Y y) const {
|
|
|
|
|
y.device(d) = (x > static_cast<T>(threshold)).template cast<T>() * x;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
float threshold;
|
|
|
|
|
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
|
|
|
|
|
return {{"threshold", &threshold}};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename Device, typename X, typename Y, typename dY, typename dX>
|
|
|
|
|
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
|
|
|
|
|
dx.device(d) = dy * (x > static_cast<T>(threshold)).template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -615,4 +641,5 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
__macro(leaky_relu, LeakyReluFunctor, LeakyReluGradFunctor); \
|
|
|
|
|
__macro(tanh_shrink, TanhShrinkFunctor, TanhShrinkGradFunctor); \
|
|
|
|
|
__macro(elu, ELUFunctor, ELUGradFunctor); \
|
|
|
|
|
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor)
|
|
|
|
|
__macro(hard_shrink, HardShrinkFunctor, HardShrinkGradFunctor); \
|
|
|
|
|
__macro(thresholded_relu, ThresholdedReluFunctor, ThresholdedReluGradFunctor);
|
|
|
|
|