|
|
|
@ -83,7 +83,7 @@ class CompareOp : public framework::OperatorWithKernel {
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
#define REGISTER_LOGICAL_OP(op_type, _equation) \
|
|
|
|
|
#define REGISTER_COMPARE_OP(op_type, _equation) \
|
|
|
|
|
struct _##op_type##Comment { \
|
|
|
|
|
static char type[]; \
|
|
|
|
|
static char equation[]; \
|
|
|
|
@ -96,11 +96,17 @@ class CompareOp : public framework::OperatorWithKernel {
|
|
|
|
|
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
|
|
|
|
|
::paddle::framework::EmptyGradOpMaker);
|
|
|
|
|
|
|
|
|
|
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
|
|
|
|
|
REGISTER_LOGICAL_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
|
|
|
|
|
REGISTER_LOGICAL_OP(less_equal, "Out = X <= Y");
|
|
|
|
|
REGISTER_LOGICAL_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
|
|
|
|
|
REGISTER_LOGICAL_OP(equal, "Out = X == Y");
|
|
|
|
|
REGISTER_LOGICAL_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
|
|
|
|
|
REGISTER_LOGICAL_OP(not_equal, "Out = X != Y");
|
|
|
|
|
REGISTER_LOGICAL_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);
|
|
|
|
|
REGISTER_COMPARE_OP(less_than, "Out = X < Y");
|
|
|
|
|
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
|
|
|
|
|
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
|
|
|
|
|
REGISTER_COMPARE_KERNEL(less_equal, CPU, paddle::operators::LessEqualFunctor);
|
|
|
|
|
REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
|
|
|
|
|
REGISTER_COMPARE_KERNEL(greater_than, CPU,
|
|
|
|
|
paddle::operators::GreaterThanFunctor);
|
|
|
|
|
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
|
|
|
|
|
REGISTER_COMPARE_KERNEL(greater_equal, CPU,
|
|
|
|
|
paddle::operators::GreaterEqualFunctor);
|
|
|
|
|
REGISTER_COMPARE_OP(equal, "Out = X == Y");
|
|
|
|
|
REGISTER_COMPARE_KERNEL(equal, CPU, paddle::operators::EqualFunctor);
|
|
|
|
|
REGISTER_COMPARE_OP(not_equal, "Out = X != Y");
|
|
|
|
|
REGISTER_COMPARE_KERNEL(not_equal, CPU, paddle::operators::NotEqualFunctor);
|
|
|
|
|