|
|
@ -1134,9 +1134,20 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
typename dX>
|
|
|
|
typename dX>
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
|
|
|
|
auto temp_a_pos = static_cast<T>(alpha > 0);
|
|
|
|
dout * static_cast<T>(alpha) * x.exp() *
|
|
|
|
auto temp_a_neg = static_cast<T>(alpha <= 0);
|
|
|
|
(x <= static_cast<T>(0)).template cast<T>();
|
|
|
|
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
|
|
|
|
|
|
|
|
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// dx = dout, if alpha > 0 and x > 0
|
|
|
|
|
|
|
|
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
|
|
|
|
|
|
|
|
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
|
|
|
|
|
|
|
|
// dx = 0, if alpha <= 0 and x <=0
|
|
|
|
|
|
|
|
dx.device(d) =
|
|
|
|
|
|
|
|
dout * temp_a_pos * temp_x_pos +
|
|
|
|
|
|
|
|
dout * static_cast<T>(alpha) * x.exp() * temp_a_pos * temp_x_neg +
|
|
|
|
|
|
|
|
dout * (static_cast<T>(1) + static_cast<T>(alpha) * x.exp()) *
|
|
|
|
|
|
|
|
temp_a_neg * temp_x_pos;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
|
|