add not_equal

emailweixu-patch-1
qiaolongfei 7 years ago
parent 23ba79b16b
commit 6f78cb9969

@ -102,3 +102,5 @@ REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_OP(not_equal, "Out = X != Y");
REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);

@ -17,3 +17,4 @@ limitations under the License. */
REGISTER_LOGICAL_KERNEL(less_than, CUDA, paddle::operators::LessThanFunctor);
REGISTER_LOGICAL_KERNEL(less_equal, CUDA, paddle::operators::LessEqualFunctor);
REGISTER_LOGICAL_KERNEL(equal, CUDA, paddle::operators::EqualFunctor);
REGISTER_LOGICAL_KERNEL(not_equal, CUDA, paddle::operators::NotEqualFunctor);

@ -48,6 +48,14 @@ struct EqualFunctor {
}
};
template <typename T>
struct NotEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
return !EqualFunctor<T>()(a, b);
}
};
template <typename DeviceContext, typename Functor>
class CompareOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {

@ -154,8 +154,9 @@ def monkey_patch_variable():
("__rpow__", "elementwise_pow", True),
# for logical compare
("__eq__", "equal", False),
("__ne__", "not_equal", False),
("__lt__", "less_than", False),
("__le__", "less_equal", False), ):
("__le__", "less_equal", False)):
setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse))

@ -53,6 +53,7 @@ class TestPythonOperatorOverride(unittest.TestCase):
lambda _a, _b: _a > _b,
lambda _a, _b: _a <= _b,
lambda _a, _b: _a >= _b,
lambda _a, _b: _a != _b,
]
# places to check

Loading…
Cancel
Save