@ -14,7 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/lrn_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
@ -23,18 +25,41 @@ namespace paddle {
namespace operators {
using framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
struct LRNFunctor<platform::CPUDeviceContext, T> {
void operator()(const framework::ExecutionContext& ctx,
const framework::Tensor& input, framework::Tensor* out,
framework::Tensor* mid, int N, int C, int H, int W, int n,
T k, T alpha, T beta) {
const T* idata = input.data<T>();
T k, T alpha, T beta, const DataLayout data_layout) {
auto place = ctx.GetPlace();
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
T* odata = out->mutable_data<T>(place);
T* mdata = mid->mutable_data<T>(place);
math::Transpose<platform::CPUDeviceContext, T, 4> transpose;
auto& dev_ctx = ctx.template device_context<platform::CPUDeviceContext>();
Tensor in_transpose, mid_transpose, out_transpose;
// if channel_last, transpose to channel_first
if (data_layout == DataLayout::kNHWC) {
auto in_dims = input.dims();
std::vector<int64_t> shape(
{in_dims[0], in_dims[3], in_dims[1], in_dims[2]});
in_transpose.mutable_data<T>(framework::make_ddim(shape), place);
mid_transpose.mutable_data<T>(framework::make_ddim(shape), place);
out_transpose.mutable_data<T>(framework::make_ddim(shape), place);
std::vector<int> axis = {0, 3, 1, 2};
transpose(dev_ctx, input, &in_transpose, axis);
} else {
in_transpose = input;
mid_transpose = *mid;
out_transpose = *out;
mid_transpose.mutable_data<T>(mid->dims(), place);
out_transpose.mutable_data<T>(out->dims(), place);
const T* idata = in_transpose.data<T>();
T* odata = out_transpose.data<T>();
T* mdata = mid_transpose.data<T>();
Tensor squared;
T* sdata = squared.mutable_data<T>({1, C + n - 1, H, W}, place);
std::memset(sdata, 0, sizeof(T) * squared.numel());
@ -67,6 +92,13 @@ struct LRNFunctor<platform::CPUDeviceContext, T> {
// compute the final output
blas.VPOW(mid->numel(), mdata, -beta, odata);
blas.VMUL(mid->numel(), odata, idata, odata);
// if channel_last, transpose the output(NCHW) to channel_last
if (data_layout == DataLayout::kNHWC) {
std::vector<int> axis = {0, 2, 3, 1};
transpose(dev_ctx, mid_transpose, mid, axis);
transpose(dev_ctx, out_transpose, out, axis);
template struct LRNFunctor<platform::CPUDeviceContext, float>;
@ -78,7 +110,7 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& x, const framework::Tensor& out,
const framework::Tensor& mid, framework::Tensor* x_g,
const framework::Tensor& out_g, int N, int C, int H, int W,
int n, T alpha, T beta) {
int n, T alpha, T beta, const DataLayout data_layout) {
T ratio = -2 * alpha * beta;
auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
x_g_e = x_g_e.constant(0.0);
@ -93,17 +125,17 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
const int end = start + n;
for (int m = 0; m < N; m++) {
for (int i = 0; i < C; i++) {
auto i_x = e_x.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_x_g = e_x_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto offsets = Eigen::array<int, 4>({{m, i, 0, 0}});
auto extents = Eigen::array<int, 4>({{1, 1, H, W}});
if (data_layout == DataLayout::kNHWC) {
offsets = Eigen::array<int, 4>({{m, 0, 0, i}});
extents = Eigen::array<int, 4>({{1, H, W, 1}});
auto i_mid = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto i_x = e_x.slice(offsets, extents);
auto i_x_g = e_x_g.slice(offsets, extents);
auto i_out_g = e_out_g.slice(offsets, extents);
auto i_mid = e_mid.slice(offsets, extents);
i_x_g = i_mid.pow(-beta) * i_out_g;
for (int c = start; c < end; c++) {
@ -112,14 +144,14 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
auto c_out = e_out.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto c_mid = e_mid.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
auto c_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
Eigen::array<int, 4>({{1, 1, H, W}}));
if (data_layout != DataLayout::kNHWC) {
offsets = Eigen::array<int, 4>({{m, ch, 0, 0}});
} else {
offsets = Eigen::array<int, 4>({{m, 0, 0, ch}});
auto c_out = e_out.slice(offsets, extents);
auto c_mid = e_mid.slice(offsets, extents);
auto c_out_g = e_out_g.slice(offsets, extents);
i_x_g += ratio * c_out_g * c_out * i_x / c_mid;
@ -156,9 +188,8 @@ class LRNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
@ -242,8 +273,8 @@ $$
Function implementation:
Inputs and outpus are in NCHW format, while input.shape.ndims() equals 4.
And dimensions 0 ~ 3 represent batch size, feature maps, rows,
Inputs and outpus are in NCHW or NHWC format, while input.shape.ndims() equals 4.
If NCHW, the dimensions 0 ~ 3 represent batch size, feature maps, rows,
and columns, respectively.
Input and Output in the formula above is for each map(i) of one image, and
@ -275,9 +306,8 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {