fix elu grad whne alpha less then zero, test=develop (#26543)

revert-26856-strategy_example2
Qi Li 5 years ago committed by GitHub
parent 786373ba29
commit 6f69fbc8ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1134,9 +1134,20 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
dout * static_cast<T>(alpha) * x.exp() *
(x <= static_cast<T>(0)).template cast<T>();
auto temp_a_pos = static_cast<T>(alpha > 0);
auto temp_a_neg = static_cast<T>(alpha <= 0);
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();
// dx = dout, if alpha > 0 and x > 0
// dx = dout * alpha * x.exp(), if alpha > 0 and x <= 0
// dx = dout * (1 + alpha * x.exp()), if alpha <= 0 and x > 0
// dx = 0, if alpha <= 0 and x <=0
dx.device(d) =
dout * temp_a_pos * temp_x_pos +
dout * static_cast<T>(alpha) * x.exp() * temp_a_pos * temp_x_neg +
dout * (static_cast<T>(1) + static_cast<T>(alpha) * x.exp()) *
temp_a_neg * temp_x_pos;
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }

Loading…
Cancel
Save