diff --git a/paddle/fluid/operators/allclose_op.cc b/paddle/fluid/operators/allclose_op.cc index 736483c330..cd83443f05 100644 --- a/paddle/fluid/operators/allclose_op.cc +++ b/paddle/fluid/operators/allclose_op.cc @@ -13,12 +13,49 @@ // limitations under the License. #include "paddle/fluid/operators/allclose_op.h" +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { +template +struct GetTensorValue { + T operator()(const platform::CPUDeviceContext& dev_ctx, + const framework::Tensor& tensor) const { + return *(tensor.data()); + } +}; + +template +struct AllcloseFunctor { + void operator()(const platform::CPUDeviceContext& ctx, + const framework::Tensor& in, const framework::Tensor& other, + const double rtol, const double atol, bool equal_nan, + framework::Tensor* output) { + auto* in_a = in.data(); + auto* in_b = other.data(); + auto* out_data = output->mutable_data(ctx.GetPlace()); + auto num = in.numel(); + *out_data = true; + for (int i = 0; i < num; i++) { + const T a = in_a[i], b = in_b[i]; + bool val; + if (std::isnan(a) || std::isnan(b)) { + val = equal_nan && std::isnan(a) == std::isnan(b); + } else { + T left = (a > b ? a - b : b - a); + T right = atol + (b > 0 ? rtol * b : (-rtol) * b); + T diff = (left > right ? left - right : right - left); + val = a == b || left <= right || diff <= 1e-15; + } + *out_data &= val; + } + } +}; + class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -26,12 +63,9 @@ class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker { "The input tensor, it's data type should be float32, float64."); AddInput("Other", "The input tensor, it's data type should be float32, float64."); + AddInput("Rtol", "The relative tolerance."); + AddInput("Atol", "The absolute tolerance."); AddOutput("Out", "The output tensor, it's data type is bool."); - - AddAttr("rtol", "The relative tolerance. Default: :math:`1e-5` .") - .SetDefault(1e-5); - AddAttr("atol", "The absolute tolerance. Default: :math:`1e-8` .") - .SetDefault(1e-8); AddAttr("equal_nan", "If :math:`True` , then two :math:`NaNs` will be " "compared as equal. Default: :math:`False` .") @@ -54,16 +88,12 @@ class AllcloseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true, - platform::errors::NotFound( - "Input(Input) of allclose op should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Other"), true, - platform::errors::NotFound( - "Input(Other) of allclose op should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::NotFound( - "The output(Out) of allclose op must not be null.")); + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "Allclose"); + OP_INOUT_CHECK(ctx->HasInput("Other"), "Input", "Other", "Allclose"); + OP_INOUT_CHECK(ctx->HasInput("Rtol"), "Input", "Rtol", "Allclose"); + OP_INOUT_CHECK(ctx->HasInput("Atol"), "Input", "Atol", "Allclose"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Allclose"); auto input_dim = ctx->GetInputDim("Input"); auto other_dim = ctx->GetInputDim("Other"); @@ -96,7 +126,7 @@ class AllcloseOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext &ctx) const override { + const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.device_context()); @@ -105,7 +135,7 @@ class AllcloseOp : public framework::OperatorWithKernel { class AllcloseOpVarTypeInference : public framework::VarTypeInference { public: - void operator()(framework::InferVarTypeContext *ctx) const override { + void operator()(framework::InferVarTypeContext* ctx) const override { ctx->SetOutputDataType("Out", framework::proto::VarType::BOOL); } }; diff --git a/paddle/fluid/operators/allclose_op.cu b/paddle/fluid/operators/allclose_op.cu index aaca4e5b12..f98fe75cd6 100644 --- a/paddle/fluid/operators/allclose_op.cu +++ b/paddle/fluid/operators/allclose_op.cu @@ -12,12 +12,70 @@ // See the License for the specific language governing permissions and // limitations under the License. -#define EIGEN_USE_GPU - +#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/allclose_op.h" +namespace paddle { +namespace operators { + +template +struct GetTensorValue { + T operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor& tensor) const { + const T* data = tensor.data(); + T value; + const auto gpu_place = + BOOST_GET_CONST(platform::CUDAPlace, dev_ctx.GetPlace()); + memory::Copy(platform::CPUPlace(), &value, gpu_place, data, sizeof(T), + dev_ctx.stream()); + return value; + } +}; + +template +__global__ void AllcloseCUDAKernel(const T* in_data, const T* other_data, + const double rtol, const double atol, + bool equal_nan, int num, bool* out_data) { + unsigned int idx = threadIdx.x + blockIdx.x * blockDim.x; + bool val; + for (int i = idx; i < num; i += blockDim.x * gridDim.x) { + const T a = in_data[i], b = other_data[i]; + if (isnan(a) || isnan(b)) { + val = equal_nan && isnan(a) == isnan(b); + } else { + T left = (a > b ? a - b : b - a); + T right = atol + (b > 0 ? rtol * b : (-rtol) * b); + T diff = (left > right ? left - right : right - left); + val = a == b || left <= right || diff <= 1e-15; + } + if (!val) *out_data = false; + } +} + +template +struct AllcloseFunctor { + void operator()(const platform::CUDADeviceContext& dev_ctx, + const framework::Tensor& in, const framework::Tensor& other, + const double rtol, const double atol, bool equal_nan, + framework::Tensor* output) { + int num = in.numel(); + const T* in_data = in.data(); + const T* other_data = other.data(); + bool* out_data = output->mutable_data(dev_ctx.GetPlace()); + int block = 1024; + int grid = (block - 1 + num) / block; + grid = (grid > block) ? block : grid; + cudaMemset(out_data, true, sizeof(bool)); + AllcloseCUDAKernel<<>>( + in_data, other_data, rtol, atol, equal_nan, num, out_data); + } +}; + +} // namespace operators +} // namespace paddle + namespace ops = paddle::operators; using CUDA = paddle::platform::CUDADeviceContext; REGISTER_OP_CUDA_KERNEL(allclose, ops::AllcloseKernel, diff --git a/paddle/fluid/operators/allclose_op.h b/paddle/fluid/operators/allclose_op.h index 51893c087c..a08ddca9eb 100644 --- a/paddle/fluid/operators/allclose_op.h +++ b/paddle/fluid/operators/allclose_op.h @@ -22,38 +22,38 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +struct GetTensorValue { + T operator()(const platform::DeviceContext& ctx, + const framework::Tensor& tensor) const; +}; + +template +struct AllcloseFunctor { + void operator()(const DeviceContext& ctx, const framework::Tensor& in, + const framework::Tensor& other, const float rtol, + const float atol, bool equal_nan, framework::Tensor* output); +}; + template class AllcloseKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { // get attrs - float rtol = ctx.Attr("rtol"); - float atol = ctx.Attr("atol"); bool equal_nan = ctx.Attr("equal_nan"); // get input/output - auto* input = ctx.Input("Input"); - auto* other = ctx.Input("Other"); + const auto* input = ctx.Input("Input"); + const auto* other = ctx.Input("Other"); + const auto* rtol = ctx.Input("Rtol"); + const auto* atol = ctx.Input("Atol"); auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - // get place - auto& place = *ctx.template device_context().eigen_device(); - - auto input_v = framework::EigenVector::Flatten(*input); - auto other_v = framework::EigenVector::Flatten(*other); - auto out_v = framework::EigenScalar::From(*out); - - auto left = (input_v - other_v).abs(); - auto right = static_cast(atol) + static_cast(rtol) * other_v.abs(); - auto compare_res = left <= right; - - if (equal_nan) { - auto input_nan = input_v.isnan(); - auto other_nan = other_v.isnan(); - out_v.device(place) = - (input_nan == other_nan).all() && (compare_res != input_nan).all(); - } else { - out_v.device(place) = compare_res.all(); - } + auto& dev_ctx = ctx.template device_context(); + + GetTensorValue get_tensor_value; + double rtol_v = get_tensor_value(dev_ctx, *rtol); + double atol_v = get_tensor_value(dev_ctx, *atol); + AllcloseFunctor()(dev_ctx, *input, *other, rtol_v, atol_v, + equal_nan, out); } }; diff --git a/python/paddle/fluid/tests/unittests/test_allclose_op.py b/python/paddle/fluid/tests/unittests/test_allclose_op.py index 83fef8c29f..6441a789f1 100644 --- a/python/paddle/fluid/tests/unittests/test_allclose_op.py +++ b/python/paddle/fluid/tests/unittests/test_allclose_op.py @@ -22,19 +22,20 @@ class TestAllcloseOp(OpTest): def set_args(self): self.input = np.array([10000., 1e-07]).astype("float32") self.other = np.array([10000.1, 1e-08]).astype("float32") - self.rtol = 1e-05 - self.atol = 1e-08 + self.rtol = np.array([1e-05]).astype("float64") + self.atol = np.array([1e-08]).astype("float64") self.equal_nan = False def setUp(self): self.set_args() self.op_type = "allclose" - self.inputs = {'Input': self.input, 'Other': self.other} - self.attrs = { - 'rtol': self.rtol, - 'atol': self.atol, - 'equal_nan': self.equal_nan + self.inputs = { + 'Input': self.input, + 'Other': self.other, + "Rtol": self.rtol, + "Atol": self.atol } + self.attrs = {'equal_nan': self.equal_nan} self.outputs = { 'Out': np.array([ np.allclose( @@ -54,8 +55,8 @@ class TestAllcloseOpSmallNum(TestAllcloseOp): def set_args(self): self.input = np.array([10000., 1e-08]).astype("float32") self.other = np.array([10000.1, 1e-09]).astype("float32") - self.rtol = 1e-05 - self.atol = 1e-08 + self.rtol = np.array([1e-05]).astype("float64") + self.atol = np.array([1e-08]).astype("float64") self.equal_nan = False @@ -63,8 +64,8 @@ class TestAllcloseOpNanFalse(TestAllcloseOp): def set_args(self): self.input = np.array([1.0, float('nan')]).astype("float32") self.other = np.array([1.0, float('nan')]).astype("float32") - self.rtol = 1e-05 - self.atol = 1e-08 + self.rtol = np.array([1e-05]).astype("float64") + self.atol = np.array([1e-08]).astype("float64") self.equal_nan = False @@ -72,8 +73,8 @@ class TestAllcloseOpNanTrue(TestAllcloseOp): def set_args(self): self.input = np.array([1.0, float('nan')]).astype("float32") self.other = np.array([1.0, float('nan')]).astype("float32") - self.rtol = 1e-05 - self.atol = 1e-08 + self.rtol = np.array([1e-05]).astype("float64") + self.atol = np.array([1e-08]).astype("float64") self.equal_nan = True @@ -130,5 +131,33 @@ class TestAllcloseError(unittest.TestCase): self.assertRaises(TypeError, test_equal_nan) +class TestAllcloseOpFloat32(TestAllcloseOp): + def set_args(self): + self.input = np.array([10.1]).astype("float32") + self.other = np.array([10]).astype("float32") + self.rtol = np.array([0.01]).astype("float64") + self.atol = np.array([0]).astype("float64") + self.equal_nan = False + + +class TestAllcloseOpFloat64(TestAllcloseOp): + def set_args(self): + self.input = np.array([10.1]).astype("float64") + self.other = np.array([10]).astype("float64") + self.rtol = np.array([0.01]).astype("float64") + self.atol = np.array([0]).astype("float64") + self.equal_nan = False + + +class TestAllcloseOpLargeDimInput(TestAllcloseOp): + def set_args(self): + self.input = np.array(np.zeros([2048, 1024])).astype("float64") + self.other = np.array(np.zeros([2048, 1024])).astype("float64") + self.input[-1][-1] = 100 + self.rtol = np.array([1e-05]).astype("float64") + self.atol = np.array([1e-08]).astype("float64") + self.equal_nan = False + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index 7b0d210987..1fc1c17d2e 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import to_tensor from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_type, check_variable_and_dtype from ..fluid.layers.layer_function_generator import templatedoc @@ -95,8 +96,8 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): Args: x(Tensor): ${input_comment}. y(Tensor): ${other_comment}. - rtol(rtoltype, optional): ${rtol_comment}. - atol(atoltype, optional): ${atol_comment}. + rtol(rtoltype, optional): The relative tolerance. Default: :math:`1e-5` . + atol(atoltype, optional): The absolute tolerance. Default: :math:`1e-8` . equal_nan(equalnantype, optional): ${equal_nan_comment}. name (str, optional): Name for the operation. For more information, please refer to :ref:`api_guide_Name`. Default: None. @@ -142,7 +143,9 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): """ if in_dygraph_mode(): - return core.ops.allclose(x, y, 'rtol', rtol, 'atol', atol, 'equal_nan', + rtol_tensor = to_tensor(rtol, dtype='float64') + atol_tensor = to_tensor(atol, dtype='float64') + return core.ops.allclose(x, y, rtol_tensor, atol_tensor, 'equal_nan', equal_nan) check_variable_and_dtype(x, "input", ['float32', 'float64'], 'allclose') @@ -152,11 +155,26 @@ def allclose(x, y, rtol=1e-05, atol=1e-08, equal_nan=False, name=None): check_type(equal_nan, 'equal_nan', bool, 'allclose') helper = LayerHelper("allclose", **locals()) + rtol_var = helper.create_global_variable( + name=fluid.unique_name.generate('rtol'), + persistable=True, + dtype='float64', + shape=[1]) + helper.set_variable_initializer( + rtol_var, initializer=fluid.initializer.ConstantInitializer(rtol)) + atol_var = helper.create_variable( + name=fluid.unique_name.generate('atol'), + persistable=True, + dtype='float64', + shape=[1]) + helper.set_variable_initializer( + atol_var, initializer=fluid.initializer.ConstantInitializer(atol)) + out = helper.create_variable_for_type_inference(dtype='bool') - inputs = {'Input': x, 'Other': y} + inputs = {'Input': x, 'Other': y, 'Rtol': rtol_var, 'Atol': atol_var} outputs = {'Out': out} - attrs = {'rtol': rtol, 'atol': atol, 'equal_nan': equal_nan} + attrs = {'equal_nan': equal_nan} helper.append_op( type='allclose', inputs=inputs, outputs=outputs, attrs=attrs)