|
|
|
@ -1663,6 +1663,10 @@ class SquareDoubleGradKernel
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename Functor>
|
|
|
|
|
class LogDoubleGradKernel
|
|
|
|
|
: public SquareDoubleGradKernel<DeviceContext, Functor> {};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename Functor>
|
|
|
|
|
class ELUDoubleGradKernel
|
|
|
|
|
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
|
|
@ -1852,6 +1856,37 @@ class PowGradKernel
|
|
|
|
|
functor(*place, x, out, dout, dx);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct LogGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device>
|
|
|
|
|
void operator()(const Device& dev, const framework::Tensor* X,
|
|
|
|
|
const framework::Tensor* ddX, framework::Tensor* ddOut,
|
|
|
|
|
const framework::Tensor* dOut, framework::Tensor* dX) const {
|
|
|
|
|
auto* d = dev.eigen_device();
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddX, "Input", "DDX", "LogGradGrad"));
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(X, "Input", "X", "LogGradGrad"));
|
|
|
|
|
// ddout = ddx / x; dx = -(dout / x) * (ddx / x)
|
|
|
|
|
// calculate dx first, so ddout can inplace ddx
|
|
|
|
|
if (dX) {
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dOut, "Output", "DOut", "LogGradGrad"));
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dX, "Output", "DX", "LogGradGrad"));
|
|
|
|
|
dx.device(*d) = dout * static_cast<T>(-1) * ddx / (x * x);
|
|
|
|
|
}
|
|
|
|
|
if (ddOut) {
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "LogGradGrad"));
|
|
|
|
|
ddout.device(*d) = ddx * static_cast<T>(1) / x;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -1872,7 +1907,6 @@ class PowGradKernel
|
|
|
|
|
__macro(cosh, Cosh, CoshFunctor, CoshGradFunctor); \
|
|
|
|
|
__macro(round, Round, RoundFunctor, ZeroGradFunctor); \
|
|
|
|
|
__macro(reciprocal, Reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
|
|
|
|
|
__macro(log, Log, LogFunctor, LogGradFunctor); \
|
|
|
|
|
__macro(log1p, Log1p, Log1pFunctor, Log1pGradFunctor); \
|
|
|
|
|
__macro(brelu, BRelu, BReluFunctor, BReluGradFunctor); \
|
|
|
|
|
__macro(soft_relu, SoftRelu, SoftReluFunctor, SoftReluGradFunctor); \
|
|
|
|
|