|
|
|
@ -21,6 +21,14 @@
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
template <typename place, typename T>
|
|
|
|
|
struct LRNFunctor {
|
|
|
|
|
void operator()(const framework::ExecutionContext& ctx,
|
|
|
|
|
const framework::Tensor& input, framework::Tensor* out,
|
|
|
|
|
framework::Tensor* mid, int N, int C, int H, int W, int n,
|
|
|
|
|
T k, T alpha, T beta);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class LRNKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
@ -31,8 +39,8 @@ class LRNKernel : public framework::OpKernel<T> {
|
|
|
|
|
// f(x) represents outputs
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
// input
|
|
|
|
|
const Tensor* x = ctx.Input<Tensor>("X");
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
const Tensor& x = *ctx.Input<Tensor>("X");
|
|
|
|
|
auto x_dims = x.dims();
|
|
|
|
|
|
|
|
|
|
// NCHW
|
|
|
|
|
int N = x_dims[0];
|
|
|
|
@ -57,38 +65,20 @@ class LRNKernel : public framework::OpKernel<T> {
|
|
|
|
|
PADDLE_ENFORCE(beta >= 0.0, "beta should >= 0.0");
|
|
|
|
|
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");
|
|
|
|
|
|
|
|
|
|
auto x_v = framework::EigenVector<T>::Flatten(*x);
|
|
|
|
|
|
|
|
|
|
const int start = -(n - 1) / 2;
|
|
|
|
|
const int end = start + n;
|
|
|
|
|
|
|
|
|
|
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
|
|
|
|
|
e_mid.device(ctx.GetEigenDevice<Place>()) = e_mid.constant(k);
|
|
|
|
|
|
|
|
|
|
auto e_x = framework::EigenTensor<T, 4>::From(*x);
|
|
|
|
|
for (int m = 0; m < N; m++) {
|
|
|
|
|
for (int i = 0; i < C; i++) {
|
|
|
|
|
for (int c = start; c <= end; c++) {
|
|
|
|
|
int ch = i + c;
|
|
|
|
|
if (ch >= 0 && ch < C) {
|
|
|
|
|
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
|
|
|
|
|
Eigen::array<int, 4>({{1, 1, H, W}}));
|
|
|
|
|
|
|
|
|
|
auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
|
|
|
|
|
Eigen::array<int, 4>({{1, 1, H, W}}));
|
|
|
|
|
|
|
|
|
|
s.device(ctx.GetEigenDevice<Place>()) += alpha * r.square();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto out_e = framework::EigenVector<T>::Flatten(*out);
|
|
|
|
|
out_e.device(ctx.GetEigenDevice<Place>()) =
|
|
|
|
|
x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
|
|
|
|
|
LRNFunctor<Place, T> f;
|
|
|
|
|
f(ctx, x, out, mid, N, C, H, W, n, k, alpha, beta);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
struct LRNGradFunctor {
|
|
|
|
|
void operator()(const framework::ExecutionContext& ctx,
|
|
|
|
|
const framework::Tensor& x, const framework::Tensor& out,
|
|
|
|
|
const framework::Tensor& mid, framework::Tensor* x_g,
|
|
|
|
|
const framework::Tensor& out_g, int N, int C, int H, int W,
|
|
|
|
|
int n, T alpha, T beta);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* \brief Backward calculation for normalization with across maps.
|
|
|
|
|
*
|
|
|
|
@ -97,7 +87,7 @@ class LRNKernel : public framework::OpKernel<T> {
|
|
|
|
|
* The implementation of this Function is derived from the
|
|
|
|
|
* CrossMapNormalFunc implementation.
|
|
|
|
|
*
|
|
|
|
|
* InputGrad = OutputGrad * denoms ^ (-beta)
|
|
|
|
|
* InputGrad = OutputGrad * MidOut ^ (-beta)
|
|
|
|
|
* -- upper
|
|
|
|
|
* + > (OutputGrad * OutputValue * (-2 * alpha * beta) / MidOut) * InputValue
|
|
|
|
|
* -- lower
|
|
|
|
@ -113,18 +103,15 @@ class LRNGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
void Compute(const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
const Tensor* x = ctx.Input<Tensor>("X");
|
|
|
|
|
const Tensor* out = ctx.Input<Tensor>("Out");
|
|
|
|
|
const Tensor* out_g = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
const Tensor* mid = ctx.Input<Tensor>("MidOut");
|
|
|
|
|
const Tensor& x = *ctx.Input<Tensor>("X");
|
|
|
|
|
const Tensor& out = *ctx.Input<Tensor>("Out");
|
|
|
|
|
const Tensor& out_g = *ctx.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
const Tensor& mid = *ctx.Input<Tensor>("MidOut");
|
|
|
|
|
|
|
|
|
|
auto x_g = ctx.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
x_g->mutable_data<T>(ctx.GetPlace());
|
|
|
|
|
|
|
|
|
|
auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
|
|
|
|
|
x_g_e.device(ctx.GetEigenDevice<Place>()) = x_g_e.constant(0.0);
|
|
|
|
|
|
|
|
|
|
auto x_dims = x->dims();
|
|
|
|
|
auto x_dims = x.dims();
|
|
|
|
|
int N = x_dims[0];
|
|
|
|
|
int C = x_dims[1];
|
|
|
|
|
int H = x_dims[2];
|
|
|
|
@ -133,51 +120,9 @@ class LRNGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
int n = ctx.Attr<int>("n");
|
|
|
|
|
T alpha = ctx.Attr<T>("alpha");
|
|
|
|
|
T beta = ctx.Attr<T>("beta");
|
|
|
|
|
T ratio = -2 * alpha * beta;
|
|
|
|
|
|
|
|
|
|
auto e_x = framework::EigenTensor<T, 4>::From(*x);
|
|
|
|
|
auto e_x_g = framework::EigenTensor<T, 4>::From(*x_g);
|
|
|
|
|
auto e_out = framework::EigenTensor<T, 4>::From(*out);
|
|
|
|
|
auto e_out_g = framework::EigenTensor<T, 4>::From(*out_g);
|
|
|
|
|
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
|
|
|
|
|
|
|
|
|
|
const int start = -(n - 1) / 2;
|
|
|
|
|
const int end = start + n;
|
|
|
|
|
for (int m = 0; m < N; m++) {
|
|
|
|
|
for (int i = 0; i < C; i++) {
|
|
|
|
|
auto i_x = e_x.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
|
|
|
|
|
Eigen::array<int, 4>({{1, 1, H, W}}));
|
|
|
|
|
|
|
|
|
|
auto i_x_g = e_x_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
|
|
|
|
|
Eigen::array<int, 4>({{1, 1, H, W}}));
|
|
|
|
|
|
|
|
|
|
auto i_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
|
|
|
|
|
Eigen::array<int, 4>({{1, 1, H, W}}));
|
|
|
|
|
|
|
|
|
|
auto i_mid = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
|
|
|
|
|
Eigen::array<int, 4>({{1, 1, H, W}}));
|
|
|
|
|
|
|
|
|
|
i_x_g.device(ctx.GetEigenDevice<Place>()) = i_mid.pow(-beta) * i_out_g;
|
|
|
|
|
for (int c = start; c <= end; c++) {
|
|
|
|
|
int ch = i + c;
|
|
|
|
|
if (ch < 0 || ch >= C) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto c_out = e_out.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
|
|
|
|
|
Eigen::array<int, 4>({{1, 1, H, W}}));
|
|
|
|
|
|
|
|
|
|
auto c_mid = e_mid.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
|
|
|
|
|
Eigen::array<int, 4>({{1, 1, H, W}}));
|
|
|
|
|
|
|
|
|
|
auto c_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
|
|
|
|
|
Eigen::array<int, 4>({{1, 1, H, W}}));
|
|
|
|
|
|
|
|
|
|
i_x_g.device(ctx.GetEigenDevice<Place>()) +=
|
|
|
|
|
ratio * c_out_g * c_out * i_x / c_mid;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
LRNGradFunctor<Place, T> f;
|
|
|
|
|
f(ctx, x, out, mid, x_g, out_g, N, C, H, W, n, alpha, beta);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|