|
|
@ -477,7 +477,7 @@ struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
typename dX>
|
|
|
|
typename dX>
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
dx.device(d) = static_cast<T>(-0.5) * dout * out.pow(3);
|
|
|
|
dx.device(d) = static_cast<T>(-0.5) * dout * out * out * out;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|