|
|
|
@ -17,6 +17,7 @@ limitations under the License. */
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_version_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -128,18 +129,28 @@ class CompareOp : public framework::OperatorWithKernel {
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
#define REGISTER_COMPARE_OP(op_type, _equation) \
|
|
|
|
|
struct _##op_type##Comment { \
|
|
|
|
|
static char type[]; \
|
|
|
|
|
static char equation[]; \
|
|
|
|
|
}; \
|
|
|
|
|
char _##op_type##Comment::type[]{#op_type}; \
|
|
|
|
|
char _##op_type##Comment::equation[]{_equation}; \
|
|
|
|
|
REGISTER_OPERATOR( \
|
|
|
|
|
op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \
|
|
|
|
|
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
|
|
|
|
|
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
|
|
|
|
|
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
|
|
|
|
|
#define REGISTER_COMPARE_OP_VERSION(op_type) \
|
|
|
|
|
REGISTER_OP_VERSION(op_type) \
|
|
|
|
|
.AddCheckpoint( \
|
|
|
|
|
R"ROC(Upgrade compare ops, add a new attribute [force_cpu])ROC", \
|
|
|
|
|
paddle::framework::compatible::OpVersionDesc().NewAttr( \
|
|
|
|
|
"force_cpu", \
|
|
|
|
|
"In order to force fill output variable to cpu memory.", \
|
|
|
|
|
false));
|
|
|
|
|
|
|
|
|
|
#define REGISTER_COMPARE_OP(op_type, _equation) \
|
|
|
|
|
struct _##op_type##Comment { \
|
|
|
|
|
static char type[]; \
|
|
|
|
|
static char equation[]; \
|
|
|
|
|
}; \
|
|
|
|
|
char _##op_type##Comment::type[]{#op_type}; \
|
|
|
|
|
char _##op_type##Comment::equation[]{_equation}; \
|
|
|
|
|
REGISTER_OPERATOR( \
|
|
|
|
|
op_type, ::paddle::operators::CompareOp<_##op_type##Comment>, \
|
|
|
|
|
::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>, \
|
|
|
|
|
::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, \
|
|
|
|
|
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); \
|
|
|
|
|
REGISTER_COMPARE_OP_VERSION(op_type);
|
|
|
|
|
|
|
|
|
|
REGISTER_COMPARE_OP(less_than, "Out = X < Y");
|
|
|
|
|
REGISTER_COMPARE_KERNEL(less_than, CPU, paddle::operators::LessThanFunctor);
|
|
|
|
|