@ -81,7 +81,7 @@ class DataNormOp : public framework::OperatorWithKernel {
protected :
framework : : OpKernelType GetExpectedKernelType (
const framework : : ExecutionContext & ctx ) const override {
auto input_data_type = ctx. Input < Tensor > ( " X " ) - > type ( ) ;
auto input_data_type = OperatorWithKernel: : IndicateVarDataType ( ctx , " X " ) ;
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
@ -89,12 +89,14 @@ class DataNormOp : public framework::OperatorWithKernel {
if ( input_data_type = = framework : : proto : : VarType : : FP64 ) {
dn_param_type = framework : : proto : : VarType : : FP64 ;
}
PADDLE_ENFORCE_EQ ( dn_param_type , ctx . Input < Tensor > ( " BatchSize " ) - > type ( ) ,
PADDLE_ENFORCE_EQ ( dn_param_type ,
OperatorWithKernel : : IndicateVarDataType ( ctx , " BatchSize " ) ,
" BatchSize input should be of float type " ) ;
PADDLE_ENFORCE_EQ ( dn_param_type , ctx . Input < Tensor > ( " BatchSum " ) - > type ( ) ,
" BatchSum input should be of float type " ) ;
PADDLE_ENFORCE_EQ ( dn_param_type ,
ctx . Input < Tensor > ( " BatchSquareSum " ) - > type ( ) ,
OperatorWithKernel : : IndicateVarDataType ( ctx , " BatchSum " ) ,
" BatchSum input should be of float type " ) ;
PADDLE_ENFORCE_EQ ( dn_param_type , OperatorWithKernel : : IndicateVarDataType (
ctx , " BatchSquareSum " ) ,
" BatchSquareSum input should be of float type " ) ;
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
@ -276,8 +278,9 @@ class DataNormGradOp : public framework::OperatorWithKernel {
}
# endif
return framework : : OpKernelType ( ctx . Input < Tensor > ( " X " ) - > type ( ) ,
ctx . GetPlace ( ) , layout , library ) ;
return framework : : OpKernelType (
OperatorWithKernel : : IndicateVarDataType ( ctx , " X " ) , ctx . GetPlace ( ) ,
layout , library ) ;
}
} ;