|
|
|
@ -388,9 +388,9 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
}
|
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
|
|
|
void operator()(Device d, X x, Out out) const {
|
|
|
|
|
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>();
|
|
|
|
|
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>();
|
|
|
|
|
out.device(d) = x * (temp1 + temp2);
|
|
|
|
|
auto temp1 = x < static_cast<T>(threshold * -1.f);
|
|
|
|
|
auto temp2 = x > static_cast<T>(threshold);
|
|
|
|
|
out.device(d) = x * (temp1 + temp2 > 0).template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -405,9 +405,9 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
|
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>();
|
|
|
|
|
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>();
|
|
|
|
|
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
|
|
|
|
|
auto temp1 = x < static_cast<T>(threshold * -1.f);
|
|
|
|
|
auto temp2 = x > static_cast<T>(threshold);
|
|
|
|
|
dx.device(d) = dout * (temp1 + temp2 > 0).template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
|
|
|