parent
4273b3513a
commit
9d142d5060
@ -0,0 +1,141 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/lrn_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class LRNOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LRNOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of LRNOp should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("MidOut"),
|
||||
"MidOut(Out) of LRNOp should not be null.");
|
||||
|
||||
auto x_dim = ctx->GetInputDim("X");
|
||||
PADDLE_ENFORCE_EQ(x_dim.size(), 4, "Input(X)'rank of LRNOp should be 4.");
|
||||
|
||||
ctx->SetOutputDim("Out", x_dim);
|
||||
ctx->SetOutputDim("MidOut", x_dim);
|
||||
ctx->ShareLoD("X", /*->*/ "Out");
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LRNOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
LRNOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X", R"DOC(
|
||||
(Tensor) The input of LRN operator. It must be a 4D tenor with NCHW format.
|
||||
)DOC");
|
||||
|
||||
AddOutput("Out",
|
||||
"(Tensor) The output of LRN operator, which is also the 4D "
|
||||
"tensor with NCHW format.");
|
||||
AddOutput("MidOut", R"Doc(
|
||||
(Tensor)Middle result of lrn op.It's computed in forward process
|
||||
and also used in backward process.
|
||||
)Doc");
|
||||
|
||||
AddAttr<int>("n", R"DOC(
|
||||
(int, default 5)n is “adjacent” kernel maps at the same spatial position.
|
||||
)DOC")
|
||||
.SetDefault(5)
|
||||
.GreaterThan(0);
|
||||
|
||||
AddAttr<T>("k", R"DOC(
|
||||
(float, default 2.0)k is the bias.
|
||||
)DOC")
|
||||
.SetDefault(2.0)
|
||||
.GreaterThan(0.0);
|
||||
|
||||
AddAttr<T>("alpha", R"DOC(
|
||||
(float, default 0.0001)alpha is the scale number.
|
||||
)DOC")
|
||||
.SetDefault(0.0001)
|
||||
.GreaterThan(0.0);
|
||||
|
||||
AddAttr<T>("beta", R"DOC(
|
||||
(float, default 0.75)beta is the power number.
|
||||
)DOC")
|
||||
.SetDefault(0.75)
|
||||
.GreaterThan(0.0);
|
||||
|
||||
AddComment(R"DOC(
|
||||
Local Response Normalization.
|
||||
|
||||
This Function comes from the paper
|
||||
"ImageNet Classification with Deep Convolutional Neural Networks".
|
||||
|
||||
The original formula is:
|
||||
|
||||
Input(i, x, y)
|
||||
Output(i, x, y) = ----------------------------------------------
|
||||
-- upper
|
||||
(k + alpha * > (Input(j, x, y))^2) ^ (beta)
|
||||
-- j = lower
|
||||
|
||||
upper is `min(C, c + n/2)`
|
||||
lower if `max(0, c - n/2)`
|
||||
|
||||
Function implementation:
|
||||
|
||||
inputs and outpus is NCHW format, while input.shape.ndims() is equal 4.
|
||||
And the meaning of each dimension(0-3) is respectively batch size,
|
||||
feature maps, rows and columns.
|
||||
|
||||
Input and Output in the above formula is for each map(i) of one image, and
|
||||
Input(i, x, y), Output(i, x, y) represents an element in an image.
|
||||
|
||||
C is the number of feature maps of one image, and n is a hyper-parameters
|
||||
is configured when Function is initialized. The sum in the denominator
|
||||
is the sum of the same position in the neighboring maps.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class LRNOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
protected:
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("MidOut")),
|
||||
"Input(MidOut@GRAD) should not be null");
|
||||
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
||||
"Input(Out@GRAD) should not be null");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(lrn, ops::LRNOp, ops::LRNOpMaker<float>, lrn_grad, ops::LRNOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(lrn, ops::LRNKernel<paddle::platform::CPUPlace, float>);
|
||||
REGISTER_OP_CPU_KERNEL(lrn_grad,
|
||||
ops::LRNGradKernel<paddle::platform::CPUPlace, float>);
|
@ -0,0 +1,22 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "paddle/operators/lrn_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_GPU_KERNEL(lrn, ops::LRNKernel<paddle::platform::GPUPlace, float>);
|
||||
REGISTER_OP_GPU_KERNEL(lrn_grad,
|
||||
ops::LRNGradKernel<paddle::platform::GPUPlace, float>);
|
@ -0,0 +1,185 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
You may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/eigen.h"
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename Place, typename T>
|
||||
class LRNKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
// f(x) = x * ( k + alpha * SUM((x)^2) )^(-beta)
|
||||
// x represents inputs
|
||||
// f(x) represents outputs
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
// input
|
||||
const Tensor* x = ctx.Input<Tensor>("X");
|
||||
auto x_dims = x->dims();
|
||||
|
||||
// NCHW
|
||||
int N = x_dims[0];
|
||||
int C = x_dims[1];
|
||||
int H = x_dims[2];
|
||||
int W = x_dims[3];
|
||||
|
||||
Tensor* out = ctx.Output<Tensor>("Out");
|
||||
out->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
// MidOut save the intermediate result for backward
|
||||
Tensor* mid = ctx.Output<Tensor>("MidOut");
|
||||
mid->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
int n = ctx.Attr<int>("n");
|
||||
T alpha = ctx.Attr<float>("alpha");
|
||||
T beta = ctx.Attr<float>("beta");
|
||||
T k = ctx.Attr<float>("k");
|
||||
|
||||
PADDLE_ENFORCE(n > 0, "n should >= 0");
|
||||
PADDLE_ENFORCE(alpha >= 0.0, "alpha should >= 0.0");
|
||||
PADDLE_ENFORCE(beta >= 0.0, "beta should >= 0.0");
|
||||
PADDLE_ENFORCE(k >= 0.0, "k should >= 0.0");
|
||||
|
||||
auto x_v = framework::EigenVector<T>::Flatten(*x);
|
||||
|
||||
const int start = -(n - 1) / 2;
|
||||
const int end = start + n;
|
||||
|
||||
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
|
||||
e_mid.device(ctx.GetEigenDevice<Place>()) = e_mid.constant(k);
|
||||
|
||||
auto e_x = framework::EigenTensor<T, 4>::From(*x);
|
||||
for (int m = 0; m < N; m++) {
|
||||
for (int i = 0; i < C; i++) {
|
||||
for (int c = start; c <= end; c++) {
|
||||
int ch = i + c;
|
||||
if (ch >= 0 && ch < C) {
|
||||
auto s = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
|
||||
Eigen::array<int, 4>({{1, 1, H, W}}));
|
||||
|
||||
auto r = e_x.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
|
||||
Eigen::array<int, 4>({{1, 1, H, W}}));
|
||||
|
||||
s.device(ctx.GetEigenDevice<Place>()) += alpha * r.square();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto out_e = framework::EigenVector<T>::Flatten(*out);
|
||||
out_e.device(ctx.GetEigenDevice<Place>()) =
|
||||
x_v * e_mid.reshape(Eigen::DSizes<int, 1>(e_mid.size())).pow(-beta);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* \brief Backward calculation for normalization with across maps.
|
||||
*
|
||||
* Function implementation:
|
||||
*
|
||||
* The implementation of this Function is derived from the
|
||||
* CrossMapNormalFunc implementation.
|
||||
*
|
||||
* InputGrad = OutputGrad * denoms ^ (-beta)
|
||||
* -- upper
|
||||
* + > (OutputGrad * OutputValue * (-2 * alpha * beta) / MidOut) * InputValue
|
||||
* -- lower
|
||||
*
|
||||
* The data of inputs/outputs format is the same as the forward interface
|
||||
* and is NCHW.
|
||||
*
|
||||
* The upper and lower is the same as forward. The logic of the sum
|
||||
* is also the same as forward.
|
||||
*/
|
||||
template <typename Place, typename T>
|
||||
class LRNGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
using Tensor = framework::Tensor;
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
const Tensor* x = ctx.Input<Tensor>("X");
|
||||
const Tensor* out = ctx.Input<Tensor>("Out");
|
||||
const Tensor* out_g = ctx.Input<Tensor>(framework::GradVarName("Out"));
|
||||
const Tensor* mid = ctx.Input<Tensor>("MidOut");
|
||||
|
||||
auto x_g = ctx.Output<Tensor>(framework::GradVarName("X"));
|
||||
x_g->mutable_data<T>(ctx.GetPlace());
|
||||
|
||||
auto x_g_e = framework::EigenVector<T>::Flatten(*x_g);
|
||||
x_g_e.device(ctx.GetEigenDevice<Place>()) = x_g_e.constant(0.0);
|
||||
|
||||
auto x_dims = x->dims();
|
||||
int N = x_dims[0];
|
||||
int C = x_dims[1];
|
||||
int H = x_dims[2];
|
||||
int W = x_dims[3];
|
||||
|
||||
int n = ctx.Attr<int>("n");
|
||||
T alpha = ctx.Attr<T>("alpha");
|
||||
T beta = ctx.Attr<T>("beta");
|
||||
T ratio = -2 * alpha * beta;
|
||||
|
||||
auto e_x = framework::EigenTensor<T, 4>::From(*x);
|
||||
auto e_x_g = framework::EigenTensor<T, 4>::From(*x_g);
|
||||
auto e_out = framework::EigenTensor<T, 4>::From(*out);
|
||||
auto e_out_g = framework::EigenTensor<T, 4>::From(*out_g);
|
||||
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
|
||||
|
||||
const int start = -(n - 1) / 2;
|
||||
const int end = start + n;
|
||||
for (int m = 0; m < N; m++) {
|
||||
for (int i = 0; i < C; i++) {
|
||||
auto i_x = e_x.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
|
||||
Eigen::array<int, 4>({{1, 1, H, W}}));
|
||||
|
||||
auto i_x_g = e_x_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
|
||||
Eigen::array<int, 4>({{1, 1, H, W}}));
|
||||
|
||||
auto i_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
|
||||
Eigen::array<int, 4>({{1, 1, H, W}}));
|
||||
|
||||
auto i_mid = e_mid.slice(Eigen::array<int, 4>({{m, i, 0, 0}}),
|
||||
Eigen::array<int, 4>({{1, 1, H, W}}));
|
||||
|
||||
i_x_g.device(ctx.GetEigenDevice<Place>()) = i_mid.pow(-beta) * i_out_g;
|
||||
for (int c = start; c <= end; c++) {
|
||||
int ch = i + c;
|
||||
if (ch < 0 || ch >= C) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto c_out = e_out.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
|
||||
Eigen::array<int, 4>({{1, 1, H, W}}));
|
||||
|
||||
auto c_mid = e_mid.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
|
||||
Eigen::array<int, 4>({{1, 1, H, W}}));
|
||||
|
||||
auto c_out_g = e_out_g.slice(Eigen::array<int, 4>({{m, ch, 0, 0}}),
|
||||
Eigen::array<int, 4>({{1, 1, H, W}}));
|
||||
|
||||
i_x_g.device(ctx.GetEigenDevice<Place>()) +=
|
||||
ratio * c_out_g * c_out * i_x / c_mid;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,77 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestLRNOp(OpTest):
|
||||
def get_input(self):
|
||||
''' TODO(gongweibao): why it's grad diff is so large?
|
||||
x = np.ndarray(
|
||||
shape=(self.N, self.C, self.H, self.W), dtype=float, order='C')
|
||||
for m in range(0, self.N):
|
||||
for i in range(0, self.C):
|
||||
for h in range(0, self.H):
|
||||
for w in range(0, self.W):
|
||||
x[m][i][h][w] = m * self.C * self.H * self.W + \
|
||||
i * self.H * self.W + \
|
||||
h * self.W + w + 1
|
||||
'''
|
||||
x = np.random.rand(self.N, self.C, self.H, self.W).astype("float32")
|
||||
return x + 1
|
||||
|
||||
def get_out(self):
|
||||
start = -(self.n - 1) / 2
|
||||
end = start + self.n
|
||||
|
||||
mid = np.empty((self.N, self.C, self.H, self.W), dtype=float)
|
||||
mid.fill(self.k)
|
||||
for m in range(0, self.N):
|
||||
for i in range(0, self.C):
|
||||
for c in range(start, end + 1):
|
||||
ch = i + c
|
||||
if ch < 0 or ch >= self.C:
|
||||
continue
|
||||
|
||||
s = mid[m][i][:][:]
|
||||
r = self.x[m][ch][:][:]
|
||||
s += np.square(r) * self.alpha
|
||||
|
||||
mid2 = np.power(mid, -self.beta)
|
||||
return np.multiply(self.x, mid2), mid
|
||||
|
||||
def get_attrs(self):
|
||||
attrs = {
|
||||
'n': self.n,
|
||||
'k': self.k,
|
||||
'alpha': self.alpha,
|
||||
'beta': self.beta
|
||||
}
|
||||
return attrs
|
||||
|
||||
def setUp(self):
|
||||
self.op_type = "lrn"
|
||||
self.N = 2
|
||||
self.C = 3
|
||||
self.H = 5
|
||||
self.W = 5
|
||||
|
||||
self.n = 5
|
||||
self.k = 2.0
|
||||
self.alpha = 0.0001
|
||||
self.beta = 0.75
|
||||
self.x = self.get_input()
|
||||
self.out, self.mid_out = self.get_out()
|
||||
|
||||
self.inputs = {'X': self.x}
|
||||
self.outputs = {'Out': self.out, 'MidOut': self.mid_out}
|
||||
self.attrs = self.get_attrs()
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad_normal(self):
|
||||
self.check_grad(['X'], 'Out', max_relative_error=0.01)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue