|
|
|
@ -39,8 +39,8 @@ namespace operators {
|
|
|
|
|
Please refer to the layer_helper.py and get the details.
|
|
|
|
|
*/
|
|
|
|
|
static std::unordered_set<std::string> InplaceOpSet = {
|
|
|
|
|
"sigmoid", "exp", "relu", "tanh", "sqrt", "ceil",
|
|
|
|
|
"floor", "reciprocal", "relu6", "soft_relu", "hard_sigmoid"};
|
|
|
|
|
"sigmoid", "exp", "relu", "tanh", "sqrt", "ceil",
|
|
|
|
|
"floor", "reciprocal", "relu6", "soft_relu", "hard_sigmoid", "rsqrt"};
|
|
|
|
|
|
|
|
|
|
static bool IsInplace(const std::string& op) {
|
|
|
|
|
bool inplace = InplaceOpSet.count(op);
|
|
|
|
@ -463,6 +463,24 @@ struct SqrtGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// rsqrt(x) = x^(-1/2)
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct RsqrtFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device, typename X, typename Out>
|
|
|
|
|
void operator()(Device d, X x, Out out) const {
|
|
|
|
|
out.device(d) = x.rsqrt();
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct RsqrtGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device, typename X, typename Out, typename dOut,
|
|
|
|
|
typename dX>
|
|
|
|
|
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
|
|
|
|
|
dx.device(d) = static_cast<T>(-0.5) * dout * out.pow(3);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// ceil(x) = ceiling(x)
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct CeilFunctor : public BaseActivationFunctor<T> {
|
|
|
|
@ -1098,6 +1116,7 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
__macro(atan, AtanFunctor, AtanGradFunctor); \
|
|
|
|
|
__macro(softshrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
|
|
|
|
|
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
|
|
|
|
|
__macro(rsqrt, RsqrtFunctor, RsqrtGradFunctor); \
|
|
|
|
|
__macro(abs, AbsFunctor, AbsGradFunctor); \
|
|
|
|
|
__macro(ceil, CeilFunctor, ZeroGradFunctor); \
|
|
|
|
|
__macro(floor, FloorFunctor, ZeroGradFunctor); \
|
|
|
|
|