@ -14,6 +14,7 @@ limitations under the License. */
# include "paddle/fluid/operators/layer_norm_op.h"
# include <memory>
# include <string>
# ifdef PADDLE_WITH_MKLDNN
# include "paddle/fluid/platform/mkldnn_helper.h"
@ -98,7 +99,26 @@ class LayerNormOp : public framework::OperatorWithKernel {
protected :
framework : : OpKernelType GetExpectedKernelType (
const framework : : ExecutionContext & ctx ) const {
const framework : : ExecutionContext & ctx ) const override {
auto input_data_type = OperatorWithKernel : : IndicateVarDataType ( ctx , " X " ) ;
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto ln_param_type = framework : : proto : : VarType : : FP32 ;
if ( input_data_type = = framework : : proto : : VarType : : FP64 ) {
ln_param_type = framework : : proto : : VarType : : FP64 ;
}
if ( ctx . HasInput ( " Scale " ) ) {
PADDLE_ENFORCE_EQ ( ln_param_type , ctx . Input < Tensor > ( " Scale " ) - > type ( ) ,
platform : : errors : : InvalidArgument (
" Scale input should be of float type " ) ) ;
}
if ( ctx . HasInput ( " Bias " ) ) {
PADDLE_ENFORCE_EQ ( ln_param_type , ctx . Input < Tensor > ( " Bias " ) - > type ( ) ,
platform : : errors : : InvalidArgument (
" Bias input should be of float type " ) ) ;
}
framework : : LibraryType library = framework : : LibraryType : : kPlain ;
framework : : DataLayout layout = framework : : DataLayout : : kAnyLayout ;
@ -110,9 +130,8 @@ class LayerNormOp : public framework::OperatorWithKernel {
}
# endif
return framework : : OpKernelType (
OperatorWithKernel : : IndicateVarDataType ( ctx , " X " ) , ctx . GetPlace ( ) ,
layout , library ) ;
return framework : : OpKernelType ( input_data_type , ctx . GetPlace ( ) , layout ,
library ) ;
}
} ;
@ -224,7 +243,13 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
}
PADDLE_ENFORCE_NOT_NULL (
t , platform : : errors : : NotFound ( " Y@GRAD of LayerNorm Op is not found. " ) ) ;
return framework : : OpKernelType ( t - > type ( ) , ctx . GetPlace ( ) ) ;
framework : : LibraryType library = framework : : LibraryType : : kPlain ;
framework : : DataLayout layout = framework : : DataLayout : : kAnyLayout ;
return framework : : OpKernelType (
OperatorWithKernel : : IndicateVarDataType ( ctx , " X " ) , ctx . GetPlace ( ) ,
layout , library ) ;
}
} ;