@ -16,40 +16,34 @@ limitations under the License. */
namespace paddle {
namespace operators {
template < typename AttrType >
class NormOpMaker : public framework : : OpProtoAndCheckerMaker {
public :
void Make ( ) override {
AddInput (
" X " ,
" (Tensor) The input tensor of norm operator. "
" The format of input tensor is NCHW. Where N is batch size, C is the "
" number of channels, H and W is the height and width of feature. " ) ;
AddInput ( " Scale " ,
" (Tensor) The input tensor of norm operator. "
" The format of input tensor is C * 1. " ) ;
AddAttr < AttrType > ( " epsilon " ,
" (float, default 1e-10) Constant "
" for numerical stability. " )
AddInput ( " X " , " (Tensor) A tensor of rank >= axis. " ) ;
AddAttr < int > ( " axis " ,
" The axis on which to apply normalization. If axis < 0, "
" the dimension to normalization is rank(X) + axis. -1 is "
" the last dimension. " ) ;
AddAttr < float > ( " epsilon " ,
" (float, default 1e-10) The epsilon value is used "
" to avoid division by zero. " )
. SetDefault ( 1.0e-10 f ) ;
AddOutput ( " Out " ,
" (Tensor) The output tensor of norm operator. "
" N * M. "
" M = C * H * W " ) ;
AddOutput ( " Norm " ,
" (Tensor) A tensor saved the `sqrt(sum(x) + epsion)` will "
" be used in backward kernel. " )
. AsIntermediate ( ) ;
AddOutput ( " Out " , " (Tensor) A tensor of the same shape as X. " ) ;
AddComment ( R " DOC(
" Input shape: $(N, C, H, W)$
Scale shape : $ ( C , 1 ) $
Output shape : $ ( N , C , H , W ) $
Where
forward
$ $
[ \ frac { x_ { 1 } } { \ sqrt { \ sum { x_ { i } ^ { 2 } } } } \ frac { x_ { 2 } } { \ sqrt { \ sum { x_ { i } ^ { 2 } } } } \ frac { x_ { 3 } } { \ sqrt { \ sum { x_ { i } ^ { 2 } } } } \ cdot \ cdot \ cdot \ frac { x_ { n } } { \ sqrt { \ sum { x_ { i } ^ { 2 } } } } ]
$ $
backward
$ $
\ frac { \ frac { \ mathrm { d } L } { \ mathrm { d } y_ { 1 } } - \ frac { x_ { 1 } \ sum { \ frac { \ mathrm { d } L } { \ mathrm { d } y_ { j } } } x_ { j } } { \ sum x_ { j } ^ { 2 } } } { \ sqrt { \ sum { x_ { j } ^ { 2 } } } }
$ $
) DOC " );
Given a tensor , apply 2 - normalization along the provided axis .
$ $
y = \ frac { x } { \ sqrt { \ sum { x ^ 2 } + epsion } }
$ $
where , $ \ sum { x ^ 2 } $ is calculated along the ` axis ` dimension .
) DOC " );
}
} ;
@ -58,15 +52,15 @@ class NormOp : public framework::OperatorWithKernel {
using framework : : OperatorWithKernel : : OperatorWithKernel ;
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE ( ctx - > HasInput ( " X " ) ,
" Input(X) of NormOp "
" should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasInput ( " Scale " ) ,
" Input(Scale) of NormOp "
" should not be null. " ) ;
" Input(X) of NormOp should not be null. " ) ;
PADDLE_ENFORCE ( ctx - > HasOutput ( " Out " ) ,
" Output(Out) of NormOp should not be null. " ) ;
auto in_x_dims = ctx - > GetInputDim ( " X " ) ;
ctx - > SetOutputDim ( " Out " , in_x_dims ) ;
auto xdim = ctx - > GetInputDim ( " X " ) ;
ctx - > SetOutputDim ( " Out " , xdim ) ;
int axis = ctx - > Attrs ( ) . Get < int > ( " axis " ) ;
if ( axis < 0 ) axis = xdim . size ( ) + axis ;
xdim [ axis ] = 1 ;
ctx - > SetOutputDim ( " Norm " , xdim ) ;
}
} ;
@ -84,12 +78,12 @@ class NormOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle : : operators ;
REGISTER_OPERATOR ( norm , ops : : NormOp , ops : : NormOpMaker < float > ,
using CPU = paddle : : platform : : CPUDeviceContext ;
REGISTER_OPERATOR ( norm , ops : : NormOp , ops : : NormOpMaker ,
paddle : : framework : : DefaultGradOpDescMaker < true > ) ;
REGISTER_OPERATOR ( norm_grad , ops : : NormOpGrad ) ;
REGISTER_OP_CPU_KERNEL (
norm , ops : : NormKernel < paddle : : platform : : CPUDeviceContext , float > ,
ops : : NormKernel < paddle : : platform : : CPUDeviceContext , double , float > ) ;
REGISTER_OP_CPU_KERNEL (
norm_grad , ops : : NormGradKernel < paddle : : platform : : CPUDeviceContext , float > ,
ops : : NormGradKernel < paddle : : platform : : CPUDeviceContext , double , float > ) ;
REGISTER_OP_CPU_KERNEL ( norm , ops : : NormKernel < CPU , float > ,
ops : : NormKernel < CPU , double > ) ;
REGISTER_OP_CPU_KERNEL ( norm_grad , ops : : NormGradKernel < CPU , float > ,
ops : : NormGradKernel < CPU , double > ) ;