Add allclose_op (#23335)
* Add allclose Op, and its function is analogous to numpy.allclose. It returns True if two tensors are elementwise equal within a tolerance.revert-23830-2.0-beta
parent
948c57d84b
commit
56b50c97f8
@ -0,0 +1,123 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "paddle/fluid/operators/allclose_op.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class AllcloseOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("Input", "The first input tensor to compare.");
|
||||
AddInput("Other", "The second input tensor to compare.");
|
||||
AddOutput("Out", "The output tensor of allclose op.");
|
||||
|
||||
AddAttr<float>("rtol", "The relative tolerance. Default: :math:`1e-5` .")
|
||||
.SetDefault(1e-5);
|
||||
AddAttr<float>("atol", "The absolute tolerance. Default: :math:`1e-8` .")
|
||||
.SetDefault(1e-8);
|
||||
AddAttr<bool>("equal_nan",
|
||||
"If :math:`True` , then two :math:`NaNs` will be "
|
||||
"compared as equal. Default: :math:`False` .")
|
||||
.SetDefault(false);
|
||||
|
||||
AddComment(R"DOC(
|
||||
This operator checks if all :math:`input` and :math:`other` satisfy the condition:
|
||||
|
||||
:math:`\left| input - other \right| \leq atol + rtol \times \left| other \right|`
|
||||
|
||||
elementwise, for all elements of :math:`input` and :math:`other`. The behaviour of this
|
||||
operator is analogous to :math:`numpy.allclose`, namely that it returns :math:`True` if
|
||||
two tensors are elementwise equal within a tolerance.
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class AllcloseOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext *ctx) const override {
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("Input"), true,
|
||||
platform::errors::NotFound(
|
||||
"Input(Input) of allclose op should not be null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasInput("Other"), true,
|
||||
platform::errors::NotFound(
|
||||
"Input(Other) of allclose op should not be null."));
|
||||
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
||||
platform::errors::NotFound(
|
||||
"The output(Out) of allclose op must not be null."));
|
||||
|
||||
auto input_dim = ctx->GetInputDim("Input");
|
||||
auto other_dim = ctx->GetInputDim("Other");
|
||||
PADDLE_ENFORCE_EQ(input_dim.size(), other_dim.size(),
|
||||
platform::errors::PreconditionNotMet(
|
||||
"Input(Input) and Input(Other) must have the same "
|
||||
"dimension size."));
|
||||
int n = input_dim.size();
|
||||
bool is_runtime = ctx->IsRuntime();
|
||||
for (int i = 0; i < n; i++) {
|
||||
if (is_runtime) {
|
||||
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The value at dim %d of Input(Input) is not "
|
||||
"equal to the Input(Other): %ld != %ld.",
|
||||
i, input_dim[i], other_dim[i]));
|
||||
} else {
|
||||
if (!(input_dim[i] < 0 || other_dim[i] < 0)) {
|
||||
PADDLE_ENFORCE_EQ(input_dim[i], other_dim[i],
|
||||
platform::errors::PreconditionNotMet(
|
||||
"The value at dim %d of Input(Input) is not "
|
||||
"equal to the Input(Other): %ld != %ld.",
|
||||
i, input_dim[i], other_dim[i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ctx->SetOutputDim("Out", framework::make_ddim({1}));
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext &ctx) const override {
|
||||
return framework::OpKernelType(
|
||||
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
|
||||
ctx.device_context());
|
||||
}
|
||||
};
|
||||
|
||||
class AllcloseOpVarTypeInference : public framework::VarTypeInference {
|
||||
public:
|
||||
void operator()(framework::InferVarTypeContext *ctx) const override {
|
||||
auto out_var_name = ctx->Output("Out").front();
|
||||
ctx->SetDataType(out_var_name, framework::proto::VarType::BOOL);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CPU = paddle::platform::CPUDeviceContext;
|
||||
|
||||
REGISTER_OPERATOR(
|
||||
allclose, ops::AllcloseOp, ops::AllcloseOpMaker,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
|
||||
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
|
||||
ops::AllcloseOpVarTypeInference);
|
||||
REGISTER_OP_CPU_KERNEL(allclose, ops::AllcloseKernel<CPU, float>,
|
||||
ops::AllcloseKernel<CPU, double>);
|
@ -0,0 +1,24 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/operators/allclose_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
using CUDA = paddle::platform::CUDADeviceContext;
|
||||
REGISTER_OP_CUDA_KERNEL(allclose, ops::AllcloseKernel<CUDA, float>,
|
||||
ops::AllcloseKernel<CUDA, double>);
|
@ -0,0 +1,61 @@
|
||||
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/eigen.h"
|
||||
#include "paddle/fluid/framework/operator.h"
|
||||
#include "paddle/fluid/platform/place.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class AllcloseKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const override {
|
||||
// get attrs
|
||||
float rtol = ctx.Attr<float>("rtol");
|
||||
float atol = ctx.Attr<float>("atol");
|
||||
bool equal_nan = ctx.Attr<bool>("equal_nan");
|
||||
// get input/output
|
||||
auto* input = ctx.Input<Tensor>("Input");
|
||||
auto* other = ctx.Input<Tensor>("Other");
|
||||
auto* out = ctx.Output<Tensor>("Out");
|
||||
out->mutable_data<bool>(ctx.GetPlace());
|
||||
// get place
|
||||
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
|
||||
|
||||
auto input_v = framework::EigenVector<T>::Flatten(*input);
|
||||
auto other_v = framework::EigenVector<T>::Flatten(*other);
|
||||
auto out_v = framework::EigenScalar<bool>::From(*out);
|
||||
|
||||
auto left = (input_v - other_v).abs();
|
||||
auto right = static_cast<T>(atol) + static_cast<T>(rtol) * other_v.abs();
|
||||
auto compare_res = left <= right;
|
||||
|
||||
if (equal_nan) {
|
||||
auto input_nan = input_v.isnan();
|
||||
auto other_nan = other_v.isnan();
|
||||
out_v.device(place) =
|
||||
(input_nan == other_nan).all() && (compare_res != input_nan).all();
|
||||
} else {
|
||||
out_v.device(place) = compare_res.all();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,140 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
|
||||
class TestAllcloseLayer(unittest.TestCase):
|
||||
def allclose_check(self, use_cuda):
|
||||
a = fluid.data(name="a", shape=[2], dtype='float32')
|
||||
b = fluid.data(name="b", shape=[2], dtype='float32')
|
||||
|
||||
result = paddle.allclose(
|
||||
a, b, rtol=1e-05, atol=1e-08, equal_nan=False, name="ignore_nan")
|
||||
result_nan = paddle.allclose(
|
||||
a, b, rtol=1e-05, atol=1e-08, equal_nan=True, name="equal_nan")
|
||||
|
||||
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
|
||||
exe = fluid.Executor(place)
|
||||
exe.run(fluid.default_startup_program())
|
||||
|
||||
x = np.array([10000., 1e-07]).astype("float32")
|
||||
y = np.array([10000.1, 1e-08]).astype("float32")
|
||||
result_v, result_nan_v = exe.run(feed={'a': x,
|
||||
'b': y},
|
||||
fetch_list=[result, result_nan])
|
||||
self.assertEqual(result_v[0], False)
|
||||
self.assertEqual(result_nan_v[0], False)
|
||||
|
||||
x = np.array([10000., 1e-08]).astype("float32")
|
||||
y = np.array([10000.1, 1e-09]).astype("float32")
|
||||
result_v, result_nan_v = exe.run(feed={'a': x,
|
||||
'b': y},
|
||||
fetch_list=[result, result_nan])
|
||||
self.assertEqual(result_v[0], True)
|
||||
self.assertEqual(result_nan_v[0], True)
|
||||
|
||||
x = np.array([1.0, float('nan')]).astype("float32")
|
||||
y = np.array([1.0, float('nan')]).astype("float32")
|
||||
result_v, result_nan_v = exe.run(feed={'a': x,
|
||||
'b': y},
|
||||
fetch_list=[result, result_nan])
|
||||
self.assertEqual(result_v[0], False)
|
||||
self.assertEqual(result_nan_v[0], True)
|
||||
|
||||
def test_allclose_cpu(self):
|
||||
main = fluid.Program()
|
||||
startup = fluid.Program()
|
||||
with fluid.unique_name.guard():
|
||||
with fluid.program_guard(main, startup):
|
||||
self.allclose_check(use_cuda=False)
|
||||
|
||||
def test_allclose_gpu(self):
|
||||
if fluid.core.is_compiled_with_cuda():
|
||||
main = fluid.Program()
|
||||
startup = fluid.Program()
|
||||
with fluid.unique_name.guard():
|
||||
with fluid.program_guard(main, startup):
|
||||
self.allclose_check(use_cuda=True)
|
||||
|
||||
def test_dygraph_mode(self):
|
||||
x_1 = np.array([10000., 1e-07]).astype("float32")
|
||||
y_1 = np.array([10000.1, 1e-08]).astype("float32")
|
||||
x_2 = np.array([10000., 1e-08]).astype("float32")
|
||||
y_2 = np.array([10000.1, 1e-09]).astype("float32")
|
||||
x_3 = np.array([1.0, float('nan')]).astype("float32")
|
||||
y_3 = np.array([1.0, float('nan')]).astype("float32")
|
||||
|
||||
with fluid.dygraph.guard():
|
||||
x_v_1 = fluid.dygraph.to_variable(x_1)
|
||||
y_v_1 = fluid.dygraph.to_variable(y_1)
|
||||
ret_1 = paddle.allclose(
|
||||
x_v_1,
|
||||
y_v_1,
|
||||
rtol=1e-05,
|
||||
atol=1e-08,
|
||||
equal_nan=False,
|
||||
name='test_1')
|
||||
self.assertEqual(ret_1.numpy()[0], False)
|
||||
ret_1 = paddle.allclose(
|
||||
x_v_1,
|
||||
y_v_1,
|
||||
rtol=1e-05,
|
||||
atol=1e-08,
|
||||
equal_nan=True,
|
||||
name='test_2')
|
||||
self.assertEqual(ret_1.numpy()[0], False)
|
||||
x_v_2 = fluid.dygraph.to_variable(x_2)
|
||||
y_v_2 = fluid.dygraph.to_variable(y_2)
|
||||
ret_2 = paddle.allclose(
|
||||
x_v_2,
|
||||
y_v_2,
|
||||
rtol=1e-05,
|
||||
atol=1e-08,
|
||||
equal_nan=False,
|
||||
name='test_3')
|
||||
self.assertEqual(ret_2.numpy()[0], True)
|
||||
ret_2 = paddle.allclose(
|
||||
x_v_2,
|
||||
y_v_2,
|
||||
rtol=1e-05,
|
||||
atol=1e-08,
|
||||
equal_nan=True,
|
||||
name='test_4')
|
||||
self.assertEqual(ret_2.numpy()[0], True)
|
||||
x_v_3 = fluid.dygraph.to_variable(x_3)
|
||||
y_v_3 = fluid.dygraph.to_variable(y_3)
|
||||
ret_3 = paddle.allclose(
|
||||
x_v_3,
|
||||
y_v_3,
|
||||
rtol=1e-05,
|
||||
atol=1e-08,
|
||||
equal_nan=False,
|
||||
name='test_5')
|
||||
self.assertEqual(ret_3.numpy()[0], False)
|
||||
ret_3 = paddle.allclose(
|
||||
x_v_3,
|
||||
y_v_3,
|
||||
rtol=1e-05,
|
||||
atol=1e-08,
|
||||
equal_nan=True,
|
||||
name='test_6')
|
||||
self.assertEqual(ret_3.numpy()[0], True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -0,0 +1,80 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
class TestAllcloseOp(OpTest):
|
||||
def set_args(self):
|
||||
self.input = np.array([10000., 1e-07]).astype("float32")
|
||||
self.other = np.array([10000.1, 1e-08]).astype("float32")
|
||||
self.rtol = 1e-05
|
||||
self.atol = 1e-08
|
||||
self.equal_nan = False
|
||||
|
||||
def setUp(self):
|
||||
self.set_args()
|
||||
self.op_type = "allclose"
|
||||
self.inputs = {'Input': self.input, 'Other': self.other}
|
||||
self.attrs = {
|
||||
'rtol': self.rtol,
|
||||
'atol': self.atol,
|
||||
'equal_nan': self.equal_nan
|
||||
}
|
||||
self.outputs = {
|
||||
'Out': np.array([
|
||||
np.allclose(
|
||||
self.inputs['Input'],
|
||||
self.inputs['Other'],
|
||||
rtol=self.rtol,
|
||||
atol=self.atol,
|
||||
equal_nan=self.equal_nan)
|
||||
])
|
||||
}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
|
||||
class TestAllcloseOpSmallNum(TestAllcloseOp):
|
||||
def set_args(self):
|
||||
self.input = np.array([10000., 1e-08]).astype("float32")
|
||||
self.other = np.array([10000.1, 1e-09]).astype("float32")
|
||||
self.rtol = 1e-05
|
||||
self.atol = 1e-08
|
||||
self.equal_nan = False
|
||||
|
||||
|
||||
class TestAllcloseOpNanFalse(TestAllcloseOp):
|
||||
def set_args(self):
|
||||
self.input = np.array([1.0, float('nan')]).astype("float32")
|
||||
self.other = np.array([1.0, float('nan')]).astype("float32")
|
||||
self.rtol = 1e-05
|
||||
self.atol = 1e-08
|
||||
self.equal_nan = False
|
||||
|
||||
|
||||
class TestAllcloseOpNanTrue(TestAllcloseOp):
|
||||
def set_args(self):
|
||||
self.input = np.array([1.0, float('nan')]).astype("float32")
|
||||
self.other = np.array([1.0, float('nan')]).astype("float32")
|
||||
self.rtol = 1e-05
|
||||
self.atol = 1e-08
|
||||
self.equal_nan = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue