|
|
|
@ -14,6 +14,7 @@
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/compare_op.h"
|
|
|
|
|
#include "paddle/framework/op_registry.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
template <typename OpComment>
|
|
|
|
@ -61,19 +62,34 @@ class CompareOpInferShape : public framework::InferShapeBase {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class CompareOp : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
framework::OpKernelType GetKernelType(
|
|
|
|
|
const framework::ExecutionContext &ctx) const override {
|
|
|
|
|
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
|
|
|
|
|
// CompareOp kernel's device type is decided by input tensor place
|
|
|
|
|
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
|
|
|
|
|
return kt;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
#define REGISTER_LOGICAL_OP(op_type, _equation) \
|
|
|
|
|
struct _##op_type##Comment { \
|
|
|
|
|
static char type[]; \
|
|
|
|
|
static char equation[]; \
|
|
|
|
|
}; \
|
|
|
|
|
char _##op_type##Comment::type[]{#op_type}; \
|
|
|
|
|
char _##op_type##Comment::equation[]{_equation}; \
|
|
|
|
|
REGISTER_OP_WITH_KERNEL( \
|
|
|
|
|
op_type, ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
|
|
|
|
|
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
|
|
|
|
|
#define REGISTER_LOGICAL_OP(op_type, _equation) \
|
|
|
|
|
struct _##op_type##Comment { \
|
|
|
|
|
static char type[]; \
|
|
|
|
|
static char equation[]; \
|
|
|
|
|
}; \
|
|
|
|
|
char _##op_type##Comment::type[]{#op_type}; \
|
|
|
|
|
char _##op_type##Comment::equation[]{_equation}; \
|
|
|
|
|
REGISTER_OPERATOR( \
|
|
|
|
|
op_type, ::paddle::operators::CompareOp, \
|
|
|
|
|
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
|
|
|
|
|
::paddle::operators::CompareOpInferShape<_##op_type##Comment>, \
|
|
|
|
|
::paddle::framework::EmptyGradOpMaker);
|
|
|
|
|
|
|
|
|
|
REGISTER_LOGICAL_OP(less_than, "Out = X < Y");
|
|
|
|
|