|
|
|
@ -26,7 +26,6 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/detail/safe_ref.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/platform/float16.h"
|
|
|
|
|
|
|
|
|
@ -156,8 +155,10 @@ class ActivationKernel
|
|
|
|
|
ExtractActivationTensor(context, &X, &Out);
|
|
|
|
|
Out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(X, "Input", "X", "Activation"));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(Out, "Output", "Out", "Activation"));
|
|
|
|
|
auto* place =
|
|
|
|
|
context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
Functor functor;
|
|
|
|
@ -182,10 +183,14 @@ class ActivationGradKernel
|
|
|
|
|
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
|
|
|
|
|
&dX);
|
|
|
|
|
dX->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "ActivationGrad"));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(Out, "Input", "Out", "ActivationGrad"));
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dX, "Input", "X@GRAD", "ActivationGrad"));
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(X, "Input", "X", "ActivationGrad"));
|
|
|
|
|
auto* place =
|
|
|
|
|
context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
Functor functor;
|
|
|
|
@ -1285,10 +1290,13 @@ struct ReluGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
framework::Tensor* ddOut, framework::Tensor* dOut,
|
|
|
|
|
framework::Tensor* dX) const {
|
|
|
|
|
auto* d = dev.eigen_device();
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddX, "Input", "DDX", "ReluGradGrad"));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(Out, "Output", "Out", "ReluGradGrad"));
|
|
|
|
|
if (ddOut) {
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ReluGradGrad"));
|
|
|
|
|
ddout.device(*d) = ddx * (out > static_cast<T>(0)).template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -1308,9 +1316,12 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
framework::Tensor* dX) const {
|
|
|
|
|
if (ddOut) {
|
|
|
|
|
auto* d = dev.eigen_device();
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddX, "Input", "DDX", "LeakyReluGradGrad"));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(Out, "Output", "Out", "LeakyReluGradGrad"));
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddOut, "Output", "DOut", "LeakyReluGradGrad"));
|
|
|
|
|
ddout.device(*d) = ddx *
|
|
|
|
|
((out > static_cast<T>(0)).template cast<T>() +
|
|
|
|
|
static_cast<T>(alpha) *
|
|
|
|
@ -1332,18 +1343,23 @@ struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
const framework::Tensor* ddX, framework::Tensor* ddOut,
|
|
|
|
|
const framework::Tensor* dOut, framework::Tensor* dX) const {
|
|
|
|
|
auto* d = dev.eigen_device();
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddX, "Input", "DDX", "ELUGradGrad"));
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(X, "Input", "X", "ELUGradGrad"));
|
|
|
|
|
|
|
|
|
|
if (dX) {
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dX, "Output", "DX", "ELUGradGrad"));
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dOut, "Output", "DOut", "ELUGradGrad"));
|
|
|
|
|
dx.device(*d) = ddx * dout * static_cast<T>(alpha) * x.exp() *
|
|
|
|
|
(x < static_cast<T>(0)).template cast<T>();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (ddOut) {
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "ELUGradGrad"));
|
|
|
|
|
ddout.device(*d) = ddx *
|
|
|
|
|
((x > static_cast<T>(0)).template cast<T>() +
|
|
|
|
|
static_cast<T>(alpha) * x.exp() *
|
|
|
|
@ -1361,17 +1377,22 @@ struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
const framework::Tensor* ddX, framework::Tensor* ddOut,
|
|
|
|
|
framework::Tensor* dOut, const framework::Tensor* dX) const {
|
|
|
|
|
auto* d = dev.eigen_device();
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddX, "Input", "DDX", "SqrtGradGrad"));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(Out, "Output", "Out", "SqrtGradGrad"));
|
|
|
|
|
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
|
|
|
|
|
// calculate dy first, so ddy can inplace ddx
|
|
|
|
|
if (dOut) {
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dX, "Output", "DX", "SqrtGradGrad"));
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dOut, "Output", "DOut", "SqrtGradGrad"));
|
|
|
|
|
dout.device(*d) = dx * ddx * static_cast<T>(-1) / out;
|
|
|
|
|
}
|
|
|
|
|
if (ddOut) {
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SqrtGradGrad"));
|
|
|
|
|
ddout.device(*d) = ddx * static_cast<T>(0.5) / out;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -1385,17 +1406,22 @@ struct SquareGradGradFunctor : public BaseActivationFunctor<T> {
|
|
|
|
|
const framework::Tensor* ddX, framework::Tensor* ddOut,
|
|
|
|
|
const framework::Tensor* dOut, framework::Tensor* dX) const {
|
|
|
|
|
auto* d = dev.eigen_device();
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
|
|
|
|
|
auto ddx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddX, "Input", "DDX", "SquareGradGrad"));
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(X, "Input", "X", "SquareGradGrad"));
|
|
|
|
|
// square GradGrad: ddy=2x*ddx, dx=2dy*ddx
|
|
|
|
|
// calculate dx first, so ddy can inplace ddx
|
|
|
|
|
if (dX) {
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dX, "Output", "DX", "SquareGradGrad"));
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dOut, "Output", "DOut", "SquareGradGrad"));
|
|
|
|
|
dx.device(*d) = ddx * static_cast<T>(2) * dout;
|
|
|
|
|
}
|
|
|
|
|
if (ddOut) {
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
|
|
|
|
|
auto ddout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad"));
|
|
|
|
|
ddout.device(*d) = ddx * static_cast<T>(2) * x;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -1557,8 +1583,10 @@ class PowKernel : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
|
|
|
|
|
ExtractActivationTensor(context, &X, &Out);
|
|
|
|
|
Out->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(X, "Input", "X", "Pow"));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(Out, "Output", "Out", "Pow"));
|
|
|
|
|
auto* place =
|
|
|
|
|
context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
Functor functor;
|
|
|
|
@ -1602,10 +1630,14 @@ class PowGradKernel
|
|
|
|
|
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
|
|
|
|
|
&dX);
|
|
|
|
|
dX->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(detail::Ref(dOut));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(detail::Ref(dX));
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
|
|
|
|
|
auto dout = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dOut, "Input", "Out@GRAD", "PowGrad"));
|
|
|
|
|
auto out = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(Out, "Input", "Out", "PowGrad"));
|
|
|
|
|
auto dx = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(dX, "Output", "X@GRAD", "PowGrad"));
|
|
|
|
|
auto x = framework::EigenVector<T>::Flatten(
|
|
|
|
|
GET_DATA_SAFELY(X, "Input", "X", "PowGrad"));
|
|
|
|
|
auto* place =
|
|
|
|
|
context.template device_context<DeviceContext>().eigen_device();
|
|
|
|
|
Functor functor;
|
|
|
|
|