|
|
|
@ -301,23 +301,22 @@ template <typename T>
|
|
|
|
|
struct GeluFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
|
|
|
void operator()(Device d, X x, Out out) const {
|
|
|
|
|
auto temp =
|
|
|
|
|
((x * static_cast<T>(M_SQRT1_2)).erf()).template cast<T>().eval();
|
|
|
|
|
auto temp = (x * static_cast<T>(M_SQRT1_2)).erf();
|
|
|
|
|
out.device(d) = x * static_cast<T>(0.5) * (static_cast<T>(1) + temp);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct GeluGradFunctor : BaseActivationFunctor<T> {
|
|
|
|
|
bool Inplace() const { return IsInplace("gelu"); }
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
|
auto temp = (static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
|
|
|
|
|
((-static_cast<T>(0.5) * x.square()).exp()))
|
|
|
|
|
.template cast<T>()
|
|
|
|
|
.eval();
|
|
|
|
|
dx.device(d) = dout * (out / x + temp);
|
|
|
|
|
auto first = static_cast<T>(0.5) *
|
|
|
|
|
(static_cast<T>(1) + ((x * static_cast<T>(M_SQRT1_2)).erf()));
|
|
|
|
|
|
|
|
|
|
auto second = static_cast<T>(0.5 * M_2_SQRTPI * M_SQRT1_2) * x *
|
|
|
|
|
(-static_cast<T>(0.5) * x.square()).exp();
|
|
|
|
|
dx.device(d) = dout * (first + second);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|