@ -13,12 +13,49 @@
// limitations under the License.
# include "paddle/fluid/operators/allclose_op.h"
# include <cmath>
# include "paddle/fluid/framework/op_registry.h"
# include "paddle/fluid/framework/operator.h"
# include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
template < typename T >
struct GetTensorValue < platform : : CPUDeviceContext , T > {
T operator ( ) ( const platform : : CPUDeviceContext & dev_ctx ,
const framework : : Tensor & tensor ) const {
return * ( tensor . data < T > ( ) ) ;
}
} ;
template < typename T >
struct AllcloseFunctor < platform : : CPUDeviceContext , T > {
void operator ( ) ( const platform : : CPUDeviceContext & ctx ,
const framework : : Tensor & in , const framework : : Tensor & other ,
const double rtol , const double atol , bool equal_nan ,
framework : : Tensor * output ) {
auto * in_a = in . data < T > ( ) ;
auto * in_b = other . data < T > ( ) ;
auto * out_data = output - > mutable_data < bool > ( ctx . GetPlace ( ) ) ;
auto num = in . numel ( ) ;
* out_data = true ;
for ( int i = 0 ; i < num ; i + + ) {
const T a = in_a [ i ] , b = in_b [ i ] ;
bool val ;
if ( std : : isnan ( a ) | | std : : isnan ( b ) ) {
val = equal_nan & & std : : isnan ( a ) = = std : : isnan ( b ) ;
} else {
T left = ( a > b ? a - b : b - a ) ;
T right = atol + ( b > 0 ? rtol * b : ( - rtol ) * b ) ;
T diff = ( left > right ? left - right : right - left ) ;
val = a = = b | | left < = right | | diff < = 1e-15 ;
}
* out_data & = val ;
}
}
} ;
class AllcloseOpMaker : public framework : : OpProtoAndCheckerMaker {
public :
void Make ( ) override {
@ -26,12 +63,9 @@ class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker {
" The input tensor, it's data type should be float32, float64. " ) ;
AddInput ( " Other " ,
" The input tensor, it's data type should be float32, float64. " ) ;
AddInput ( " Rtol " , " The relative tolerance. " ) ;
AddInput ( " Atol " , " The absolute tolerance. " ) ;
AddOutput ( " Out " , " The output tensor, it's data type is bool. " ) ;
AddAttr < float > ( " rtol " , " The relative tolerance. Default: :math:`1e-5` . " )
. SetDefault ( 1e-5 ) ;
AddAttr < float > ( " atol " , " The absolute tolerance. Default: :math:`1e-8` . " )
. SetDefault ( 1e-8 ) ;
AddAttr < bool > ( " equal_nan " ,
" If :math:`True` , then two :math:`NaNs` will be "
" compared as equal. Default: :math:`False` . " )
@ -55,15 +89,11 @@ class AllcloseOp : public framework::OperatorWithKernel {
using framework : : OperatorWithKernel : : OperatorWithKernel ;
void InferShape ( framework : : InferShapeContext * ctx ) const override {
PADDLE_ENFORCE_EQ ( ctx - > HasInput ( " Input " ) , true ,
platform : : errors : : NotFound (
" Input(Input) of allclose op should not be null. " ) ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasInput ( " Other " ) , true ,
platform : : errors : : NotFound (
" Input(Other) of allclose op should not be null. " ) ) ;
PADDLE_ENFORCE_EQ ( ctx - > HasOutput ( " Out " ) , true ,
platform : : errors : : NotFound (
" The output(Out) of allclose op must not be null. " ) ) ;
OP_INOUT_CHECK ( ctx - > HasInput ( " Input " ) , " Input " , " Input " , " Allclose " ) ;
OP_INOUT_CHECK ( ctx - > HasInput ( " Other " ) , " Input " , " Other " , " Allclose " ) ;
OP_INOUT_CHECK ( ctx - > HasInput ( " Rtol " ) , " Input " , " Rtol " , " Allclose " ) ;
OP_INOUT_CHECK ( ctx - > HasInput ( " Atol " ) , " Input " , " Atol " , " Allclose " ) ;
OP_INOUT_CHECK ( ctx - > HasOutput ( " Out " ) , " Output " , " Out " , " Allclose " ) ;
auto input_dim = ctx - > GetInputDim ( " Input " ) ;
auto other_dim = ctx - > GetInputDim ( " Other " ) ;