Fix the gelu backward to avoid nan (#14857)

* Fix the gelu backward to avoid nan

test=develop

* Remove unnecessary calls

test=develop
ce_debug
Yibing Liu 7 years ago committed by GitHub
parent 322bb8d5c5
commit 6951ef9a55
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -301,23 +301,22 @@ template <typename T>
struct GeluFunctor : public BaseActivationFunctor<T> { struct GeluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out> template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const { void operator()(Device d, X x, Out out) const {
auto temp = auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
((x * static_cast<T>(M_SQRT1_2)).erf()).template cast<T>().eval();
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp); out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
} }
}; };
template <typename T> template <typename T>
struct GeluGradFunctor : BaseActivationFunctor<T> { struct GeluGradFunctor : BaseActivationFunctor<T> {
bool Inplace() const { return IsInplace("gelu"); }
template <typename Device, typename X, typename Out, typename dOut, template <typename Device, typename X, typename Out, typename dOut,
typename dX> typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const { void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp = (static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x * auto first = static_cast<T>(0.5) *
((-static_cast<T>(0.5) * x.square()).exp())) (static_cast<T>(1) + ((x * static_cast<T>(M_SQRT1_2)).erf()));
.template cast<T>()
.eval(); auto second = static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
dx.device(d) = dout * (out / x + temp); (-static_cast<T>(0.5) * x.square()).exp();
dx.device(d) = dout * (first + second);
} }
}; };

Loading…
Cancel
Save