|
|
|
@ -99,5 +99,36 @@ struct ReluGradFunctor {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct TanhFunctor {
|
|
|
|
|
template <typename Device, typename X, typename Y>
|
|
|
|
|
void operator()(Device d, X x, Y y) {
|
|
|
|
|
y.device(d) = x.tanh();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct TanhGradFunctor {
|
|
|
|
|
template <typename Device, typename X, typename Y, typename dY, typename dX>
|
|
|
|
|
void operator()(Device d, X x, Y y, dY dy, dX dx) {
|
|
|
|
|
dx.device(d) = dy * (T(1) - y * y);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
struct SqrtFunctor {
|
|
|
|
|
template <typename Device, typename X, typename Y>
|
|
|
|
|
void operator()(Device d, X x, Y y) {
|
|
|
|
|
y.device(d) = x.sqrt();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SqrtGradFunctor {
|
|
|
|
|
template <typename Device, typename X, typename Y, typename dY, typename dX>
|
|
|
|
|
void operator()(Device d, X x, Y y, dY dy, dX dx) {
|
|
|
|
|
const T y_conj = Eigen::numext::conj(y);
|
|
|
|
|
dx.device(d) = static_cast<T>(0.5) * dy / y_conj;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|