|
|
|
@ -1643,6 +1643,35 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct RsqrtGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device>
|
|
|
|
|
void operator()(const Device& dev, const framework::Tensor* Out,
|
|
|
|
|
const framework::Tensor* ddX, framework::Tensor* ddOut,
|
|
|
|
|
framework::Tensor* dOut, const framework::Tensor* dX) const {
|
|
|
|
|
auto* d = dev.eigen_device();
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddX, "Input", "DDX", "RsqrtGradGrad"));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(Out, "Output", "Out", "RsqrtGradGrad"));
|
|
|
|
|
|
|
|
|
|
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3/y) * dx * ddx
|
|
|
|
|
if (dOut) {
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dX, "Output", "DX", "RsqrtGradGrad"));
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dOut, "Output", "DOut", "RsqrtGradGrad"));
|
|
|
|
|
dout.device(*d) = (static_cast<T>(3.0) / out) * dx * ddx;
|
|
|
|
|
}
|
|
|
|
|
if (ddOut) {
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "RsqrtGradGrad"));
|
|
|
|
|
ddout.device(*d) = ddx * static_cast<T>(-0.5) * out * out * out;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device>
|
|
|
|
@ -1828,6 +1857,67 @@ class SqrtDoubleGradKernel
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// rsqrt Grad: dx = -0.5 * dy * y * y * y
|
|
|
|
|
// rsqrt GradGrad: ddy = -0.5 * ddx * y * y * y, dy = (3 / y) * dx * ddx
|
|
|
|
|
template <typename DeviceContext, typename Functor>
|
|
|
|
|
class RsqrtDoubleGradKernel
|
|
|
|
|
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
|
|
|
public:
|
|
|
|
|
using T = typename Functor::ELEMENT_TYPE;
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
const framework::Tensor *Out, *dX, *ddX;
|
|
|
|
|
Out = dX = ddX = nullptr;
|
|
|
|
|
framework::Tensor *ddOut, *dOut;
|
|
|
|
|
ddOut = dOut = nullptr;
|
|
|
|
|
|
|
|
|
|
// extract ddx(input), ddout(output)
|
|
|
|
|
auto ddx_var = ctx.InputVar("DDX");
|
|
|
|
|
auto ddo_var = ctx.OutputVar("DDOut");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
ddx_var, platform::errors::NotFound(
|
|
|
|
|
"Cannot get input Variable DDX, variable name = %s",
|
|
|
|
|
ctx.InputName("DDX")));
|
|
|
|
|
ddX = ctx.Input<framework::Tensor>("DDX");
|
|
|
|
|
if (ddo_var) {
|
|
|
|
|
ddOut = ctx.Output<framework::Tensor>("DDOut");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
ddX, platform::errors::NotFound(
|
|
|
|
|
"Cannot get input Variable DDX, variable name = %s",
|
|
|
|
|
ctx.InputName("DDX")));
|
|
|
|
|
|
|
|
|
|
// extract out(input), dout(output)
|
|
|
|
|
auto out_var = ctx.InputVar("Out");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
out_var, platform::errors::NotFound(
|
|
|
|
|
"Cannot get input Variable Out, variable name = %s",
|
|
|
|
|
ctx.InputName("Out")));
|
|
|
|
|
auto dout_var = ctx.OutputVar("DOut");
|
|
|
|
|
Out = ctx.Input<framework::Tensor>("Out");
|
|
|
|
|
if (dout_var) {
|
|
|
|
|
dOut = ctx.Output<framework::Tensor>("DOut");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// extract dx(input)
|
|
|
|
|
auto dx_var = ctx.InputVar("DX");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
dx_var, platform::errors::NotFound(
|
|
|
|
|
"Cannot get input Variable DX, variable name = %s",
|
|
|
|
|
ctx.InputName("DX")));
|
|
|
|
|
if (dx_var) {
|
|
|
|
|
dX = ctx.Input<framework::Tensor>("DX");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dOut) dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
|
|
|
|
|
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto& place = ctx.template device_context<DeviceContext>();
|
|
|
|
|
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(place, Out, ddX, ddOut, dOut, dX);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename Functor>
|
|
|
|
|
class PowKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
|
|
|
public:
|
|
|
|
@ -1971,7 +2061,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
__macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
|
|
|
|
|
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
|
|
|
|
|
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
|
|
|
|
|
__macro(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor); \
|
|
|
|
|
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
|
|
|
|
|
__macro(floor, Floor, FloorFunctor, ZeroGradFunctor); \
|
|
|
|
|
__macro(cos, Cos, CosFunctor, CosGradFunctor); \
|
|
|
|
|