|
|
|
@ -1359,6 +1359,28 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device>
|
|
|
|
|
void operator()(const Device& dev, const framework::Tensor* Out,
|
|
|
|
|
const framework::Tensor* ddX, framework::Tensor* ddOut,
|
|
|
|
|
framework::Tensor* dOut, const framework::Tensor* dX) const {
|
|
|
|
|
auto* d = dev.eigen_device();
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
|
|
|
if (ddOut) {
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
|
|
|
|
|
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
|
|
|
|
|
}
|
|
|
|
|
if (dOut) {
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
|
|
|
|
|
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
template <typename Device>
|
|
|
|
@ -1433,8 +1455,8 @@ class SquareDoubleGradKernel
|
|
|
|
|
|
|
|
|
|
ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);
|
|
|
|
|
|
|
|
|
|
dX->mutable_data<T>(X->dims(), ctx.GetPlace());
|
|
|
|
|
ddOut->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
if (dX) dX->mutable_data<T>(X->dims(), ctx.GetPlace());
|
|
|
|
|
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto& place = ctx.template device_context<DeviceContext>();
|
|
|
|
|
|
|
|
|
@ -1443,6 +1465,61 @@ class SquareDoubleGradKernel
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename Functor>
|
|
|
|
|
class SqrtDoubleGradKernel
|
|
|
|
|
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
|
|
|
public:
|
|
|
|
|
using T = typename Functor::ELEMENT_TYPE;
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
const framework::Tensor *Out, *dX, *ddX;
|
|
|
|
|
Out = dX = ddX = nullptr;
|
|
|
|
|
framework::Tensor *ddOut, *dOut;
|
|
|
|
|
ddOut = dOut = nullptr;
|
|
|
|
|
|
|
|
|
|
// extract ddx(input), ddout(output)
|
|
|
|
|
auto ddx_var = ctx.InputVar("DDX");
|
|
|
|
|
auto ddo_var = ctx.OutputVar("DDOut");
|
|
|
|
|
PADDLE_ENFORCE(ddx_var != nullptr,
|
|
|
|
|
"Cannot get input Variable DDX, variable name = %s",
|
|
|
|
|
ctx.op().Input("DDX"));
|
|
|
|
|
ddX = ctx.Input<framework::Tensor>("DDX");
|
|
|
|
|
if (ddo_var) {
|
|
|
|
|
ddOut = ctx.Output<framework::Tensor>("DDOut");
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(ddX != nullptr,
|
|
|
|
|
"Cannot get input Variable DDX, variable name = %s",
|
|
|
|
|
ctx.op().Input("DDX"));
|
|
|
|
|
|
|
|
|
|
// extract out(input), dout(output)
|
|
|
|
|
auto out_var = ctx.InputVar("Out");
|
|
|
|
|
PADDLE_ENFORCE(out_var != nullptr,
|
|
|
|
|
"Cannot get input Variable Out, variable name = %s",
|
|
|
|
|
ctx.op().Input("Out"));
|
|
|
|
|
auto dout_var = ctx.OutputVar("DOut");
|
|
|
|
|
Out = ctx.Input<framework::Tensor>("Out");
|
|
|
|
|
if (dout_var) {
|
|
|
|
|
dOut = ctx.Output<framework::Tensor>("DOut");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// extract dx(input)
|
|
|
|
|
auto dx_var = ctx.InputVar("DX");
|
|
|
|
|
PADDLE_ENFORCE(dx_var != nullptr,
|
|
|
|
|
"Cannot get input Variable DX, variable name = %s",
|
|
|
|
|
ctx.op().Input("DX"));
|
|
|
|
|
if (dx_var) {
|
|
|
|
|
dX = ctx.Input<framework::Tensor>("DX");
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (dOut) dOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
|
|
|
|
|
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto& place = ctx.template device_context<DeviceContext>();
|
|
|
|
|
|
|
|
|
|
Functor functor;
|
|
|
|
|
functor(place, Out, ddX, ddOut, dOut, dX);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -1454,7 +1531,6 @@ class SquareDoubleGradKernel
|
|
|
|
|
__macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
|
|
|
|
|
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
|
|
|
|
|
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
|
|
|
|
|
__macro(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor); \
|
|
|
|
|
__macro(rsqrt, Rsqrt, RsqrtFunctor, RsqrtGradFunctor); \
|
|
|
|
|
__macro(abs, Abs, AbsFunctor, AbsGradFunctor); \
|
|
|
|
|
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
|
|
|
|
|