|
|
|
@ -1073,8 +1073,8 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
|
auto temp1 =
|
|
|
|
|
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>();
|
|
|
|
|
auto temp2 = (out >= static_cast<T>(0)).template cast<T>();
|
|
|
|
|
static_cast<T>(alpha) * (out <= static_cast<T>(0)).template cast<T>();
|
|
|
|
|
auto temp2 = (out > static_cast<T>(0)).template cast<T>();
|
|
|
|
|
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1418,10 +1418,10 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
|
|
|
|
|
ddout.device(*d) =
|
|
|
|
|
ddx *
|
|
|
|
|
((out >= static_cast<T>(0)).template cast<T>() +
|
|
|
|
|
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>())
|
|
|
|
|
ddout.device(*d) = ddx *
|
|
|
|
|
((out > static_cast<T>(0)).template cast<T>() +
|
|
|
|
|
static_cast<T>(alpha) *
|
|
|
|
|
(out <= static_cast<T>(0)).template cast<T>())
|
|
|
|
|
.template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|