parent
a87f4963ed
commit
b1025cf50a
@ -0,0 +1,106 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
Indicesou may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/norm_op.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class NormOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
NormOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput(
|
||||
"X",
|
||||
"(Tensor) The input tensor of norm operator. "
|
||||
"The format of input tensor is NCHW. Where N is batch size, C is the "
|
||||
"number of channels, H and W is the height and width of feature.");
|
||||
AddInput("Scale",
|
||||
"(Tensor) The input tensor of norm operator. "
|
||||
"The format of input tensor is C * 1.");
|
||||
AddAttr<float>("epsilon",
|
||||
"(float, default 1e-10) Constant "
|
||||
"for numerical stability.")
|
||||
.SetDefault(1.0e-10f);
|
||||
AddOutput("Out",
|
||||
"(Tensor) The output tensor of norm operator."
|
||||
"N * M."
|
||||
"M = C * H * W");
|
||||
AddComment(R"DOC(
|
||||
"Input shape: $(N, C, H, W)$
|
||||
Sclae shape: $(C, 1)$
|
||||
Output shape: $(N, C, H, W)$
|
||||
Where
|
||||
forward
|
||||
$$
|
||||
[\frac {x_{1}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{2}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{3}}{\sqrt{\sum{x_{i}^{2}}}} \cdot \cdot \cdot \frac {x_{n}}{\sqrt{\sum{x_{i}^{2}}}}]
|
||||
$$
|
||||
backward
|
||||
$$
|
||||
\frac{\frac{\mathrm{d}L }{\mathrm{d}y_{1}} - \frac {x_{1}\sum {\frac{\mathrm{d} L}{\mathrm{d} y_{j}}}x_{j}}{\sum x_{j}^{2}} }{\sqrt{\sum{x_{j}^{2}}}}
|
||||
$$
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class NormOp : public framework::OperatorWithKernel {
|
||||
protected:
|
||||
framework::OpKernelType GetKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"),
|
||||
"Input(X) of NormOp"
|
||||
"should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of NormOp should not be null.");
|
||||
auto in_x_dims = ctx->GetInputDim("X");
|
||||
ctx->SetOutputDim("Out", in_x_dims);
|
||||
}
|
||||
};
|
||||
|
||||
class NormOpGrad : public framework::OperatorWithKernel {
|
||||
protected:
|
||||
framework::OpKernelType GetKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
|
||||
ctx.device_context());
|
||||
}
|
||||
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
||||
"Input(X@GRAD) should not be null.");
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker, norm_grad, ops::NormOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
norm, ops::NormKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::NormKernel<paddle::platform::CPUDeviceContext, double>);
|
||||
REGISTER_OP_CPU_KERNEL(
|
||||
norm_grad, ops::NormGradKernel<paddle::platform::CPUDeviceContext, float>,
|
||||
ops::NormGradKernel<paddle::platform::CPUDeviceContext, double>);
|
@ -0,0 +1,24 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
Indicesou may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "paddle/operators/norm_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
norm, ops::NormKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::NormKernel<paddle::platform::CUDADeviceContext, double>);
|
||||
REGISTER_OP_CUDA_KERNEL(
|
||||
norm_grad, ops::NormGradKernel<paddle::platform::CUDADeviceContext, float>,
|
||||
ops::NormGradKernel<paddle::platform::CUDADeviceContext, double>);
|
@ -0,0 +1,162 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
Indicesou may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
|
||||
template <typename T, int MajorType = Eigen::RowMajor,
|
||||
typename IndexType = Eigen::DenseIndex>
|
||||
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class NormKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
|
||||
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
|
||||
auto* out = context.Output<framework::Tensor>("Out");
|
||||
T epsilon = context.Attr<T>("epsilon");
|
||||
out->mutable_data<T>(context.GetPlace());
|
||||
int batch_size = in_x->dims()[0];
|
||||
int channels = in_x->dims()[1];
|
||||
int height = in_x->dims()[2];
|
||||
int width = in_x->dims()[3];
|
||||
int fea_len = height * width;
|
||||
auto* place =
|
||||
context.template device_context<DeviceContext>().eigen_device();
|
||||
auto x = EigenMatrix<T>::From(
|
||||
*in_x, framework::make_ddim({batch_size, fea_len * channels}));
|
||||
// get square
|
||||
framework::Tensor x_square;
|
||||
x_square.mutable_data<T>(in_x->dims(), context.GetPlace());
|
||||
auto x_square_eigen = EigenMatrix<T>::From(
|
||||
x_square, framework::make_ddim({batch_size, fea_len * channels}));
|
||||
x_square_eigen.device(*place) = x.square();
|
||||
auto scale_eigen = EigenVector<T>::Flatten(*scale);
|
||||
for (int n = 0; n < batch_size; ++n) {
|
||||
framework::Tensor in_x_batch = in_x->Slice(n, n + 1);
|
||||
auto in_x_batch_eigen = EigenMatrix<T>::From(
|
||||
in_x_batch, framework::make_ddim({channels, fea_len}));
|
||||
framework::Tensor x_square_batch = x_square.Slice(n, n + 1);
|
||||
auto x_square_batch_eigen = EigenMatrix<T>::From(
|
||||
x_square_batch, framework::make_ddim({channels, fea_len}));
|
||||
framework::Tensor out_batch = out->Slice(n, n + 1);
|
||||
auto out_batch_eigen = EigenMatrix<T>::From(
|
||||
out_batch, framework::make_ddim({channels, fea_len}));
|
||||
framework::Tensor tmp_tensor;
|
||||
tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
|
||||
context.GetPlace());
|
||||
auto tmp = EigenVector<T>::Flatten(tmp_tensor);
|
||||
// get colsum and sqrt , inverse
|
||||
auto dim = Eigen::array<int, 1>({{0}});
|
||||
tmp.device(*place) = x_square_batch_eigen.sum(dim);
|
||||
tmp.device(*place) = (tmp + epsilon).sqrt().inverse();
|
||||
Eigen::array<int, 2> broadcast_dim_col;
|
||||
broadcast_dim_col[1] = 1;
|
||||
broadcast_dim_col[0] = channels;
|
||||
out_batch_eigen.device(*place) =
|
||||
in_x_batch_eigen * (tmp.broadcast(broadcast_dim_col));
|
||||
Eigen::array<int, 2> broadcast_dim_row;
|
||||
broadcast_dim_row[1] = fea_len;
|
||||
broadcast_dim_row[0] = 1;
|
||||
out_batch_eigen.device(*place) =
|
||||
out_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row));
|
||||
}
|
||||
}
|
||||
};
|
||||
template <typename DeviceContext, typename T>
|
||||
class NormGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
|
||||
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
|
||||
const framework::Tensor* out_grad =
|
||||
context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
T epsilon = context.Attr<T>("epsilon");
|
||||
framework::Tensor* in_x_grad =
|
||||
context.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
in_x_grad->mutable_data<T>(context.GetPlace());
|
||||
int batch_size = in_x->dims()[0];
|
||||
int channels = in_x->dims()[1];
|
||||
int height = in_x->dims()[2];
|
||||
int width = in_x->dims()[3];
|
||||
int fea_len = height * width;
|
||||
auto* place =
|
||||
context.template device_context<DeviceContext>().eigen_device();
|
||||
|
||||
auto scale_eigen = EigenVector<T>::Flatten(*scale);
|
||||
auto x = EigenMatrix<T>::From(
|
||||
*in_x, framework::make_ddim({batch_size, fea_len * channels}));
|
||||
// get square
|
||||
framework::Tensor x_square;
|
||||
x_square.mutable_data<T>(in_x->dims(), context.GetPlace());
|
||||
auto x_square_eigen = EigenMatrix<T>::From(
|
||||
x_square, framework::make_ddim({batch_size, fea_len * channels}));
|
||||
x_square_eigen.device(*place) = x.square();
|
||||
|
||||
for (int n = 0; n < batch_size; ++n) {
|
||||
framework::Tensor in_x_batch = in_x->Slice(n, n + 1);
|
||||
auto in_x_batch_eigen = EigenMatrix<T>::From(
|
||||
in_x_batch, framework::make_ddim({channels, fea_len}));
|
||||
framework::Tensor in_g_batch = in_x_grad->Slice(n, n + 1);
|
||||
auto in_g_batch_eigen = EigenMatrix<T>::From(
|
||||
in_g_batch, framework::make_ddim({channels, fea_len}));
|
||||
framework::Tensor x_square_batch = x_square.Slice(n, n + 1);
|
||||
auto x_square_batch_eigen = EigenMatrix<T>::From(
|
||||
x_square_batch, framework::make_ddim({channels, fea_len}));
|
||||
framework::Tensor outg_batch = out_grad->Slice(n, n + 1);
|
||||
auto outg_batch_eigen = EigenMatrix<T>::From(
|
||||
outg_batch, framework::make_ddim({channels, fea_len}));
|
||||
|
||||
framework::Tensor tmp_tensor;
|
||||
tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
|
||||
context.GetPlace());
|
||||
auto tmp_eigen = EigenVector<T>::Flatten(tmp_tensor);
|
||||
auto dim = Eigen::array<int, 1>({{0}});
|
||||
tmp_eigen.device(*place) = (in_x_batch_eigen * outg_batch_eigen).sum(dim);
|
||||
framework::Tensor norm_tmp_tensor;
|
||||
norm_tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
|
||||
context.GetPlace());
|
||||
auto norm_tmp_eigen = EigenVector<T>::Flatten(norm_tmp_tensor);
|
||||
norm_tmp_eigen.device(*place) =
|
||||
(x_square_batch_eigen.sum(dim) + epsilon).sqrt();
|
||||
Eigen::array<int, 2> broadcast_dim_col;
|
||||
broadcast_dim_col[1] = 1;
|
||||
broadcast_dim_col[0] = channels;
|
||||
in_g_batch_eigen.device(*place) =
|
||||
in_x_batch_eigen * tmp_eigen.broadcast(broadcast_dim_col);
|
||||
in_g_batch_eigen.device(*place) =
|
||||
in_g_batch_eigen /
|
||||
(norm_tmp_eigen * norm_tmp_eigen).broadcast(broadcast_dim_col);
|
||||
in_g_batch_eigen.device(*place) = outg_batch_eigen - in_g_batch_eigen;
|
||||
// outg_batch_eigen + (in_g_batch_eigen * -1);
|
||||
in_g_batch_eigen.device(*place) =
|
||||
in_g_batch_eigen / norm_tmp_eigen.broadcast(broadcast_dim_col);
|
||||
Eigen::array<int, 2> broadcast_dim_row;
|
||||
broadcast_dim_row[1] = fea_len;
|
||||
broadcast_dim_row[0] = 1;
|
||||
in_g_batch_eigen.device(*place) =
|
||||
in_g_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row));
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,57 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def norm(input, scale, epsilon):
|
||||
s0, s1, s2, s3 = input.shape
|
||||
x_square = input * input
|
||||
for i in xrange(s0):
|
||||
input_batch = input[i:i + 1, :, :, :]
|
||||
input_batch = input_batch.reshape(s1, s2 * s3)
|
||||
x_square_batch = x_square[i:i + 1, :, :, :]
|
||||
x_square_batch = x_square_batch.reshape(s1, s2 * s3)
|
||||
square_colsum = x_square_batch.sum(axis=0) + epsilon
|
||||
tmp = pow(square_colsum, 0.5)
|
||||
tmp = np.reciprocal(tmp)
|
||||
tmp_tile = np.tile(tmp, s1)
|
||||
tmp_tile = tmp_tile.reshape(s1, s2 * s3)
|
||||
scale_tile = np.tile(scale, (1, s2 * s3))
|
||||
scale_tile = scale_tile.reshape(s1, s2 * s3)
|
||||
out_batch = input_batch * tmp_tile * scale_tile
|
||||
out_batch = out_batch.reshape(1, s1, s2, s3)
|
||||
if i == 0:
|
||||
out = out_batch
|
||||
else:
|
||||
out = np.concatenate((out, out_batch), 0)
|
||||
out.reshape(s0, s1, s2, s3)
|
||||
return out
|
||||
|
||||
|
||||
class TestNormOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "norm"
|
||||
self.init_test_case()
|
||||
input = np.random.random(self.shape).astype("float32")
|
||||
scale = np.array([10, 10, 10])
|
||||
self.inputs = {
|
||||
'X': input.astype('float32'),
|
||||
'Scale': scale.astype('float32')
|
||||
}
|
||||
self.attrs = {'epsilon': self.epsilon}
|
||||
output = norm(input, scale, self.epsilon)
|
||||
self.outputs = {'Out': output.astype('float32')}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
def init_test_case(self):
|
||||
self.shape = [1, 3, 2, 2]
|
||||
self.epsilon = 1e-6
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue