Update the api for the compare_ops

Update the code for the compare_ops, update the api and doc
fix_copy_if_different
wawltor 5 years ago committed by GitHub
parent fc6fed3283
commit 595a719795
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -114,7 +114,7 @@ function(op_library TARGET)
endif() endif()
# Define operators that don't need pybind here. # Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_reduce_op" "compare_op" "logical_op" "nccl_op" foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"

@ -9,4 +9,4 @@ cc_test(conditional_block_op_test SRCS conditional_block_op_test.cc DEPS conditi
target_link_libraries(conditional_block_infer_op conditional_block_op) target_link_libraries(conditional_block_infer_op conditional_block_op)
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_reduce);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")

@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_reduce_op.h" #include "paddle/fluid/operators/controlflow/compare_all_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
@ -30,38 +30,44 @@ class CompareReduceOpKernel
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y"); auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out"); auto* z = context.Output<Tensor>("Out");
int axis = context.Attr<int>("axis"); bool shape_same = true;
Tensor tmp; Tensor tmp;
framework::DDim x_dims = x->dims(); framework::DDim x_dims = x->dims();
framework::DDim y_dims = y->dims(); framework::DDim y_dims = y->dims();
int max_dim = std::max(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); // judge the two inputs shape is same, if not same, just return false
std::vector<int> x_dims_array(max_dim); if (x_dims.size() != y_dims.size()) {
std::vector<int> y_dims_array(max_dim); shape_same = false;
std::vector<int> tmp_dims_array(max_dim);
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(),
y_dims_array.data(), tmp_dims_array.data(), max_dim,
axis);
tmp.mutable_data<bool>(framework::make_ddim(tmp_dims_array),
context.GetPlace());
if (x->numel() == 1 && y->numel() == 1) {
bool* z_data = tmp.mutable_data<bool>(context.GetPlace());
z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);
} else { } else {
ElementwiseComputeEx<Functor, platform::CPUDeviceContext, T, bool>( for (auto i = 0; i < x_dims.size(); i++) {
context, x, y, axis, Functor(), &tmp); if (x_dims[i] != y_dims[i]) {
shape_same = false;
break;
}
}
} }
// Reduce by 'logical and' operator bool* z_data = z->mutable_data<bool>(context.GetPlace());
z->mutable_data<bool>(context.GetPlace()); if (!shape_same) {
auto ipt = framework::EigenVector<bool>::Flatten(tmp); z_data[0] = false;
auto out = framework::EigenScalar<bool>::From(*z); } else {
auto& place = *context.template device_context<platform::CPUDeviceContext>() tmp.mutable_data<bool>(x_dims, context.GetPlace());
.eigen_device(); if (x->numel() == 1 && y->numel() == 1) {
auto reduce_dim = Eigen::array<int, 1>({{0}}); bool* z_data = tmp.mutable_data<bool>(context.GetPlace());
out.device(place) = ipt.all(reduce_dim); z_data[0] = Functor()(x->data<T>()[0], y->data<T>()[0]);
} else {
ElementwiseComputeEx<Functor, platform::CPUDeviceContext, T, bool>(
context, x, y, 0, Functor(), &tmp);
}
auto ipt = framework::EigenVector<bool>::Flatten(tmp);
auto out = framework::EigenScalar<bool>::From(*z);
auto& place =
*context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
auto reduce_dim = Eigen::array<int, 1>({{0}});
out.device(place) = ipt.all(reduce_dim);
}
} }
}; };
@ -74,11 +80,6 @@ class CompareReduceOpProtoMaker : public framework::OpProtoAndCheckerMaker {
comment.type)); comment.type));
AddInput("Y", string::Sprintf("the right hand operand of %s operator", AddInput("Y", string::Sprintf("the right hand operand of %s operator",
comment.type)); comment.type));
AddAttr<int>(
"axis",
"The start dimension index for broadcasting Y onto X. [default -1]")
.SetDefault(-1)
.EqualGreaterThan(-1);
AddOutput("Out", string::Sprintf( AddOutput("Out", string::Sprintf(
"tensor with a bool element. If all " "tensor with a bool element. If all "
"element %s, the Out tensor is [True], else [False]", "element %s, the Out tensor is [True], else [False]",
@ -144,7 +145,7 @@ class CompareReduceOp : public framework::OperatorWithKernel {
::paddle::platform::CPUDeviceContext, functor<float>>, \ ::paddle::platform::CPUDeviceContext, functor<float>>, \
::paddle::operators::CompareReduceOpKernel< \ ::paddle::operators::CompareReduceOpKernel< \
::paddle::platform::CPUDeviceContext, functor<double>>); ::paddle::platform::CPUDeviceContext, functor<double>>);
REGISTER_COMPARE_REDUCE_OP(equal_reduce, "X == Y"); REGISTER_COMPARE_REDUCE_OP(equal_all, "X == Y");
REGISTER_COMPARE_REDUCE_CPU_KERNEL(equal_reduce, REGISTER_COMPARE_REDUCE_CPU_KERNEL(equal_all,
paddle::operators::EqualReduceFunctor); paddle::operators::EqualReduceFunctor);

@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/controlflow/compare_reduce_op.h" #include <thrust/fill.h>
#include "paddle/fluid/operators/controlflow/compare_all_op.h"
#include "paddle/fluid/operators/reduce_ops/cub_reduce.h" #include "paddle/fluid/operators/reduce_ops/cub_reduce.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
@ -43,31 +44,41 @@ class CompareReduceOpKernel
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* y = context.Input<Tensor>("Y"); auto* y = context.Input<Tensor>("Y");
auto* z = context.Output<Tensor>("Out"); auto* z = context.Output<Tensor>("Out");
int axis = context.Attr<int>("axis"); bool shape_same = true;
Tensor tmp; Tensor tmp;
framework::DDim x_dims = x->dims(); framework::DDim x_dims = x->dims();
framework::DDim y_dims = y->dims(); framework::DDim y_dims = y->dims();
int max_dim = std::max(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); if (x_dims.size() != y_dims.size()) {
std::vector<int> x_dims_array(max_dim); shape_same = false;
std::vector<int> y_dims_array(max_dim); } else {
std::vector<int> tmp_dims_array(max_dim); for (auto i = 0; i < x_dims.size(); i++) {
GetBroadcastDimsArrays(x_dims, y_dims, x_dims_array.data(), if (x_dims[i] != y_dims[i]) {
y_dims_array.data(), tmp_dims_array.data(), max_dim, shape_same = false;
axis); break;
tmp.mutable_data<bool>(framework::make_ddim(tmp_dims_array), }
context.GetPlace()); }
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, axis, }
Functor(), &tmp);
// Reduce by 'bitwise and' operator bool* z_data = z->mutable_data<bool>(context.GetPlace());
std::vector<int> reduce_dims; if (!shape_same) {
reduce_dims.resize(tmp.dims().size()); thrust::device_ptr<bool> z_dev_ptr(z_data);
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i; thrust::fill(z_dev_ptr, z_dev_ptr + 1, false);
auto stream = context.cuda_device_context().stream(); return;
TensorReduce<bool, bool, BitwiseAdd, IdentityFunctor<bool>>( } else {
tmp, z, reduce_dims, true, BitwiseAdd(), IdentityFunctor<bool>(), tmp.mutable_data<bool>(x_dims, context.GetPlace());
stream); ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, 0,
Functor(), &tmp);
// Reduce by 'bitwise and' operator
std::vector<int> reduce_dims;
reduce_dims.resize(tmp.dims().size());
for (int i = 0; i < reduce_dims.size(); ++i) reduce_dims[i] = i;
auto stream = context.cuda_device_context().stream();
TensorReduce<bool, bool, BitwiseAdd, IdentityFunctor<bool>>(
tmp, z, reduce_dims, true, BitwiseAdd(), IdentityFunctor<bool>(),
stream);
}
} }
}; };
@ -84,5 +95,5 @@ class CompareReduceOpKernel
paddle::platform::CUDADeviceContext, functor<float>>, \ paddle::platform::CUDADeviceContext, functor<float>>, \
paddle::operators::CompareReduceOpKernel< \ paddle::operators::CompareReduceOpKernel< \
paddle::platform::CUDADeviceContext, functor<double>>); paddle::platform::CUDADeviceContext, functor<double>>);
REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_reduce, REGISTER_COMPARE_REDUCE_CUDA_KERNEL(equal_all,
paddle::operators::EqualReduceFunctor); paddle::operators::EqualReduceFunctor);

@ -98,7 +98,7 @@ from .tensor.logic import not_equal #DEFINE_ALIAS
from .tensor.logic import reduce_all #DEFINE_ALIAS from .tensor.logic import reduce_all #DEFINE_ALIAS
from .tensor.logic import reduce_any #DEFINE_ALIAS from .tensor.logic import reduce_any #DEFINE_ALIAS
from .tensor.logic import allclose #DEFINE_ALIAS from .tensor.logic import allclose #DEFINE_ALIAS
from .tensor.logic import elementwise_equal #DEFINE_ALIAS from .tensor.logic import equal_all #DEFINE_ALIAS
# from .tensor.logic import isnan #DEFINE_ALIAS # from .tensor.logic import isnan #DEFINE_ALIAS
from .tensor.manipulation import cast #DEFINE_ALIAS from .tensor.manipulation import cast #DEFINE_ALIAS
from .tensor.manipulation import concat #DEFINE_ALIAS from .tensor.manipulation import concat #DEFINE_ALIAS

@ -1580,7 +1580,7 @@ def create_array(dtype):
@templatedoc() @templatedoc()
def less_than(x, y, force_cpu=None, cond=None): def less_than(x, y, force_cpu=None, cond=None, name=None):
""" """
:alias_main: paddle.less_than :alias_main: paddle.less_than
:alias: paddle.less_than,paddle.tensor.less_than,paddle.tensor.logic.less_than :alias: paddle.less_than,paddle.tensor.less_than,paddle.tensor.logic.less_than
@ -1595,6 +1595,8 @@ def less_than(x, y, force_cpu=None, cond=None):
cond(Variable, optional): Optional output which can be any created Variable cond(Variable, optional): Optional output which can be any created Variable
that meets the requirements to store the result of *less_than*. that meets the requirements to store the result of *less_than*.
if cond is None, a new Varibale will be created to store the result. if cond is None, a new Varibale will be created to store the result.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
${out_comment}. ${out_comment}.
@ -1649,7 +1651,7 @@ def less_than(x, y, force_cpu=None, cond=None):
@templatedoc() @templatedoc()
def less_equal(x, y, cond=None): def less_equal(x, y, cond=None, name=None):
""" """
:alias_main: paddle.less_equal :alias_main: paddle.less_equal
:alias: paddle.less_equal,paddle.tensor.less_equal,paddle.tensor.logic.less_equal :alias: paddle.less_equal,paddle.tensor.less_equal,paddle.tensor.logic.less_equal
@ -1662,6 +1664,8 @@ def less_equal(x, y, cond=None):
y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *less_equal*. cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *less_equal*.
if cond is None, a new Varibale will be created to store the result. if cond is None, a new Varibale will be created to store the result.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x`. Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x`.
@ -1701,7 +1705,7 @@ def less_equal(x, y, cond=None):
@templatedoc() @templatedoc()
def greater_than(x, y, cond=None): def greater_than(x, y, cond=None, name=None):
""" """
:alias_main: paddle.greater_than :alias_main: paddle.greater_than
:alias: paddle.greater_than,paddle.tensor.greater_than,paddle.tensor.logic.greater_than :alias: paddle.greater_than,paddle.tensor.greater_than,paddle.tensor.logic.greater_than
@ -1714,6 +1718,8 @@ def greater_than(x, y, cond=None):
y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *greater_than*. cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *greater_than*.
if cond is None, a new Varibale will be created to store the result. if cond is None, a new Varibale will be created to store the result.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x` . Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x` .
@ -1752,7 +1758,7 @@ def greater_than(x, y, cond=None):
@templatedoc() @templatedoc()
def greater_equal(x, y, cond=None): def greater_equal(x, y, cond=None, name=None):
""" """
:alias_main: paddle.greater_equal :alias_main: paddle.greater_equal
:alias: paddle.greater_equal,paddle.tensor.greater_equal,paddle.tensor.logic.greater_equal :alias: paddle.greater_equal,paddle.tensor.greater_equal,paddle.tensor.logic.greater_equal
@ -1765,6 +1771,8 @@ def greater_equal(x, y, cond=None):
y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *greater_equal*. cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *greater_equal*.
if cond is None, a new Varibale will be created to store the result. if cond is None, a new Varibale will be created to store the result.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x`. Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x`.
@ -1804,7 +1812,7 @@ def greater_equal(x, y, cond=None):
return cond return cond
def equal(x, y, cond=None): def equal(x, y, cond=None, name=None):
""" """
This layer returns the truth value of :math:`x == y` elementwise. This layer returns the truth value of :math:`x == y` elementwise.
@ -1814,6 +1822,8 @@ def equal(x, y, cond=None):
cond(Variable, optional): Optional output which can be any created cond(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of *equal*. Variable that meets the requirements to store the result of *equal*.
if cond is None, a new Varibale will be created to store the result. if cond is None, a new Varibale will be created to store the result.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable: output Tensor, it's shape is the same as the input's Tensor, Variable: output Tensor, it's shape is the same as the input's Tensor,
@ -1849,7 +1859,7 @@ def equal(x, y, cond=None):
return cond return cond
def not_equal(x, y, cond=None): def not_equal(x, y, cond=None, name=None):
""" """
:alias_main: paddle.not_equal :alias_main: paddle.not_equal
:alias: paddle.not_equal,paddle.tensor.not_equal,paddle.tensor.logic.not_equal :alias: paddle.not_equal,paddle.tensor.not_equal,paddle.tensor.logic.not_equal
@ -1862,6 +1872,8 @@ def not_equal(x, y, cond=None):
y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64. y(Variable): Second input to compare which is N-D tensor. The input data type should be float32, float64, int32, int64.
cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *not_equal*. cond(Variable, optional): Optional output which can be any created Variable that meets the requirements to store the result of *not_equal*.
if cond is None, a new Varibale will be created to store the result. if cond is None, a new Varibale will be created to store the result.
name(str, optional): The default value is None. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
Returns: Returns:
Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x`. Variable, the output data type is bool: The tensor variable storing the output, the output shape is same as input :attr:`x`.

@ -20,6 +20,7 @@ import numpy
import numpy as np import numpy as np
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
@ -67,6 +68,49 @@ for _type_name in {'float32', 'float64', 'int32', 'int64'}:
create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b) create_test_class('not_equal', _type_name, lambda _a, _b: _a != _b)
def create_paddle_case(op_type, callback):
class PaddleCls(unittest.TestCase):
def setUp(self):
self.op_type = op_type
self.input_x = np.array([1, 2, 3, 4])
self.input_y = np.array([1, 3, 2, 4])
self.real_result = callback(self.input_x, self.input_y)
def test_api(self):
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[4], dtype='int64')
y = fluid.layers.data(name='y', shape=[4], dtype='int64')
op = eval("paddle.%s" % (self.op_type))
out = op(x, y)
place = fluid.CPUPlace()
if core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = fluid.Executor(place)
res, = exe.run(feed={"x": self.input_x,
"y": self.input_y},
fetch_list=[out])
self.assertEqual((res == self.real_result).all(), True)
def test_attr_name(self):
with program_guard(Program(), Program()):
x = fluid.layers.data(name='x', shape=[4], dtype='int32')
y = fluid.layers.data(name='y', shape=[4], dtype='int32')
op = eval("paddle.%s" % (self.op_type))
out = op(x=x, y=y, name="name_%s" % (self.op_type))
self.assertEqual("name_%s" % (self.op_type) in out.name, True)
cls_name = "TestCase_{}".format(op_type)
PaddleCls.__name__ = cls_name
globals()[cls_name] = PaddleCls
create_paddle_case('less_equal', lambda _a, _b: _a <= _b)
create_paddle_case('greater_than', lambda _a, _b: _a > _b)
create_paddle_case('greater_equal', lambda _a, _b: _a >= _b)
create_paddle_case('equal', lambda _a, _b: _a == _b)
create_paddle_case('not_equal', lambda _a, _b: _a != _b)
class TestCompareOpError(unittest.TestCase): class TestCompareOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
@ -82,7 +126,7 @@ class API_TestElementwise_Equal(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
label = fluid.layers.assign(np.array([3, 3], dtype="int32")) label = fluid.layers.assign(np.array([3, 3], dtype="int32"))
limit = fluid.layers.assign(np.array([3, 2], dtype="int32")) limit = fluid.layers.assign(np.array([3, 2], dtype="int32"))
out = paddle.elementwise_equal(x=label, y=limit) out = paddle.equal(x=label, y=limit)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
res, = exe.run(fetch_list=[out]) res, = exe.run(fetch_list=[out])
@ -91,7 +135,7 @@ class API_TestElementwise_Equal(unittest.TestCase):
with fluid.program_guard(fluid.Program(), fluid.Program()): with fluid.program_guard(fluid.Program(), fluid.Program()):
label = fluid.layers.assign(np.array([3, 3], dtype="int32")) label = fluid.layers.assign(np.array([3, 3], dtype="int32"))
limit = fluid.layers.assign(np.array([3, 3], dtype="int32")) limit = fluid.layers.assign(np.array([3, 3], dtype="int32"))
out = paddle.elementwise_equal(x=label, y=limit) out = paddle.equal(x=label, y=limit)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
res, = exe.run(fetch_list=[out]) res, = exe.run(fetch_list=[out])

@ -22,30 +22,29 @@ import paddle.fluid as fluid
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
def create_test_broadcast_class(op_type, args, callback): def create_test_not_equal_class(op_type, typename, callback):
class Cls(op_test.OpTest): class Cls(op_test.OpTest):
def setUp(self): def setUp(self):
x = np.random.random(size=args['x_size']).astype('int32') x = np.random.random(size=(10, 7)).astype(typename)
y = np.random.random(size=args['y_size']).astype('int32') y = np.random.random(size=(10, 7)).astype(typename)
z = callback(x, y) z = callback(x, y)
self.inputs = {'X': x, 'Y': y} self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': z} self.outputs = {'Out': z}
self.op_type = op_type self.op_type = op_type
self.axis = args['axis']
def test_output(self): def test_output(self):
self.check_output() self.check_output()
cls_name = "{0}_{1}".format(op_type, 'broadcast') cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_equal_all')
Cls.__name__ = cls_name Cls.__name__ = cls_name
globals()[cls_name] = Cls globals()[cls_name] = Cls
def create_test_not_equal_class(op_type, typename, callback): def create_test_not_shape_equal_class(op_type, typename, callback):
class Cls(op_test.OpTest): class Cls(op_test.OpTest):
def setUp(self): def setUp(self):
x = np.random.random(size=(10, 7)).astype(typename) x = np.random.random(size=(10, 7)).astype(typename)
y = np.random.random(size=(10, 7)).astype(typename) y = np.random.random(size=(10)).astype(typename)
z = callback(x, y) z = callback(x, y)
self.inputs = {'X': x, 'Y': y} self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': z} self.outputs = {'Out': z}
@ -54,7 +53,7 @@ def create_test_not_equal_class(op_type, typename, callback):
def test_output(self): def test_output(self):
self.check_output() self.check_output()
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_equal') cls_name = "{0}_{1}_{2}".format(op_type, typename, 'not_shape_equal_all')
Cls.__name__ = cls_name Cls.__name__ = cls_name
globals()[cls_name] = Cls globals()[cls_name] = Cls
@ -71,7 +70,7 @@ def create_test_equal_class(op_type, typename, callback):
def test_output(self): def test_output(self):
self.check_output() self.check_output()
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal') cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal_all')
Cls.__name__ = cls_name Cls.__name__ = cls_name
globals()[cls_name] = Cls globals()[cls_name] = Cls
@ -88,7 +87,7 @@ def create_test_dim1_class(op_type, typename, callback):
def test_output(self): def test_output(self):
self.check_output() self.check_output()
cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal') cls_name = "{0}_{1}_{2}".format(op_type, typename, 'equal_all')
Cls.__name__ = cls_name Cls.__name__ = cls_name
globals()[cls_name] = Cls globals()[cls_name] = Cls
@ -96,59 +95,16 @@ def create_test_dim1_class(op_type, typename, callback):
np_equal = lambda _x, _y: np.array(np.array_equal(_x, _y)) np_equal = lambda _x, _y: np.array(np.array_equal(_x, _y))
for _type_name in {'float32', 'float64', 'int32', 'int64'}: for _type_name in {'float32', 'float64', 'int32', 'int64'}:
create_test_not_equal_class('equal_reduce', _type_name, np_equal) create_test_not_equal_class('equal_all', _type_name, np_equal)
create_test_equal_class('equal_reduce', _type_name, np_equal) create_test_equal_class('equal_all', _type_name, np_equal)
create_test_dim1_class('equal_reduce', _type_name, np_equal) create_test_dim1_class('equal_all', _type_name, np_equal)
broadcast_args = [{
'x_size': (100, 2, 3),
'y_size': (100),
'axis': 0
}, {
'x_size': (2, 100, 3),
'y_size': (100),
'axis': 1
}, {
'x_size': (2, 3, 100),
'y_size': (1, 1),
'axis': -1
}, {
'x_size': (2, 10, 12, 3),
'y_size': (10, 12),
'axis': 1
}, {
'x_size': (100, 2, 3, 4),
'y_size': (100, 1),
'axis': 0
}, {
'x_size': (10, 3, 12),
'y_size': (10, 1, 12),
'axis': -1
}, {
'x_size': (2, 12, 3, 5),
'y_size': (2, 12, 1, 5),
'axis': -1
}, {
'x_size': (2, 12, 3, 5),
'y_size': (3, 5),
'axis': 2
}]
def np_broadcast_equal(_x, _y):
res = np.all(np.equal(_x, _y))
return np.array(res)
for args in broadcast_args:
create_test_broadcast_class('equal_reduce', args, np_broadcast_equal)
class TestEqualReduceAPI(unittest.TestCase): class TestEqualReduceAPI(unittest.TestCase):
def test_name(self): def test_name(self):
x = fluid.layers.assign(np.array([3, 4], dtype="int32")) x = fluid.layers.assign(np.array([3, 4], dtype="int32"))
y = fluid.layers.assign(np.array([3, 4], dtype="int32")) y = fluid.layers.assign(np.array([3, 4], dtype="int32"))
out = paddle.equal(x, y, name='equal_res') out = paddle.equal_all(x, y, name='equal_res')
assert 'equal_res' in out.name assert 'equal_res' in out.name

@ -71,7 +71,7 @@ from .logic import not_equal #DEFINE_ALIAS
from .logic import reduce_all #DEFINE_ALIAS from .logic import reduce_all #DEFINE_ALIAS
from .logic import reduce_any #DEFINE_ALIAS from .logic import reduce_any #DEFINE_ALIAS
from .logic import allclose #DEFINE_ALIAS from .logic import allclose #DEFINE_ALIAS
from .logic import elementwise_equal #DEFINE_ALIAS from .logic import equal_all #DEFINE_ALIAS
# from .logic import isnan #DEFINE_ALIAS # from .logic import isnan #DEFINE_ALIAS
from .manipulation import cast #DEFINE_ALIAS from .manipulation import cast #DEFINE_ALIAS
from .manipulation import concat #DEFINE_ALIAS from .manipulation import concat #DEFINE_ALIAS

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save